fix(api): download pretrained models from HF correctly (#371)
This commit is contained in:
parent
b6692f068e
commit
d66bf9e54f
|
@ -207,7 +207,7 @@ def fetch_model(
|
|||
format: Optional[str] = None,
|
||||
hf_hub_fetch: bool = False,
|
||||
hf_hub_filename: Optional[str] = None,
|
||||
) -> str:
|
||||
) -> Tuple[str, bool]:
|
||||
cache_path = dest or conversion.cache_path
|
||||
cache_name = path.join(cache_path, name)
|
||||
|
||||
|
@ -223,7 +223,7 @@ def fetch_model(
|
|||
|
||||
if path.exists(cache_name):
|
||||
logger.debug("model already exists in cache, skipping fetch")
|
||||
return cache_name
|
||||
return cache_name, False
|
||||
|
||||
for proto in model_sources:
|
||||
api_name, api_root = model_sources.get(proto)
|
||||
|
@ -232,33 +232,36 @@ def fetch_model(
|
|||
logger.info(
|
||||
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
||||
)
|
||||
return download_progress([(api_source, cache_name)])
|
||||
return download_progress([(api_source, cache_name)]), False
|
||||
|
||||
if source.startswith(model_source_huggingface):
|
||||
hub_source = remove_prefix(source, model_source_huggingface)
|
||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
if hf_hub_fetch:
|
||||
return hf_hub_download(
|
||||
repo_id=hub_source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
return (
|
||||
hf_hub_download(
|
||||
repo_id=hub_source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
),
|
||||
False,
|
||||
)
|
||||
else:
|
||||
return hub_source
|
||||
return hub_source, True
|
||||
elif source.startswith("https://"):
|
||||
logger.info("downloading model from: %s", source)
|
||||
return download_progress([(source, cache_name)])
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith("http://"):
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
return download_progress([(source, cache_name)])
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("using local model: %s", source)
|
||||
return source
|
||||
return source, False
|
||||
else:
|
||||
logger.info("unknown model location, using path as provided: %s", source)
|
||||
return source
|
||||
return source, False
|
||||
|
||||
|
||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||
|
@ -280,7 +283,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if "dest" in model:
|
||||
dest_path = path.join(conversion.model_path, model["dest"])
|
||||
|
||||
dest = fetch_model(
|
||||
dest, hf = fetch_model(
|
||||
conversion, name, source, format=model_format, dest=dest_path
|
||||
)
|
||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
|
@ -302,7 +305,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
|
||||
try:
|
||||
if network_type == "control":
|
||||
dest = fetch_model(
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
|
@ -315,7 +318,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
dest,
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest = fetch_model(
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
|
@ -325,7 +328,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
hf_hub_filename="learned_embeds.bin",
|
||||
)
|
||||
else:
|
||||
dest = fetch_model(
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
|
@ -349,7 +352,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source = fetch_model(
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
|
@ -358,6 +361,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
|
||||
# make sure blending only happens once, not every run
|
||||
|
@ -389,7 +393,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", None)
|
||||
inversion_source = fetch_model(
|
||||
inversion_source, hf = fetch_model(
|
||||
conversion,
|
||||
inversion_name,
|
||||
inversion_source,
|
||||
|
@ -430,7 +434,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
# load models if not loaded yet
|
||||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source = fetch_model(
|
||||
lora_source, hf = fetch_model(
|
||||
conversion,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source,
|
||||
|
@ -489,7 +493,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source = fetch_model(
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "resrgan")
|
||||
|
@ -521,7 +525,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
else:
|
||||
model_format = source_format(model)
|
||||
try:
|
||||
source = fetch_model(
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
|
|
|
@ -247,6 +247,7 @@ def convert_diffusion_diffusers(
|
|||
model: Dict,
|
||||
source: str,
|
||||
format: str,
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
|
@ -316,6 +317,13 @@ def convert_diffusion_diffusers(
|
|||
pipeline_class=pipe_class,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif hf:
|
||||
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
else:
|
||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
|
|
Loading…
Reference in New Issue