1
0
Fork 0

fix(api): add model image size and version hint to extras file

This commit is contained in:
Sean Sube 2023-04-29 22:56:52 -05:00
parent 2690eafe09
commit bc71583393
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 59 additions and 36 deletions

View File

@ -52,42 +52,51 @@ available_pipelines = {
def get_model_version(
checkpoint,
size=None,
source,
map_location,
size = None,
version = None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None
if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768
v2 = False
v2 = version is not None and "v2" in version
opts = {
"extract_ema": True,
"image_size": size,
}
key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True
try:
checkpoint = load_tensor(source, map_location=map_location)
if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None
if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768
opts["image_size"] = size
key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True
if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
except:
logger.debug("unable to load tensor for version check")
pass
return (v2, opts)
@ -241,12 +250,14 @@ def convert_diffusion_diffusers(
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
source = source or model.get("source")
config = model.get("config", None)
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
image_size = model.get("image_size", None)
name = model.get("name")
pipe_type = model.get("pipeline", "txt2img")
single_vae = model.get("single_vae")
source = source or model.get("source")
replace_vae = model.get("vae")
version = model.get("version", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
@ -279,7 +290,7 @@ def convert_diffusion_diffusers(
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
v2, pipe_args = get_model_version(source, conversion.map_location, size=image_size, version=version)
if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
@ -299,7 +310,7 @@ def convert_diffusion_diffusers(
pipeline_class=pipe_class,
torch_dtype=dtype,
**pipe_args,
).to(device)
).to(device, torch_dtype=dtype)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")

View File

@ -268,6 +268,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
except Exception as e:
logger.warning("error loading tensor: %s", e)
if checkpoint is None:
raise ValueError("error loading tensor")
if checkpoint is not None and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]

View File

@ -78,6 +78,8 @@ $defs:
properties:
config:
type: string
image_size:
type: number
inversions:
type: array
items:
@ -94,6 +96,13 @@ $defs:
]
vae:
type: string
version:
type: string
enum: [
v1,
v2,
v2.1,
]
upscaling_model:
allOf: