1
0
Fork 0

correctly handle falsey args in with_args methods

This commit is contained in:
Sean Sube 2024-01-30 09:28:43 -06:00
parent c2d45b03dc
commit f70de1ca79
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 48 additions and 32 deletions

View File

@ -11,7 +11,7 @@ from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext from ..server.context import ServerContext
from ..server.load import get_extra_hashes from ..server.load import get_extra_hashes
from ..utils import hash_file, load_config_str from ..utils import coalesce, hash_file, load_config_str
logger = getLogger(__name__) logger = getLogger(__name__)
@ -269,13 +269,13 @@ class ImageMetadata:
return ImageMetadata( return ImageMetadata(
params or self.params, params or self.params,
size or self.size, size or self.size,
upscale=upscale or self.upscale, upscale=coalesce(upscale, self.upscale),
border=border or self.border, border=coalesce(border, self.border),
highres=highres or self.highres, highres=coalesce(highres, self.highres),
inversions=inversions or self.inversions, inversions=coalesce(inversions, self.inversions),
loras=loras or self.loras, loras=coalesce(loras, self.loras),
models=models or self.models, models=coalesce(models, self.models),
ancestors=ancestors or self.ancestors, ancestors=coalesce(ancestors, self.ancestors),
) )
@staticmethod @staticmethod

View File

@ -73,10 +73,14 @@ class UpscaleSwinIRStage(BaseStage):
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None: if upscale.upscale_model is None:
logger.warning("no correction model given, skipping") logger.warning("no upscale model given, skipping")
return sources return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) logger.info(
"upscaling %sx with SwinIR model: %s",
upscale.outscale,
upscale.upscale_model,
)
device = worker.get_device() device = worker.get_device()
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from .models.meta import NetworkModel from .models.meta import NetworkModel
from .torch_before_ort import GraphOptimizationLevel, SessionOptions from .torch_before_ort import GraphOptimizationLevel, SessionOptions
from .utils import coalesce
logger = getLogger(__name__) logger = getLogger(__name__)
@ -417,10 +418,10 @@ class StageParams:
): ):
logger.debug("ignoring extra kwargs for stage: %s", kwargs) logger.debug("ignoring extra kwargs for stage: %s", kwargs)
return StageParams( return StageParams(
name=name or self.name, name=coalesce(name, self.name),
outscale=outscale or self.outscale, outscale=coalesce(outscale, self.outscale),
tile_order=tile_order or self.tile_order, tile_order=coalesce(tile_order, self.tile_order),
tile_size=tile_size or self.tile_size, tile_size=coalesce(tile_size, self.tile_size),
) )
@ -516,18 +517,18 @@ class UpscaleParams:
): ):
logger.debug("ignoring extra kwargs for upscale: %s", kwargs) logger.debug("ignoring extra kwargs for upscale: %s", kwargs)
return UpscaleParams( return UpscaleParams(
upscale_model=upscale_model or self.upscale_model, upscale_model=coalesce(upscale_model, self.upscale_model),
correction_model=correction_model or self.correction_model, correction_model=coalesce(correction_model, self.correction_model),
denoise=denoise or self.denoise, denoise=coalesce(denoise, self.denoise),
upscale=upscale or self.upscale, upscale=coalesce(upscale, self.upscale),
faces=faces or self.faces, faces=coalesce(faces, self.faces),
face_outscale=face_outscale or self.face_outscale, face_outscale=coalesce(face_outscale, self.face_outscale),
face_strength=face_strength or self.face_strength, face_strength=coalesce(face_strength, self.face_strength),
outscale=outscale or self.outscale, outscale=coalesce(outscale, self.outscale),
scale=scale or self.scale, scale=coalesce(scale, self.scale),
pre_pad=pre_pad or self.pre_pad, pre_pad=coalesce(pre_pad, self.pre_pad),
tile_pad=tile_pad or self.tile_pad, tile_pad=coalesce(tile_pad, self.tile_pad),
upscale_order=upscale_order or self.upscale_order, upscale_order=coalesce(upscale_order, self.upscale_order),
) )
@ -583,10 +584,10 @@ class HighresParams:
): ):
logger.debug("ignoring extra kwargs for highres: %s", kwargs) logger.debug("ignoring extra kwargs for highres: %s", kwargs)
return HighresParams( return HighresParams(
enabled=enabled or self.enabled, enabled=coalesce(enabled, self.enabled),
scale=scale or self.scale, scale=coalesce(scale, self.scale),
steps=steps or self.steps, steps=coalesce(steps, self.steps),
strength=strength or self.strength, strength=coalesce(strength, self.strength),
method=method or self.method, method=coalesce(method, self.method),
iterations=iterations or self.iterations, iterations=coalesce(iterations, self.iterations),
) )

View File

@ -251,3 +251,14 @@ def hash_value(sha, param: Optional[Param]):
sha.update(param.encode("utf-8")) sha.update(param.encode("utf-8"))
else: else:
logger.warning("cannot hash param: %s, %s", param, type(param)) logger.warning("cannot hash param: %s, %s", param, type(param))
def coalesce(*args, throw=False):
for arg in args:
if arg is not None:
return arg
if throw:
raise ValueError("no value found")
return None