diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index 0362494a..5d72baa8 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -612,12 +612,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.weight" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.bias" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = ( + unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = ( + unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + ) paths = renew_resnet_paths(resnets) meta_path = { @@ -705,12 +705,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False index = list(output_block_list.values()).index( ["conv.bias", "conv.weight"] ) - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.weight" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.bias" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = ( + unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = ( + unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + ) # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -818,12 +818,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = ( + vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = ( + vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} @@ -871,12 +871,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = ( + vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + ) + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = ( + vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} @@ -983,9 +983,9 @@ def convert_ldm_clip_checkpoint(checkpoint): "text_model." + key[len("cond_stage_model.transformer.") :] ] = checkpoint[key] else: - text_model_dict[ - key[len("cond_stage_model.transformer.") :] - ] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = ( + checkpoint[key] + ) text_model.load_state_dict(text_model_dict, strict=False) @@ -1109,9 +1109,9 @@ def convert_open_clip_checkpoint(checkpoint): else: logger.debug("no projection shape found, setting to 1024") d_model = 1024 - text_model_dict[ - "text_model.embeddings.position_ids" - ] = text_model.text_model.embeddings.get_buffer("position_ids") + text_model_dict["text_model.embeddings.position_ids"] = ( + text_model.text_model.embeddings.get_buffer("position_ids") + ) for key in keys: if ( diff --git a/api/onnx_web/diffusers/pipelines/lpw.py b/api/onnx_web/diffusers/pipelines/lpw.py index 8febc277..f8b1a4cf 100644 --- a/api/onnx_web/diffusers/pipelines/lpw.py +++ b/api/onnx_web/diffusers/pipelines/lpw.py @@ -465,6 +465,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) """ + if version.parse( version.parse(diffusers.__version__).base_version ) >= version.parse("0.9.0"): diff --git a/api/onnx_web/diffusers/pipelines/pix2pix.py b/api/onnx_web/diffusers/pipelines/pix2pix.py index 2689fa4f..d257315f 100644 --- a/api/onnx_web/diffusers/pipelines/pix2pix.py +++ b/api/onnx_web/diffusers/pipelines/pix2pix.py @@ -88,6 +88,7 @@ class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel diff --git a/api/onnx_web/models/swinir.py b/api/onnx_web/models/swinir.py index 81c96931..ff964993 100644 --- a/api/onnx_web/models/swinir.py +++ b/api/onnx_web/models/swinir.py @@ -495,9 +495,9 @@ class BasicLayer(nn.Module): qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, + drop_path=( + drop_path[i] if isinstance(drop_path, list) else drop_path + ), norm_layer=norm_layer, ) for i in range(depth) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 16bbfcc8..28f79a86 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -244,9 +244,9 @@ def load_extras(server: ServerContext): inversion_name, model_name, ) - labels[ - f"inversion.{inversion_name}" - ] = inversion["label"] + labels[f"inversion.{inversion_name}"] = ( + inversion["label"] + ) if "loras" in model: for lora in model["loras"]: