lint(api): fix some Sonar issues
This commit is contained in:
parent
454abcdddc
commit
157067b554
|
@ -15,6 +15,7 @@ from .upscale_resrgan import convert_upscale_resrgan
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
download_progress,
|
download_progress,
|
||||||
|
model_formats_original,
|
||||||
source_format,
|
source_format,
|
||||||
tuple_to_correction,
|
tuple_to_correction,
|
||||||
tuple_to_diffusion,
|
tuple_to_diffusion,
|
||||||
|
@ -103,12 +104,12 @@ base_models: Models = {
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(
|
def fetch_model(
|
||||||
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
cache_name = path.join(ctx.cache_path, name)
|
cache_name = path.join(ctx.cache_path, name)
|
||||||
if format is not None:
|
if model_format is not None:
|
||||||
# add an extension if possible, some of the conversion code checks for it
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
cache_name = "%s.%s" % (cache_name, format)
|
cache_name = "%s.%s" % (cache_name, model_format)
|
||||||
|
|
||||||
for proto in model_sources:
|
for proto in model_sources:
|
||||||
api_name, api_root = model_sources.get(proto)
|
api_name, api_root = model_sources.get(proto)
|
||||||
|
@ -147,10 +148,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("Skipping model: %s", name)
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
format = source_format(model)
|
model_format = source_format(model)
|
||||||
source = fetch_model(ctx, name, model["source"], format=format)
|
source = fetch_model(
|
||||||
|
ctx, name, model["source"], model_format=model_format
|
||||||
|
)
|
||||||
|
|
||||||
if format in ["safetensors", "ckpt"]:
|
if model_format in model_formats_original:
|
||||||
convert_diffusion_original(
|
convert_diffusion_original(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
|
@ -171,8 +174,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("Skipping model: %s", name)
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
format = source_format(model)
|
model_format = source_format(model)
|
||||||
source = fetch_model(ctx, name, model["source"], format=format)
|
source = fetch_model(
|
||||||
|
ctx, name, model["source"], model_format=model_format
|
||||||
|
)
|
||||||
convert_upscale_resrgan(ctx, model, source)
|
convert_upscale_resrgan(ctx, model, source)
|
||||||
|
|
||||||
if args.correction:
|
if args.correction:
|
||||||
|
@ -183,8 +188,10 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("Skipping model: %s", name)
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
format = source_format(model)
|
model_format = source_format(model)
|
||||||
source = fetch_model(ctx, name, model["source"], format=format)
|
source = fetch_model(
|
||||||
|
ctx, name, model["source"], model_format=model_format
|
||||||
|
)
|
||||||
convert_correction_gfpgan(ctx, model, source)
|
convert_correction_gfpgan(ctx, model, source)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
#
|
#
|
||||||
# d8ahazard portions do not include a license header or file
|
# d8ahazard portions do not include a license header or file
|
||||||
# HuggingFace portions used under the Apache License, Version 2.0
|
# HuggingFace portions used under the Apache License, Version 2.0
|
||||||
#
|
|
||||||
# TODO: ask about license before merging
|
|
||||||
###
|
###
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
@ -59,6 +57,10 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig():
|
class TrainingConfig():
|
||||||
|
"""
|
||||||
|
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
|
||||||
|
"""
|
||||||
|
|
||||||
adamw_weight_decay: float = 0.01
|
adamw_weight_decay: float = 0.01
|
||||||
attention: str = "default"
|
attention: str = "default"
|
||||||
cache_latents: bool = True
|
cache_latents: bool = True
|
||||||
|
@ -1314,7 +1316,7 @@ def extract_checkpoint(
|
||||||
elif scheduler_type == "dpm":
|
elif scheduler_type == "dpm":
|
||||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||||
elif scheduler_type == "ddim":
|
elif scheduler_type == "ddim":
|
||||||
scheduler = scheduler
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
|
model_formats_original = ["ckpt", "safetensors"]
|
||||||
|
|
||||||
|
|
||||||
def source_format(model: Dict) -> Optional[str]:
|
def source_format(model: Dict) -> Optional[str]:
|
||||||
|
@ -139,7 +140,7 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
|
|
||||||
if "source" in model:
|
if "source" in model:
|
||||||
ext = path.splitext(model["source"])
|
ext = path.splitext(model["source"])
|
||||||
if ext in known_formats:
|
if ext in model_formats:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Reference in New Issue