1
0
Fork 0

update black, apply latest lint

This commit is contained in:
Sean Sube 2024-02-25 08:23:42 -06:00
parent 9d87e92a1c
commit fd28c095fd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 38 additions and 36 deletions

View File

@ -612,12 +612,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.weight"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.bias"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
)
paths = renew_resnet_paths(resnets)
meta_path = {
@ -705,12 +705,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
index = list(output_block_list.values()).index(
["conv.bias", "conv.weight"]
)
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.weight"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.bias"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
)
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
@ -818,12 +818,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
@ -871,12 +871,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
)
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
@ -983,9 +983,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
"text_model." + key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
else:
text_model_dict[
key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
text_model_dict[key[len("cond_stage_model.transformer.") :]] = (
checkpoint[key]
)
text_model.load_state_dict(text_model_dict, strict=False)
@ -1109,9 +1109,9 @@ def convert_open_clip_checkpoint(checkpoint):
else:
logger.debug("no projection shape found, setting to 1024")
d_model = 1024
text_model_dict[
"text_model.embeddings.position_ids"
] = text_model.text_model.embeddings.get_buffer("position_ids")
text_model_dict["text_model.embeddings.position_ids"] = (
text_model.text_model.embeddings.get_buffer("position_ids")
)
for key in keys:
if (

View File

@ -465,6 +465,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
"""
if version.parse(
version.parse(diffusers.__version__).base_version
) >= version.parse("0.9.0"):

View File

@ -88,6 +88,7 @@ class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel

View File

@ -495,9 +495,9 @@ class BasicLayer(nn.Module):
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
drop_path=(
drop_path[i] if isinstance(drop_path, list) else drop_path
),
norm_layer=norm_layer,
)
for i in range(depth)

View File

@ -244,9 +244,9 @@ def load_extras(server: ServerContext):
inversion_name,
model_name,
)
labels[
f"inversion.{inversion_name}"
] = inversion["label"]
labels[f"inversion.{inversion_name}"] = (
inversion["label"]
)
if "loras" in model:
for lora in model["loras"]: