lint(api): remove unused hub download code path
This commit is contained in:
parent
157067b554
commit
5eae80cbdd
|
@ -1135,9 +1135,6 @@ def extract_checkpoint(
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
checkpoint_file: str,
|
checkpoint_file: str,
|
||||||
scheduler_type="ddim",
|
scheduler_type="ddim",
|
||||||
from_hub=False,
|
|
||||||
new_model_url="",
|
|
||||||
new_model_token="",
|
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
is_512=True,
|
is_512=True,
|
||||||
|
@ -1177,81 +1174,51 @@ def extract_checkpoint(
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||||
src=checkpoint_file if not from_hub else new_model_url)
|
src=checkpoint_file)
|
||||||
|
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
|
|
||||||
# Okay then. So, if it's from the hub, try to download it
|
|
||||||
if from_hub:
|
|
||||||
model_info, config = download_model(db_config, new_model_token)
|
|
||||||
if db_config is not None:
|
|
||||||
original_config_file = config
|
|
||||||
if model_info is not None:
|
|
||||||
logger.debug("Got model info.")
|
|
||||||
if ".ckpt" in model_info or ".safetensors" in model_info:
|
|
||||||
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
|
||||||
from_hub = False
|
|
||||||
db_config.src = model_info
|
|
||||||
checkpoint_file = model_info
|
|
||||||
else:
|
|
||||||
msg = "Unable to fetch model from hub."
|
|
||||||
logger.warning(msg)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
map_location = torch.device("cpu")
|
map_location = torch.device("cpu")
|
||||||
|
|
||||||
# Try to determine if v1 or v2 model if we have a ckpt
|
# Try to determine if v1 or v2 model if we have a ckpt
|
||||||
if not from_hub:
|
logger.info("Loading model from checkpoint.")
|
||||||
logger.info("Loading model from checkpoint.")
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
if extension.lower() == ".safetensors":
|
||||||
if extension.lower() == ".safetensors":
|
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
|
||||||
try:
|
|
||||||
logger.debug("Loading safetensors...")
|
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
|
||||||
checkpoint = torch.jit.load(checkpoint_file)
|
|
||||||
else:
|
|
||||||
logger.debug("Loading ckpt...")
|
|
||||||
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
|
||||||
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
|
||||||
|
|
||||||
rev_keys = ["db_global_step", "global_step"]
|
|
||||||
epoch_keys = ["db_epoch", "epoch"]
|
|
||||||
for key in rev_keys:
|
|
||||||
if key in checkpoint:
|
|
||||||
revision = checkpoint[key]
|
|
||||||
break
|
|
||||||
|
|
||||||
for key in epoch_keys:
|
|
||||||
if key in checkpoint:
|
|
||||||
epoch = checkpoint[key]
|
|
||||||
break
|
|
||||||
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
|
||||||
if not is_512:
|
|
||||||
# v2.1 needs to upcast attention
|
|
||||||
logger.debug("Setting upcast_attention")
|
|
||||||
upcast_attention = True
|
|
||||||
v2 = True
|
|
||||||
else:
|
|
||||||
v2 = False
|
|
||||||
else:
|
|
||||||
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
|
||||||
try:
|
try:
|
||||||
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
logger.debug("Loading safetensors...")
|
||||||
logger.debug("Loaded unet.")
|
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||||
unet_dict = unet.state_dict()
|
except Exception as e:
|
||||||
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
||||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
checkpoint = torch.jit.load(checkpoint_file)
|
||||||
logger.debug("UNet using v2 parameters.")
|
else:
|
||||||
v2 = True
|
logger.debug("Loading ckpt...")
|
||||||
except Exception:
|
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
||||||
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
|
|
||||||
|
rev_keys = ["db_global_step", "global_step"]
|
||||||
|
epoch_keys = ["db_epoch", "epoch"]
|
||||||
|
for key in rev_keys:
|
||||||
|
if key in checkpoint:
|
||||||
|
revision = checkpoint[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
for key in epoch_keys:
|
||||||
|
if key in checkpoint:
|
||||||
|
epoch = checkpoint[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
if not is_512:
|
||||||
|
# v2.1 needs to upcast attention
|
||||||
|
logger.debug("Setting upcast_attention")
|
||||||
|
upcast_attention = True
|
||||||
|
v2 = True
|
||||||
|
else:
|
||||||
|
v2 = False
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
|
@ -1265,10 +1232,6 @@ def extract_checkpoint(
|
||||||
db_config.lifetime_revision = revision
|
db_config.lifetime_revision = revision
|
||||||
db_config.epoch = epoch
|
db_config.epoch = epoch
|
||||||
db_config.v2 = v2
|
db_config.v2 = v2
|
||||||
if from_hub:
|
|
||||||
result_status = "Model fetched from hub."
|
|
||||||
db_config.save()
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||||
|
|
||||||
|
@ -1457,7 +1420,7 @@ def convert_diffusion_original(
|
||||||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
extract_checkpoint(ctx, torch_name, source)
|
||||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model, working_name)
|
convert_diffusion_stable(ctx, model, working_name)
|
||||||
|
|
|
@ -147,6 +147,9 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
|
"""
|
||||||
|
Shim for pydantic-style config.
|
||||||
|
"""
|
||||||
def __init__(self, kwargs):
|
def __init__(self, kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
|
|
Loading…
Reference in New Issue