fix(api): add model image size and version hint to extras file
This commit is contained in:
parent
2690eafe09
commit
bc71583393
|
@ -52,42 +52,51 @@ available_pipelines = {
|
|||
|
||||
|
||||
def get_model_version(
|
||||
checkpoint,
|
||||
size=None,
|
||||
source,
|
||||
map_location,
|
||||
size = None,
|
||||
version = None,
|
||||
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
if size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
size = 512 if global_step == 875000 else 768
|
||||
|
||||
v2 = False
|
||||
v2 = version is not None and "v2" in version
|
||||
opts = {
|
||||
"extract_ema": True,
|
||||
"image_size": size,
|
||||
}
|
||||
|
||||
key_name = (
|
||||
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
)
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
v2 = True
|
||||
if size != 512:
|
||||
# v2.1 needs to upcast attention
|
||||
logger.debug("setting upcast_attention")
|
||||
opts["upcast_attention"] = True
|
||||
try:
|
||||
checkpoint = load_tensor(source, map_location=map_location)
|
||||
|
||||
if v2 and size != 512:
|
||||
opts["model_type"] = "FrozenOpenCLIPEmbedder"
|
||||
opts["prediction_type"] = "v_prediction"
|
||||
else:
|
||||
opts["model_type"] = "FrozenCLIPEmbedder"
|
||||
opts["prediction_type"] = "epsilon"
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
if size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
size = 512 if global_step == 875000 else 768
|
||||
|
||||
opts["image_size"] = size
|
||||
|
||||
key_name = (
|
||||
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
)
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
v2 = True
|
||||
if size != 512:
|
||||
# v2.1 needs to upcast attention
|
||||
logger.debug("setting upcast_attention")
|
||||
opts["upcast_attention"] = True
|
||||
|
||||
if v2 and size != 512:
|
||||
opts["model_type"] = "FrozenOpenCLIPEmbedder"
|
||||
opts["prediction_type"] = "v_prediction"
|
||||
else:
|
||||
opts["model_type"] = "FrozenCLIPEmbedder"
|
||||
opts["prediction_type"] = "epsilon"
|
||||
except:
|
||||
logger.debug("unable to load tensor for version check")
|
||||
pass
|
||||
|
||||
return (v2, opts)
|
||||
|
||||
|
@ -241,12 +250,14 @@ def convert_diffusion_diffusers(
|
|||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
config = model.get("config", None)
|
||||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
image_size = model.get("image_size", None)
|
||||
name = model.get("name")
|
||||
pipe_type = model.get("pipeline", "txt2img")
|
||||
single_vae = model.get("single_vae")
|
||||
source = source or model.get("source")
|
||||
replace_vae = model.get("vae")
|
||||
version = model.get("version", None)
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
|
@ -279,7 +290,7 @@ def convert_diffusion_diffusers(
|
|||
return (False, dest_path)
|
||||
|
||||
pipe_class = available_pipelines.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
|
||||
v2, pipe_args = get_model_version(source, conversion.map_location, size=image_size, version=version)
|
||||
|
||||
if pipe_type == "inpaint":
|
||||
pipe_args["num_in_channels"] = 9
|
||||
|
@ -299,7 +310,7 @@ def convert_diffusion_diffusers(
|
|||
pipeline_class=pipe_class,
|
||||
torch_dtype=dtype,
|
||||
**pipe_args,
|
||||
).to(device)
|
||||
).to(device, torch_dtype=dtype)
|
||||
else:
|
||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
|
|
|
@ -268,6 +268,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
except Exception as e:
|
||||
logger.warning("error loading tensor: %s", e)
|
||||
|
||||
if checkpoint is None:
|
||||
raise ValueError("error loading tensor")
|
||||
|
||||
if checkpoint is not None and "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
|
|
|
@ -78,6 +78,8 @@ $defs:
|
|||
properties:
|
||||
config:
|
||||
type: string
|
||||
image_size:
|
||||
type: number
|
||||
inversions:
|
||||
type: array
|
||||
items:
|
||||
|
@ -94,6 +96,13 @@ $defs:
|
|||
]
|
||||
vae:
|
||||
type: string
|
||||
version:
|
||||
type: string
|
||||
enum: [
|
||||
v1,
|
||||
v2,
|
||||
v2.1,
|
||||
]
|
||||
|
||||
upscaling_model:
|
||||
allOf:
|
||||
|
|
Loading…
Reference in New Issue