diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 44358f25..91d464fb 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -27,7 +27,7 @@ class BlendMaskStage(BaseStage): stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, tile_mask: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("blending image using mask") diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 8a33cba8..803c82e9 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -10,6 +10,7 @@ from torchvision.transforms.functional import normalize from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext +from ..worker.context import ProgressCallback from .base import BaseStage from .result import StageResult @@ -29,6 +30,7 @@ class CorrectCodeformerStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + callback: Optional[ProgressCallback] = None, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index d80363dc..d8023215 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -13,7 +13,7 @@ from ..params import ( ) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -66,6 +66,7 @@ class CorrectGFPGANStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + callback: Optional[ProgressCallback] = None, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, diff --git a/api/onnx_web/chain/edit_metadata.py b/api/onnx_web/chain/edit_metadata.py index 0ca4d1c5..02edaf78 100644 --- a/api/onnx_web/chain/edit_metadata.py +++ b/api/onnx_web/chain/edit_metadata.py @@ -9,7 +9,7 @@ from ..params import ( UpscaleParams, ) from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -25,6 +25,7 @@ class EditMetadataStage(BaseStage): _params: ImageParams, source: StageResult, *, + callback: Optional[ProgressCallback] = None, size: Optional[Size] = None, upscale: Optional[UpscaleParams] = None, highres: Optional[HighresParams] = None, diff --git a/api/onnx_web/chain/edit_text.py b/api/onnx_web/chain/edit_text.py index ddbe8cd9..2aa453d2 100644 --- a/api/onnx_web/chain/edit_text.py +++ b/api/onnx_web/chain/edit_text.py @@ -1,10 +1,10 @@ -from typing import Tuple +from typing import Optional, Tuple from PIL import ImageDraw from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -25,6 +25,7 @@ class EditTextStage(BaseStage): fill: str = "white", stroke: str = "black", stroke_width: int = 1, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: # Add text to each image in source at the given position diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 2afaed00..340956a9 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -6,7 +6,7 @@ from PIL import Image from ..output import save_result from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -24,6 +24,7 @@ class PersistDiskStage(BaseStage): params: ImageParams, sources: StageResult, *, + callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 4e118bb1..c1825d1b 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -9,7 +9,7 @@ from PIL import Image from ..output import make_output_names from ..params import ImageParams, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -29,6 +29,7 @@ class PersistS3Stage(BaseStage): endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: session = Session(profile_name=profile_name) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 3a81ce39..6261f045 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -5,7 +5,7 @@ from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -24,6 +24,7 @@ class ReduceCropStage(BaseStage): origin: Size, size: Size, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: outputs = [] diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 79970301..565c8c0a 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -1,10 +1,11 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -22,6 +23,7 @@ class ReduceThumbnailStage(BaseStage): *, size: Size, stage_source: Image.Image, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: outputs = [] diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index bfa2d94b..19554af2 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -5,7 +5,7 @@ from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -24,6 +24,7 @@ class SourceNoiseStage(BaseStage): size: Size, noise_source: Callable, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("generating image from noise source") diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index a34d087f..6e2df551 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -6,7 +6,7 @@ from boto3 import Session from PIL import Image from ..params import ImageParams, StageParams -from ..server import ServerContext +from ..server import ProgressCallback, ServerContext from ..worker import WorkerContext from .base import BaseStage from .result import ImageMetadata, StageResult @@ -27,6 +27,7 @@ class SourceS3Stage(BaseStage): bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: session = Session(profile_name=profile_name) diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 60a3ca4f..e25fa590 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -7,7 +7,7 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import ImageMetadata, StageResult @@ -25,6 +25,7 @@ class SourceURLStage(BaseStage): *, source_urls: List[str], stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("loading image from URL source") diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index bffc32f1..4b4d5752 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -17,7 +17,7 @@ from ..params import ( ) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -68,6 +68,7 @@ class UpscaleBSRGANStage(BaseStage): upscale: UpscaleParams, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: upscale = upscale.with_args(**kwargs) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index f818057c..ff9cb98d 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -14,7 +14,7 @@ from ..params import ( ) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -110,6 +110,7 @@ class UpscaleRealESRGANStage(BaseStage): upscale: UpscaleParams, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 5cf5f24c..686f9c71 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -3,9 +3,9 @@ from typing import Optional from PIL import Image -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import ImageParams, SizeChart, StageParams, UpscaleParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -13,6 +13,8 @@ logger = getLogger(__name__) class UpscaleSimpleStage(BaseStage): + max_tile = SizeChart.max + def run( self, _worker: WorkerContext, @@ -24,6 +26,7 @@ class UpscaleSimpleStage(BaseStage): method: str, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: if upscale.scale <= 1: diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index a5510c28..97ee76aa 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -16,7 +16,7 @@ from ..params import ( ) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -67,6 +67,7 @@ class UpscaleSwinIRStage(BaseStage): upscale: UpscaleParams, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: upscale = upscale.with_args(**kwargs)