From f70de1ca79eee4c128226a9581e803b744f17064 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 30 Jan 2024 09:28:43 -0600 Subject: [PATCH] correctly handle falsey args in with_args methods --- api/onnx_web/chain/result.py | 16 +++++----- api/onnx_web/chain/upscale_swinir.py | 8 +++-- api/onnx_web/params.py | 45 ++++++++++++++-------------- api/onnx_web/utils.py | 11 +++++++ 4 files changed, 48 insertions(+), 32 deletions(-) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 918391a9..805559c7 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -11,7 +11,7 @@ from ..convert.utils import resolve_tensor from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..server.context import ServerContext from ..server.load import get_extra_hashes -from ..utils import hash_file, load_config_str +from ..utils import coalesce, hash_file, load_config_str logger = getLogger(__name__) @@ -269,13 +269,13 @@ class ImageMetadata: return ImageMetadata( params or self.params, size or self.size, - upscale=upscale or self.upscale, - border=border or self.border, - highres=highres or self.highres, - inversions=inversions or self.inversions, - loras=loras or self.loras, - models=models or self.models, - ancestors=ancestors or self.ancestors, + upscale=coalesce(upscale, self.upscale), + border=coalesce(border, self.border), + highres=coalesce(highres, self.highres), + inversions=coalesce(inversions, self.inversions), + loras=coalesce(loras, self.loras), + models=coalesce(models, self.models), + ancestors=coalesce(ancestors, self.ancestors), ) @staticmethod diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index 97ee76aa..2f6831ae 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -73,10 +73,14 @@ class UpscaleSwinIRStage(BaseStage): upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: - logger.warning("no correction model given, skipping") + logger.warning("no upscale model given, skipping") return sources - logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) + logger.info( + "upscaling %sx with SwinIR model: %s", + upscale.outscale, + upscale.upscale_model, + ) device = worker.get_device() swinir = self.load(server, stage, upscale, device) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 57fa4d8f..f0bf7b7f 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union from .models.meta import NetworkModel from .torch_before_ort import GraphOptimizationLevel, SessionOptions +from .utils import coalesce logger = getLogger(__name__) @@ -417,10 +418,10 @@ class StageParams: ): logger.debug("ignoring extra kwargs for stage: %s", kwargs) return StageParams( - name=name or self.name, - outscale=outscale or self.outscale, - tile_order=tile_order or self.tile_order, - tile_size=tile_size or self.tile_size, + name=coalesce(name, self.name), + outscale=coalesce(outscale, self.outscale), + tile_order=coalesce(tile_order, self.tile_order), + tile_size=coalesce(tile_size, self.tile_size), ) @@ -516,18 +517,18 @@ class UpscaleParams: ): logger.debug("ignoring extra kwargs for upscale: %s", kwargs) return UpscaleParams( - upscale_model=upscale_model or self.upscale_model, - correction_model=correction_model or self.correction_model, - denoise=denoise or self.denoise, - upscale=upscale or self.upscale, - faces=faces or self.faces, - face_outscale=face_outscale or self.face_outscale, - face_strength=face_strength or self.face_strength, - outscale=outscale or self.outscale, - scale=scale or self.scale, - pre_pad=pre_pad or self.pre_pad, - tile_pad=tile_pad or self.tile_pad, - upscale_order=upscale_order or self.upscale_order, + upscale_model=coalesce(upscale_model, self.upscale_model), + correction_model=coalesce(correction_model, self.correction_model), + denoise=coalesce(denoise, self.denoise), + upscale=coalesce(upscale, self.upscale), + faces=coalesce(faces, self.faces), + face_outscale=coalesce(face_outscale, self.face_outscale), + face_strength=coalesce(face_strength, self.face_strength), + outscale=coalesce(outscale, self.outscale), + scale=coalesce(scale, self.scale), + pre_pad=coalesce(pre_pad, self.pre_pad), + tile_pad=coalesce(tile_pad, self.tile_pad), + upscale_order=coalesce(upscale_order, self.upscale_order), ) @@ -583,10 +584,10 @@ class HighresParams: ): logger.debug("ignoring extra kwargs for highres: %s", kwargs) return HighresParams( - enabled=enabled or self.enabled, - scale=scale or self.scale, - steps=steps or self.steps, - strength=strength or self.strength, - method=method or self.method, - iterations=iterations or self.iterations, + enabled=coalesce(enabled, self.enabled), + scale=coalesce(scale, self.scale), + steps=coalesce(steps, self.steps), + strength=coalesce(strength, self.strength), + method=coalesce(method, self.method), + iterations=coalesce(iterations, self.iterations), ) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index b2eacec3..3db2ad2e 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -251,3 +251,14 @@ def hash_value(sha, param: Optional[Param]): sha.update(param.encode("utf-8")) else: logger.warning("cannot hash param: %s, %s", param, type(param)) + + +def coalesce(*args, throw=False): + for arg in args: + if arg is not None: + return arg + + if throw: + raise ValueError("no value found") + + return None