apply lint
This commit is contained in:
parent
3f3811e16a
commit
44a23a6366
|
@ -130,19 +130,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||
match: Optional[str] = None
|
||||
if "conv" in suffix:
|
||||
match = next(
|
||||
node for node in names if node == f"{root}_Conv"
|
||||
)
|
||||
match = next(node for node in names if node == f"{root}_Conv")
|
||||
elif "time_emb_proj" in root:
|
||||
match = next(
|
||||
node for node in names if node == f"{root}_Gemm"
|
||||
)
|
||||
match = next(node for node in names if node == f"{root}_Gemm")
|
||||
elif block == "text_model" or simple:
|
||||
match = next(
|
||||
node
|
||||
for node in names
|
||||
if node == f"{root}_MatMul"
|
||||
)
|
||||
match = next(node for node in names if node == f"{root}_MatMul")
|
||||
else:
|
||||
# search in order. one side has sparse indices, so they will not match.
|
||||
match = next(
|
||||
|
@ -174,8 +166,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
names.remove(match)
|
||||
|
||||
logger.debug(
|
||||
"SDXL LoRA key fixup matched %s keys, %s remaining",
|
||||
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
|
||||
len(fixed.keys()),
|
||||
len(keys.keys()),
|
||||
len(names),
|
||||
)
|
||||
|
||||
|
|
|
@ -7,12 +7,6 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
|||
ORTStableDiffusionXLImg2ImgPipeline,
|
||||
ORTStableDiffusionXLPipeline,
|
||||
)
|
||||
from optimum.onnxruntime.modeling_diffusion import (
|
||||
ORTModelTextEncoder,
|
||||
ORTModelUnet,
|
||||
ORTModelVaeDecoder,
|
||||
ORTModelVaeEncoder,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ..constants import ONNX_MODEL
|
||||
|
@ -232,7 +226,9 @@ def load_pipeline(
|
|||
tokenizer_2=components.get("tokenizer_2", None),
|
||||
)
|
||||
else:
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
logger.debug(
|
||||
"loading pretrained SD pipeline for %s", pipeline_class.__name__
|
||||
)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
|
|
Loading…
Reference in New Issue