correctly handle falsey args in with_args methods
This commit is contained in:
parent
c2d45b03dc
commit
f70de1ca79
|
@ -11,7 +11,7 @@ from ..convert.utils import resolve_tensor
|
||||||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||||
from ..server.context import ServerContext
|
from ..server.context import ServerContext
|
||||||
from ..server.load import get_extra_hashes
|
from ..server.load import get_extra_hashes
|
||||||
from ..utils import hash_file, load_config_str
|
from ..utils import coalesce, hash_file, load_config_str
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -269,13 +269,13 @@ class ImageMetadata:
|
||||||
return ImageMetadata(
|
return ImageMetadata(
|
||||||
params or self.params,
|
params or self.params,
|
||||||
size or self.size,
|
size or self.size,
|
||||||
upscale=upscale or self.upscale,
|
upscale=coalesce(upscale, self.upscale),
|
||||||
border=border or self.border,
|
border=coalesce(border, self.border),
|
||||||
highres=highres or self.highres,
|
highres=coalesce(highres, self.highres),
|
||||||
inversions=inversions or self.inversions,
|
inversions=coalesce(inversions, self.inversions),
|
||||||
loras=loras or self.loras,
|
loras=coalesce(loras, self.loras),
|
||||||
models=models or self.models,
|
models=coalesce(models, self.models),
|
||||||
ancestors=ancestors or self.ancestors,
|
ancestors=coalesce(ancestors, self.ancestors),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -73,10 +73,14 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
logger.warning("no correction model given, skipping")
|
logger.warning("no upscale model given, skipping")
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
logger.info(
|
||||||
|
"upscaling %sx with SwinIR model: %s",
|
||||||
|
upscale.outscale,
|
||||||
|
upscale.upscale_model,
|
||||||
|
)
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from .models.meta import NetworkModel
|
from .models.meta import NetworkModel
|
||||||
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||||
|
from .utils import coalesce
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -417,10 +418,10 @@ class StageParams:
|
||||||
):
|
):
|
||||||
logger.debug("ignoring extra kwargs for stage: %s", kwargs)
|
logger.debug("ignoring extra kwargs for stage: %s", kwargs)
|
||||||
return StageParams(
|
return StageParams(
|
||||||
name=name or self.name,
|
name=coalesce(name, self.name),
|
||||||
outscale=outscale or self.outscale,
|
outscale=coalesce(outscale, self.outscale),
|
||||||
tile_order=tile_order or self.tile_order,
|
tile_order=coalesce(tile_order, self.tile_order),
|
||||||
tile_size=tile_size or self.tile_size,
|
tile_size=coalesce(tile_size, self.tile_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -516,18 +517,18 @@ class UpscaleParams:
|
||||||
):
|
):
|
||||||
logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
|
logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscale_model=upscale_model or self.upscale_model,
|
upscale_model=coalesce(upscale_model, self.upscale_model),
|
||||||
correction_model=correction_model or self.correction_model,
|
correction_model=coalesce(correction_model, self.correction_model),
|
||||||
denoise=denoise or self.denoise,
|
denoise=coalesce(denoise, self.denoise),
|
||||||
upscale=upscale or self.upscale,
|
upscale=coalesce(upscale, self.upscale),
|
||||||
faces=faces or self.faces,
|
faces=coalesce(faces, self.faces),
|
||||||
face_outscale=face_outscale or self.face_outscale,
|
face_outscale=coalesce(face_outscale, self.face_outscale),
|
||||||
face_strength=face_strength or self.face_strength,
|
face_strength=coalesce(face_strength, self.face_strength),
|
||||||
outscale=outscale or self.outscale,
|
outscale=coalesce(outscale, self.outscale),
|
||||||
scale=scale or self.scale,
|
scale=coalesce(scale, self.scale),
|
||||||
pre_pad=pre_pad or self.pre_pad,
|
pre_pad=coalesce(pre_pad, self.pre_pad),
|
||||||
tile_pad=tile_pad or self.tile_pad,
|
tile_pad=coalesce(tile_pad, self.tile_pad),
|
||||||
upscale_order=upscale_order or self.upscale_order,
|
upscale_order=coalesce(upscale_order, self.upscale_order),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -583,10 +584,10 @@ class HighresParams:
|
||||||
):
|
):
|
||||||
logger.debug("ignoring extra kwargs for highres: %s", kwargs)
|
logger.debug("ignoring extra kwargs for highres: %s", kwargs)
|
||||||
return HighresParams(
|
return HighresParams(
|
||||||
enabled=enabled or self.enabled,
|
enabled=coalesce(enabled, self.enabled),
|
||||||
scale=scale or self.scale,
|
scale=coalesce(scale, self.scale),
|
||||||
steps=steps or self.steps,
|
steps=coalesce(steps, self.steps),
|
||||||
strength=strength or self.strength,
|
strength=coalesce(strength, self.strength),
|
||||||
method=method or self.method,
|
method=coalesce(method, self.method),
|
||||||
iterations=iterations or self.iterations,
|
iterations=coalesce(iterations, self.iterations),
|
||||||
)
|
)
|
||||||
|
|
|
@ -251,3 +251,14 @@ def hash_value(sha, param: Optional[Param]):
|
||||||
sha.update(param.encode("utf-8"))
|
sha.update(param.encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
logger.warning("cannot hash param: %s, %s", param, type(param))
|
logger.warning("cannot hash param: %s, %s", param, type(param))
|
||||||
|
|
||||||
|
|
||||||
|
def coalesce(*args, throw=False):
|
||||||
|
for arg in args:
|
||||||
|
if arg is not None:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
if throw:
|
||||||
|
raise ValueError("no value found")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
Loading…
Reference in New Issue