diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index a11d706f..5d6923be 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -15,6 +15,7 @@ from .upscale_resrgan import convert_upscale_resrgan from .utils import ( ConversionContext, download_progress, + model_formats_original, source_format, tuple_to_correction, tuple_to_diffusion, @@ -103,12 +104,12 @@ base_models: Models = { def fetch_model( - ctx: ConversionContext, name: str, source: str, format: Optional[str] = None + ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None ) -> str: cache_name = path.join(ctx.cache_path, name) - if format is not None: + if model_format is not None: # add an extension if possible, some of the conversion code checks for it - cache_name = "%s.%s" % (cache_name, format) + cache_name = "%s.%s" % (cache_name, model_format) for proto in model_sources: api_name, api_root = model_sources.get(proto) @@ -147,10 +148,12 @@ def convert_models(ctx: ConversionContext, args, models: Models): if name in args.skip: logger.info("Skipping model: %s", name) else: - format = source_format(model) - source = fetch_model(ctx, name, model["source"], format=format) + model_format = source_format(model) + source = fetch_model( + ctx, name, model["source"], model_format=model_format + ) - if format in ["safetensors", "ckpt"]: + if model_format in model_formats_original: convert_diffusion_original( ctx, model, @@ -171,8 +174,10 @@ def convert_models(ctx: ConversionContext, args, models: Models): if name in args.skip: logger.info("Skipping model: %s", name) else: - format = source_format(model) - source = fetch_model(ctx, name, model["source"], format=format) + model_format = source_format(model) + source = fetch_model( + ctx, name, model["source"], model_format=model_format + ) convert_upscale_resrgan(ctx, model, source) if args.correction: @@ -183,8 +188,10 @@ def convert_models(ctx: ConversionContext, args, models: Models): if name in args.skip: logger.info("Skipping model: %s", name) else: - format = source_format(model) - source = fetch_model(ctx, name, model["source"], format=format) + model_format = source_format(model) + source = fetch_model( + ctx, name, model["source"], model_format=model_format + ) convert_correction_gfpgan(ctx, model, source) diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 63a6575a..b28d89d3 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -7,8 +7,6 @@ # # d8ahazard portions do not include a license header or file # HuggingFace portions used under the Apache License, Version 2.0 -# -# TODO: ask about license before merging ### import json @@ -59,6 +57,10 @@ logger = getLogger(__name__) class TrainingConfig(): + """ + From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py + """ + adamw_weight_decay: float = 0.01 attention: str = "default" cache_latents: bool = True @@ -1314,7 +1316,7 @@ def extract_checkpoint( elif scheduler_type == "dpm": scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) elif scheduler_type == "ddim": - scheduler = scheduler + pass else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index d36e3e2c..3fef6dea 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -130,7 +130,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): return model -known_formats = ["onnx", "pth", "ckpt", "safetensors"] +model_formats = ["onnx", "pth", "ckpt", "safetensors"] +model_formats_original = ["ckpt", "safetensors"] def source_format(model: Dict) -> Optional[str]: @@ -139,7 +140,7 @@ def source_format(model: Dict) -> Optional[str]: if "source" in model: ext = path.splitext(model["source"]) - if ext in known_formats: + if ext in model_formats: return ext return None