1
0
Fork 0

lint(api): fix some Sonar issues

This commit is contained in:
Sean Sube 2023-02-11 14:19:42 -06:00
parent 454abcdddc
commit 157067b554
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 25 additions and 15 deletions

View File

@ -15,6 +15,7 @@ from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
download_progress,
model_formats_original,
source_format,
tuple_to_correction,
tuple_to_diffusion,
@ -103,12 +104,12 @@ base_models: Models = {
def fetch_model(
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
) -> str:
cache_name = path.join(ctx.cache_path, name)
if format is not None:
if model_format is not None:
# add an extension if possible, some of the conversion code checks for it
cache_name = "%s.%s" % (cache_name, format)
cache_name = "%s.%s" % (cache_name, model_format)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
@ -147,10 +148,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("Skipping model: %s", name)
else:
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
if format in ["safetensors", "ckpt"]:
if model_format in model_formats_original:
convert_diffusion_original(
ctx,
model,
@ -171,8 +174,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("Skipping model: %s", name)
else:
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_upscale_resrgan(ctx, model, source)
if args.correction:
@ -183,8 +188,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("Skipping model: %s", name)
else:
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_correction_gfpgan(ctx, model, source)

View File

@ -7,8 +7,6 @@
#
# d8ahazard portions do not include a license header or file
# HuggingFace portions used under the Apache License, Version 2.0
#
# TODO: ask about license before merging
###
import json
@ -59,6 +57,10 @@ logger = getLogger(__name__)
class TrainingConfig():
"""
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
"""
adamw_weight_decay: float = 0.01
attention: str = "default"
cache_latents: bool = True
@ -1314,7 +1316,7 @@ def extract_checkpoint(
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
pass
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

View File

@ -130,7 +130,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
return model
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
model_formats_original = ["ckpt", "safetensors"]
def source_format(model: Dict) -> Optional[str]:
@ -139,7 +140,7 @@ def source_format(model: Dict) -> Optional[str]:
if "source" in model:
ext = path.splitext(model["source"])
if ext in known_formats:
if ext in model_formats:
return ext
return None