From 44a23a6366a1f60537737bc8535259acf541dce8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 24 Nov 2023 22:40:01 -0600 Subject: [PATCH] apply lint --- api/onnx_web/convert/diffusion/lora.py | 17 +++++------------ api/onnx_web/diffusers/load.py | 12 ++++-------- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 6e8ecc6e..f115942f 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -130,19 +130,11 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) match: Optional[str] = None if "conv" in suffix: - match = next( - node for node in names if node == f"{root}_Conv" - ) + match = next(node for node in names if node == f"{root}_Conv") elif "time_emb_proj" in root: - match = next( - node for node in names if node == f"{root}_Gemm" - ) + match = next(node for node in names if node == f"{root}_Gemm") elif block == "text_model" or simple: - match = next( - node - for node in names - if node == f"{root}_MatMul" - ) + match = next(node for node in names if node == f"{root}_MatMul") else: # search in order. one side has sparse indices, so they will not match. match = next( @@ -174,8 +166,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] names.remove(match) logger.debug( - "SDXL LoRA key fixup matched %s keys, %s remaining", + "SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining", len(fixed.keys()), + len(keys.keys()), len(names), ) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 248f8887..e7394dc7 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -7,12 +7,6 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLPipeline, ) -from optimum.onnxruntime.modeling_diffusion import ( - ORTModelTextEncoder, - ORTModelUnet, - ORTModelVaeDecoder, - ORTModelVaeEncoder, -) from transformers import CLIPTokenizer from ..constants import ONNX_MODEL @@ -224,7 +218,7 @@ def load_pipeline( components["vae_decoder_session"], components["text_encoder_session"], components["unet_session"], - {}, # empty config + {}, # empty config components["tokenizer"], scheduler, vae_encoder_session=components.get("vae_encoder_session", None), @@ -232,7 +226,9 @@ def load_pipeline( tokenizer_2=components.get("tokenizer_2", None), ) else: - logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) + logger.debug( + "loading pretrained SD pipeline for %s", pipeline_class.__name__ + ) pipe = pipeline_class.from_pretrained( model, provider=device.ort_provider(),