lint(api): remove unused hub download code path
This commit is contained in:
parent
157067b554
commit
5eae80cbdd
|
@ -1135,9 +1135,6 @@ def extract_checkpoint(
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
checkpoint_file: str,
|
checkpoint_file: str,
|
||||||
scheduler_type="ddim",
|
scheduler_type="ddim",
|
||||||
from_hub=False,
|
|
||||||
new_model_url="",
|
|
||||||
new_model_token="",
|
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
is_512=True,
|
is_512=True,
|
||||||
|
@ -1177,33 +1174,15 @@ def extract_checkpoint(
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||||
src=checkpoint_file if not from_hub else new_model_url)
|
src=checkpoint_file)
|
||||||
|
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
|
|
||||||
# Okay then. So, if it's from the hub, try to download it
|
|
||||||
if from_hub:
|
|
||||||
model_info, config = download_model(db_config, new_model_token)
|
|
||||||
if db_config is not None:
|
|
||||||
original_config_file = config
|
|
||||||
if model_info is not None:
|
|
||||||
logger.debug("Got model info.")
|
|
||||||
if ".ckpt" in model_info or ".safetensors" in model_info:
|
|
||||||
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
|
||||||
from_hub = False
|
|
||||||
db_config.src = model_info
|
|
||||||
checkpoint_file = model_info
|
|
||||||
else:
|
|
||||||
msg = "Unable to fetch model from hub."
|
|
||||||
logger.warning(msg)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
map_location = torch.device("cpu")
|
map_location = torch.device("cpu")
|
||||||
|
|
||||||
# Try to determine if v1 or v2 model if we have a ckpt
|
# Try to determine if v1 or v2 model if we have a ckpt
|
||||||
if not from_hub:
|
|
||||||
logger.info("Loading model from checkpoint.")
|
logger.info("Loading model from checkpoint.")
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
|
@ -1240,18 +1219,6 @@ def extract_checkpoint(
|
||||||
v2 = True
|
v2 = True
|
||||||
else:
|
else:
|
||||||
v2 = False
|
v2 = False
|
||||||
else:
|
|
||||||
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
|
||||||
try:
|
|
||||||
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
|
||||||
logger.debug("Loaded unet.")
|
|
||||||
unet_dict = unet.state_dict()
|
|
||||||
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
|
||||||
logger.debug("UNet using v2 parameters.")
|
|
||||||
v2 = True
|
|
||||||
except Exception:
|
|
||||||
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
|
@ -1265,10 +1232,6 @@ def extract_checkpoint(
|
||||||
db_config.lifetime_revision = revision
|
db_config.lifetime_revision = revision
|
||||||
db_config.epoch = epoch
|
db_config.epoch = epoch
|
||||||
db_config.v2 = v2
|
db_config.v2 = v2
|
||||||
if from_hub:
|
|
||||||
result_status = "Model fetched from hub."
|
|
||||||
db_config.save()
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||||
|
|
||||||
|
@ -1457,7 +1420,7 @@ def convert_diffusion_original(
|
||||||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
extract_checkpoint(ctx, torch_name, source)
|
||||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model, working_name)
|
convert_diffusion_stable(ctx, model, working_name)
|
||||||
|
|
|
@ -147,6 +147,9 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
|
"""
|
||||||
|
Shim for pydantic-style config.
|
||||||
|
"""
|
||||||
def __init__(self, kwargs):
|
def __init__(self, kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
|
|
Loading…
Reference in New Issue