1
0
Fork 0

correctly handle falsey args in with_args methods

This commit is contained in:
Sean Sube 2024-01-30 09:28:43 -06:00
parent c2d45b03dc
commit f70de1ca79
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 48 additions and 32 deletions

View File

@ -11,7 +11,7 @@ from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext
from ..server.load import get_extra_hashes
from ..utils import hash_file, load_config_str
from ..utils import coalesce, hash_file, load_config_str
logger = getLogger(__name__)
@ -269,13 +269,13 @@ class ImageMetadata:
return ImageMetadata(
params or self.params,
size or self.size,
upscale=upscale or self.upscale,
border=border or self.border,
highres=highres or self.highres,
inversions=inversions or self.inversions,
loras=loras or self.loras,
models=models or self.models,
ancestors=ancestors or self.ancestors,
upscale=coalesce(upscale, self.upscale),
border=coalesce(border, self.border),
highres=coalesce(highres, self.highres),
inversions=coalesce(inversions, self.inversions),
loras=coalesce(loras, self.loras),
models=coalesce(models, self.models),
ancestors=coalesce(ancestors, self.ancestors),
)
@staticmethod

View File

@ -73,10 +73,14 @@ class UpscaleSwinIRStage(BaseStage):
upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None:
logger.warning("no correction model given, skipping")
logger.warning("no upscale model given, skipping")
return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
logger.info(
"upscaling %sx with SwinIR model: %s",
upscale.outscale,
upscale.upscale_model,
)
device = worker.get_device()
swinir = self.load(server, stage, upscale, device)

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from .models.meta import NetworkModel
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
from .utils import coalesce
logger = getLogger(__name__)
@ -417,10 +418,10 @@ class StageParams:
):
logger.debug("ignoring extra kwargs for stage: %s", kwargs)
return StageParams(
name=name or self.name,
outscale=outscale or self.outscale,
tile_order=tile_order or self.tile_order,
tile_size=tile_size or self.tile_size,
name=coalesce(name, self.name),
outscale=coalesce(outscale, self.outscale),
tile_order=coalesce(tile_order, self.tile_order),
tile_size=coalesce(tile_size, self.tile_size),
)
@ -516,18 +517,18 @@ class UpscaleParams:
):
logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
return UpscaleParams(
upscale_model=upscale_model or self.upscale_model,
correction_model=correction_model or self.correction_model,
denoise=denoise or self.denoise,
upscale=upscale or self.upscale,
faces=faces or self.faces,
face_outscale=face_outscale or self.face_outscale,
face_strength=face_strength or self.face_strength,
outscale=outscale or self.outscale,
scale=scale or self.scale,
pre_pad=pre_pad or self.pre_pad,
tile_pad=tile_pad or self.tile_pad,
upscale_order=upscale_order or self.upscale_order,
upscale_model=coalesce(upscale_model, self.upscale_model),
correction_model=coalesce(correction_model, self.correction_model),
denoise=coalesce(denoise, self.denoise),
upscale=coalesce(upscale, self.upscale),
faces=coalesce(faces, self.faces),
face_outscale=coalesce(face_outscale, self.face_outscale),
face_strength=coalesce(face_strength, self.face_strength),
outscale=coalesce(outscale, self.outscale),
scale=coalesce(scale, self.scale),
pre_pad=coalesce(pre_pad, self.pre_pad),
tile_pad=coalesce(tile_pad, self.tile_pad),
upscale_order=coalesce(upscale_order, self.upscale_order),
)
@ -583,10 +584,10 @@ class HighresParams:
):
logger.debug("ignoring extra kwargs for highres: %s", kwargs)
return HighresParams(
enabled=enabled or self.enabled,
scale=scale or self.scale,
steps=steps or self.steps,
strength=strength or self.strength,
method=method or self.method,
iterations=iterations or self.iterations,
enabled=coalesce(enabled, self.enabled),
scale=coalesce(scale, self.scale),
steps=coalesce(steps, self.steps),
strength=coalesce(strength, self.strength),
method=coalesce(method, self.method),
iterations=coalesce(iterations, self.iterations),
)

View File

@ -251,3 +251,14 @@ def hash_value(sha, param: Optional[Param]):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))
def coalesce(*args, throw=False):
for arg in args:
if arg is not None:
return arg
if throw:
raise ValueError("no value found")
return None