correctly handle falsey args in with_args methods
This commit is contained in:
parent
c2d45b03dc
commit
f70de1ca79
|
@ -11,7 +11,7 @@ from ..convert.utils import resolve_tensor
|
|||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||
from ..server.context import ServerContext
|
||||
from ..server.load import get_extra_hashes
|
||||
from ..utils import hash_file, load_config_str
|
||||
from ..utils import coalesce, hash_file, load_config_str
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -269,13 +269,13 @@ class ImageMetadata:
|
|||
return ImageMetadata(
|
||||
params or self.params,
|
||||
size or self.size,
|
||||
upscale=upscale or self.upscale,
|
||||
border=border or self.border,
|
||||
highres=highres or self.highres,
|
||||
inversions=inversions or self.inversions,
|
||||
loras=loras or self.loras,
|
||||
models=models or self.models,
|
||||
ancestors=ancestors or self.ancestors,
|
||||
upscale=coalesce(upscale, self.upscale),
|
||||
border=coalesce(border, self.border),
|
||||
highres=coalesce(highres, self.highres),
|
||||
inversions=coalesce(inversions, self.inversions),
|
||||
loras=coalesce(loras, self.loras),
|
||||
models=coalesce(models, self.models),
|
||||
ancestors=coalesce(ancestors, self.ancestors),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -73,10 +73,14 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warning("no correction model given, skipping")
|
||||
logger.warning("no upscale model given, skipping")
|
||||
return sources
|
||||
|
||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||
logger.info(
|
||||
"upscaling %sx with SwinIR model: %s",
|
||||
upscale.outscale,
|
||||
upscale.upscale_model,
|
||||
)
|
||||
device = worker.get_device()
|
||||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
from .models.meta import NetworkModel
|
||||
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||
from .utils import coalesce
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -417,10 +418,10 @@ class StageParams:
|
|||
):
|
||||
logger.debug("ignoring extra kwargs for stage: %s", kwargs)
|
||||
return StageParams(
|
||||
name=name or self.name,
|
||||
outscale=outscale or self.outscale,
|
||||
tile_order=tile_order or self.tile_order,
|
||||
tile_size=tile_size or self.tile_size,
|
||||
name=coalesce(name, self.name),
|
||||
outscale=coalesce(outscale, self.outscale),
|
||||
tile_order=coalesce(tile_order, self.tile_order),
|
||||
tile_size=coalesce(tile_size, self.tile_size),
|
||||
)
|
||||
|
||||
|
||||
|
@ -516,18 +517,18 @@ class UpscaleParams:
|
|||
):
|
||||
logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
|
||||
return UpscaleParams(
|
||||
upscale_model=upscale_model or self.upscale_model,
|
||||
correction_model=correction_model or self.correction_model,
|
||||
denoise=denoise or self.denoise,
|
||||
upscale=upscale or self.upscale,
|
||||
faces=faces or self.faces,
|
||||
face_outscale=face_outscale or self.face_outscale,
|
||||
face_strength=face_strength or self.face_strength,
|
||||
outscale=outscale or self.outscale,
|
||||
scale=scale or self.scale,
|
||||
pre_pad=pre_pad or self.pre_pad,
|
||||
tile_pad=tile_pad or self.tile_pad,
|
||||
upscale_order=upscale_order or self.upscale_order,
|
||||
upscale_model=coalesce(upscale_model, self.upscale_model),
|
||||
correction_model=coalesce(correction_model, self.correction_model),
|
||||
denoise=coalesce(denoise, self.denoise),
|
||||
upscale=coalesce(upscale, self.upscale),
|
||||
faces=coalesce(faces, self.faces),
|
||||
face_outscale=coalesce(face_outscale, self.face_outscale),
|
||||
face_strength=coalesce(face_strength, self.face_strength),
|
||||
outscale=coalesce(outscale, self.outscale),
|
||||
scale=coalesce(scale, self.scale),
|
||||
pre_pad=coalesce(pre_pad, self.pre_pad),
|
||||
tile_pad=coalesce(tile_pad, self.tile_pad),
|
||||
upscale_order=coalesce(upscale_order, self.upscale_order),
|
||||
)
|
||||
|
||||
|
||||
|
@ -583,10 +584,10 @@ class HighresParams:
|
|||
):
|
||||
logger.debug("ignoring extra kwargs for highres: %s", kwargs)
|
||||
return HighresParams(
|
||||
enabled=enabled or self.enabled,
|
||||
scale=scale or self.scale,
|
||||
steps=steps or self.steps,
|
||||
strength=strength or self.strength,
|
||||
method=method or self.method,
|
||||
iterations=iterations or self.iterations,
|
||||
enabled=coalesce(enabled, self.enabled),
|
||||
scale=coalesce(scale, self.scale),
|
||||
steps=coalesce(steps, self.steps),
|
||||
strength=coalesce(strength, self.strength),
|
||||
method=coalesce(method, self.method),
|
||||
iterations=coalesce(iterations, self.iterations),
|
||||
)
|
||||
|
|
|
@ -251,3 +251,14 @@ def hash_value(sha, param: Optional[Param]):
|
|||
sha.update(param.encode("utf-8"))
|
||||
else:
|
||||
logger.warning("cannot hash param: %s, %s", param, type(param))
|
||||
|
||||
|
||||
def coalesce(*args, throw=False):
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
return arg
|
||||
|
||||
if throw:
|
||||
raise ValueError("no value found")
|
||||
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue