1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-11-24 22:40:01 -06:00
parent 3f3811e16a
commit 44a23a6366
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 9 additions and 20 deletions

View File

@ -130,19 +130,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match: Optional[str] = None
if "conv" in suffix:
match = next(
node for node in names if node == f"{root}_Conv"
)
match = next(node for node in names if node == f"{root}_Conv")
elif "time_emb_proj" in root:
match = next(
node for node in names if node == f"{root}_Gemm"
)
match = next(node for node in names if node == f"{root}_Gemm")
elif block == "text_model" or simple:
match = next(
node
for node in names
if node == f"{root}_MatMul"
)
match = next(node for node in names if node == f"{root}_MatMul")
else:
# search in order. one side has sparse indices, so they will not match.
match = next(
@ -174,8 +166,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
names.remove(match)
logger.debug(
"SDXL LoRA key fixup matched %s keys, %s remaining",
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
len(fixed.keys()),
len(keys.keys()),
len(names),
)

View File

@ -7,12 +7,6 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.onnxruntime.modeling_diffusion import (
ORTModelTextEncoder,
ORTModelUnet,
ORTModelVaeDecoder,
ORTModelVaeEncoder,
)
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
@ -224,7 +218,7 @@ def load_pipeline(
components["vae_decoder_session"],
components["text_encoder_session"],
components["unet_session"],
{}, # empty config
{}, # empty config
components["tokenizer"],
scheduler,
vae_encoder_session=components.get("vae_encoder_session", None),
@ -232,7 +226,9 @@ def load_pipeline(
tokenizer_2=components.get("tokenizer_2", None),
)
else:
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
logger.debug(
"loading pretrained SD pipeline for %s", pipeline_class.__name__
)
pipe = pipeline_class.from_pretrained(
model,
provider=device.ort_provider(),