lint(api): fix some Sonar issues
This commit is contained in:
parent
454abcdddc
commit
157067b554
|
@ -15,6 +15,7 @@ from .upscale_resrgan import convert_upscale_resrgan
|
|||
from .utils import (
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
model_formats_original,
|
||||
source_format,
|
||||
tuple_to_correction,
|
||||
tuple_to_diffusion,
|
||||
|
@ -103,12 +104,12 @@ base_models: Models = {
|
|||
|
||||
|
||||
def fetch_model(
|
||||
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
||||
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
|
||||
) -> str:
|
||||
cache_name = path.join(ctx.cache_path, name)
|
||||
if format is not None:
|
||||
if model_format is not None:
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
cache_name = "%s.%s" % (cache_name, format)
|
||||
cache_name = "%s.%s" % (cache_name, model_format)
|
||||
|
||||
for proto in model_sources:
|
||||
api_name, api_root = model_sources.get(proto)
|
||||
|
@ -147,10 +148,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
model_format = source_format(model)
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], model_format=model_format
|
||||
)
|
||||
|
||||
if format in ["safetensors", "ckpt"]:
|
||||
if model_format in model_formats_original:
|
||||
convert_diffusion_original(
|
||||
ctx,
|
||||
model,
|
||||
|
@ -171,8 +174,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
model_format = source_format(model)
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], model_format=model_format
|
||||
)
|
||||
convert_upscale_resrgan(ctx, model, source)
|
||||
|
||||
if args.correction:
|
||||
|
@ -183,8 +188,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
model_format = source_format(model)
|
||||
source = fetch_model(
|
||||
ctx, name, model["source"], model_format=model_format
|
||||
)
|
||||
convert_correction_gfpgan(ctx, model, source)
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
#
|
||||
# d8ahazard portions do not include a license header or file
|
||||
# HuggingFace portions used under the Apache License, Version 2.0
|
||||
#
|
||||
# TODO: ask about license before merging
|
||||
###
|
||||
|
||||
import json
|
||||
|
@ -59,6 +57,10 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class TrainingConfig():
|
||||
"""
|
||||
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
|
||||
"""
|
||||
|
||||
adamw_weight_decay: float = 0.01
|
||||
attention: str = "default"
|
||||
cache_latents: bool = True
|
||||
|
@ -1314,7 +1316,7 @@ def extract_checkpoint(
|
|||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
|
|
@ -130,7 +130,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
return model
|
||||
|
||||
|
||||
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
model_formats_original = ["ckpt", "safetensors"]
|
||||
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
|
@ -139,7 +140,7 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
|
||||
if "source" in model:
|
||||
ext = path.splitext(model["source"])
|
||||
if ext in known_formats:
|
||||
if ext in model_formats:
|
||||
return ext
|
||||
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue