lint(api): remove unused hub download code path
This commit is contained in:
parent
157067b554
commit
5eae80cbdd
|
@ -1135,9 +1135,6 @@ def extract_checkpoint(
|
|||
new_model_name: str,
|
||||
checkpoint_file: str,
|
||||
scheduler_type="ddim",
|
||||
from_hub=False,
|
||||
new_model_url="",
|
||||
new_model_token="",
|
||||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
is_512=True,
|
||||
|
@ -1177,33 +1174,15 @@ def extract_checkpoint(
|
|||
|
||||
# Create empty config
|
||||
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||
src=checkpoint_file if not from_hub else new_model_url)
|
||||
src=checkpoint_file)
|
||||
|
||||
original_config_file = None
|
||||
|
||||
# Okay then. So, if it's from the hub, try to download it
|
||||
if from_hub:
|
||||
model_info, config = download_model(db_config, new_model_token)
|
||||
if db_config is not None:
|
||||
original_config_file = config
|
||||
if model_info is not None:
|
||||
logger.debug("Got model info.")
|
||||
if ".ckpt" in model_info or ".safetensors" in model_info:
|
||||
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
||||
from_hub = False
|
||||
db_config.src = model_info
|
||||
checkpoint_file = model_info
|
||||
else:
|
||||
msg = "Unable to fetch model from hub."
|
||||
logger.warning(msg)
|
||||
return
|
||||
|
||||
try:
|
||||
checkpoint = None
|
||||
map_location = torch.device("cpu")
|
||||
|
||||
# Try to determine if v1 or v2 model if we have a ckpt
|
||||
if not from_hub:
|
||||
logger.info("Loading model from checkpoint.")
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
|
@ -1240,18 +1219,6 @@ def extract_checkpoint(
|
|||
v2 = True
|
||||
else:
|
||||
v2 = False
|
||||
else:
|
||||
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
||||
try:
|
||||
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
||||
logger.debug("Loaded unet.")
|
||||
unet_dict = unet.state_dict()
|
||||
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||
logger.debug("UNet using v2 parameters.")
|
||||
v2 = True
|
||||
except Exception:
|
||||
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
||||
|
||||
if v2 and not is_512:
|
||||
prediction_type = "v_prediction"
|
||||
|
@ -1265,10 +1232,6 @@ def extract_checkpoint(
|
|||
db_config.lifetime_revision = revision
|
||||
db_config.epoch = epoch
|
||||
db_config.v2 = v2
|
||||
if from_hub:
|
||||
result_status = "Model fetched from hub."
|
||||
db_config.save()
|
||||
return
|
||||
|
||||
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||
|
||||
|
@ -1457,7 +1420,7 @@ def convert_diffusion_original(
|
|||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||
else:
|
||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
||||
extract_checkpoint(ctx, torch_name, source)
|
||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||
|
||||
convert_diffusion_stable(ctx, model, working_name)
|
||||
|
|
|
@ -147,6 +147,9 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
|
||||
|
||||
class Config(object):
|
||||
"""
|
||||
Shim for pydantic-style config.
|
||||
"""
|
||||
def __init__(self, kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
for k, v in self.__dict__.items():
|
||||
|
|
Loading…
Reference in New Issue