1
0
Fork 0

fix(api): add missing callback params to stages

This commit is contained in:
Sean Sube 2024-01-28 13:09:06 -06:00
parent bdbe6549bc
commit 256714b661
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 36 additions and 17 deletions

View File

@ -27,7 +27,7 @@ class BlendMaskStage(BaseStage):
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
tile_mask: Optional[Image.Image] = None, tile_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info("blending image using mask") logger.info("blending image using mask")

View File

@ -10,6 +10,7 @@ from torchvision.transforms.functional import normalize
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -29,6 +30,7 @@ class CorrectCodeformerStage(BaseStage):
sources: StageResult, sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,

View File

@ -13,7 +13,7 @@ from ..params import (
) )
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -66,6 +66,7 @@ class CorrectGFPGANStage(BaseStage):
sources: StageResult, sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,

View File

@ -9,7 +9,7 @@ from ..params import (
UpscaleParams, UpscaleParams,
) )
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -25,6 +25,7 @@ class EditMetadataStage(BaseStage):
_params: ImageParams, _params: ImageParams,
source: StageResult, source: StageResult,
*, *,
callback: Optional[ProgressCallback] = None,
size: Optional[Size] = None, size: Optional[Size] = None,
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,

View File

@ -1,10 +1,10 @@
from typing import Tuple from typing import Optional, Tuple
from PIL import ImageDraw from PIL import ImageDraw
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -25,6 +25,7 @@ class EditTextStage(BaseStage):
fill: str = "white", fill: str = "white",
stroke: str = "black", stroke: str = "black",
stroke_width: int = 1, stroke_width: int = 1,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
# Add text to each image in source at the given position # Add text to each image in source at the given position

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..output import save_result from ..output import save_result
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -24,6 +24,7 @@ class PersistDiskStage(BaseStage):
params: ImageParams, params: ImageParams,
sources: StageResult, sources: StageResult,
*, *,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:

View File

@ -9,7 +9,7 @@ from PIL import Image
from ..output import make_output_names from ..output import make_output_names
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -29,6 +29,7 @@ class PersistS3Stage(BaseStage):
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)

View File

@ -5,7 +5,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -24,6 +24,7 @@ class ReduceCropStage(BaseStage):
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
outputs = [] outputs = []

View File

@ -1,10 +1,11 @@
from logging import getLogger from logging import getLogger
from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -22,6 +23,7 @@ class ReduceThumbnailStage(BaseStage):
*, *,
size: Size, size: Size,
stage_source: Image.Image, stage_source: Image.Image,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
outputs = [] outputs = []

View File

@ -5,7 +5,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -24,6 +24,7 @@ class SourceNoiseStage(BaseStage):
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info("generating image from noise source") logger.info("generating image from noise source")

View File

@ -6,7 +6,7 @@ from boto3 import Session
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ProgressCallback, ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import ImageMetadata, StageResult from .result import ImageMetadata, StageResult
@ -27,6 +27,7 @@ class SourceS3Stage(BaseStage):
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import ImageMetadata, StageResult from .result import ImageMetadata, StageResult
@ -25,6 +25,7 @@ class SourceURLStage(BaseStage):
*, *,
source_urls: List[str], source_urls: List[str],
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info("loading image from URL source") logger.info("loading image from URL source")

View File

@ -17,7 +17,7 @@ from ..params import (
) )
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -68,6 +68,7 @@ class UpscaleBSRGANStage(BaseStage):
upscale: UpscaleParams, upscale: UpscaleParams,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)

View File

@ -14,7 +14,7 @@ from ..params import (
) )
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -110,6 +110,7 @@ class UpscaleRealESRGANStage(BaseStage):
upscale: UpscaleParams, upscale: UpscaleParams,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)

View File

@ -3,9 +3,9 @@ from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -13,6 +13,8 @@ logger = getLogger(__name__)
class UpscaleSimpleStage(BaseStage): class UpscaleSimpleStage(BaseStage):
max_tile = SizeChart.max
def run( def run(
self, self,
_worker: WorkerContext, _worker: WorkerContext,
@ -24,6 +26,7 @@ class UpscaleSimpleStage(BaseStage):
method: str, method: str,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
if upscale.scale <= 1: if upscale.scale <= 1:

View File

@ -16,7 +16,7 @@ from ..params import (
) )
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult from .result import StageResult
@ -67,6 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
upscale: UpscaleParams, upscale: UpscaleParams,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)