1
0
Fork 0
onnx-web/api/onnx_web/chain/upscale_highres.py

44 lines
1.0 KiB
Python

from logging import getLogger
from typing import Optional
from PIL import Image
from ..chain.highres import stage_highres
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
logger = getLogger(__name__)
class UpscaleHighresStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
if highres.scale <= 1:
return source
chain = stage_highres(stage, params, highres, upscale)
return chain(
job,
server,
params,
source,
callback=callback,
)