1
0
Fork 0

lint(api): remove unused hub download code path

This commit is contained in:
Sean Sube 2023-02-11 14:36:17 -06:00
parent 157067b554
commit 5eae80cbdd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 40 additions and 74 deletions

View File

@ -1135,9 +1135,6 @@ def extract_checkpoint(
new_model_name: str, new_model_name: str,
checkpoint_file: str, checkpoint_file: str,
scheduler_type="ddim", scheduler_type="ddim",
from_hub=False,
new_model_url="",
new_model_token="",
extract_ema=False, extract_ema=False,
train_unfrozen=False, train_unfrozen=False,
is_512=True, is_512=True,
@ -1177,81 +1174,51 @@ def extract_checkpoint(
# Create empty config # Create empty config
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file if not from_hub else new_model_url) src=checkpoint_file)
original_config_file = None original_config_file = None
# Okay then. So, if it's from the hub, try to download it
if from_hub:
model_info, config = download_model(db_config, new_model_token)
if db_config is not None:
original_config_file = config
if model_info is not None:
logger.debug("Got model info.")
if ".ckpt" in model_info or ".safetensors" in model_info:
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
from_hub = False
db_config.src = model_info
checkpoint_file = model_info
else:
msg = "Unable to fetch model from hub."
logger.warning(msg)
return
try: try:
checkpoint = None checkpoint = None
map_location = torch.device("cpu") map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt # Try to determine if v1 or v2 model if we have a ckpt
if not from_hub: logger.info("Loading model from checkpoint.")
logger.info("Loading model from checkpoint.") _, extension = os.path.splitext(checkpoint_file)
_, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors":
if extension.lower() == ".safetensors": os.environ["SAFETENSORS_FAST_GPU"] = "1"
os.environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("Loading safetensors...")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file)
else:
logger.debug("Loading ckpt...")
checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"]
for key in rev_keys:
if key in checkpoint:
revision = checkpoint[key]
break
for key in epoch_keys:
if key in checkpoint:
epoch = checkpoint[key]
break
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not is_512:
# v2.1 needs to upcast attention
logger.debug("Setting upcast_attention")
upcast_attention = True
v2 = True
else:
v2 = False
else:
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
try: try:
unet = UNet2DConditionModel.from_pretrained(unet_dir) logger.debug("Loading safetensors...")
logger.debug("Loaded unet.") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
unet_dict = unet.state_dict() except Exception as e:
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight" logger.warn("Failed to load as safetensors file, falling back to torch...", e)
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: checkpoint = torch.jit.load(checkpoint_file)
logger.debug("UNet using v2 parameters.") else:
v2 = True logger.debug("Loading ckpt...")
except Exception: checkpoint = torch.load(checkpoint_file, map_location=map_location)
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info())) checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"]
for key in rev_keys:
if key in checkpoint:
revision = checkpoint[key]
break
for key in epoch_keys:
if key in checkpoint:
epoch = checkpoint[key]
break
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not is_512:
# v2.1 needs to upcast attention
logger.debug("Setting upcast_attention")
upcast_attention = True
v2 = True
else:
v2 = False
if v2 and not is_512: if v2 and not is_512:
prediction_type = "v_prediction" prediction_type = "v_prediction"
@ -1265,10 +1232,6 @@ def extract_checkpoint(
db_config.lifetime_revision = revision db_config.lifetime_revision = revision
db_config.epoch = epoch db_config.epoch = epoch
db_config.v2 = v2 db_config.v2 = v2
if from_hub:
result_status = "Model fetched from hub."
db_config.save()
return
logger.info(f"{'v2' if v2 else 'v1'} model loaded.") logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
@ -1457,7 +1420,7 @@ def convert_diffusion_original(
logger.info("Torch pipeline already exists, reusing: %s", torch_path) logger.info("Torch pipeline already exists, reusing: %s", torch_path)
else: else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, from_hub=False) extract_checkpoint(ctx, torch_name, source)
logger.info("Converted original Diffusers checkpoint to Torch model.") logger.info("Converted original Diffusers checkpoint to Torch model.")
convert_diffusion_stable(ctx, model, working_name) convert_diffusion_stable(ctx, model, working_name)

View File

@ -147,6 +147,9 @@ def source_format(model: Dict) -> Optional[str]:
class Config(object): class Config(object):
"""
Shim for pydantic-style config.
"""
def __init__(self, kwargs): def __init__(self, kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():