1
0
Fork 0

fix(api): download pretrained models from HF correctly (#371)

This commit is contained in:
Sean Sube 2023-05-03 22:21:04 -05:00
parent b6692f068e
commit d66bf9e54f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 34 additions and 22 deletions

View File

@ -207,7 +207,7 @@ def fetch_model(
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> str:
) -> Tuple[str, bool]:
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)
@ -223,7 +223,7 @@ def fetch_model(
if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
return cache_name
return cache_name, False
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
@ -232,33 +232,36 @@ def fetch_model(
logger.info(
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
)
return download_progress([(api_source, cache_name)])
return download_progress([(api_source, cache_name)]), False
if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface)
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return hf_hub_download(
return (
hf_hub_download(
repo_id=hub_source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
),
False,
)
else:
return hub_source
return hub_source, True
elif source.startswith("https://"):
logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)])
return download_progress([(source, cache_name)]), False
elif source.startswith("http://"):
logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
return download_progress([(source, cache_name)]), False
elif source.startswith(path.sep) or source.startswith("."):
logger.info("using local model: %s", source)
return source
return source, False
else:
logger.info("unknown model location, using path as provided: %s", source)
return source
return source, False
def convert_models(conversion: ConversionContext, args, models: Models):
@ -280,7 +283,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest = fetch_model(
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
@ -302,7 +305,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
try:
if network_type == "control":
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
@ -315,7 +318,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
dest,
)
if network_type == "inversion" and network_model == "concept":
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
@ -325,7 +328,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
hf_hub_filename="learned_embeds.bin",
)
else:
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
@ -349,7 +352,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_format = source_format(model)
try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
@ -358,6 +361,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
@ -389,7 +393,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source = fetch_model(
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
@ -430,7 +434,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
@ -489,7 +493,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_format = source_format(model)
try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "resrgan")
@ -521,7 +525,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
else:
model_format = source_format(model)
try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "gfpgan")

View File

@ -247,6 +247,7 @@ def convert_diffusion_diffusers(
model: Dict,
source: str,
format: str,
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@ -316,6 +317,13 @@ def convert_diffusion_diffusers(
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")