fix(api): add missing callback params to stages
This commit is contained in:
parent
bdbe6549bc
commit
256714b661
|
@ -27,7 +27,7 @@ class BlendMaskStage(BaseStage):
|
|||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
tile_mask: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("blending image using mask")
|
||||
|
|
|
@ -10,6 +10,7 @@ from torchvision.transforms.functional import normalize
|
|||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker.context import ProgressCallback
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -29,6 +30,7 @@ class CorrectCodeformerStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..params import (
|
|||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -66,6 +66,7 @@ class CorrectGFPGANStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..params import (
|
|||
UpscaleParams,
|
||||
)
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -25,6 +25,7 @@ class EditMetadataStage(BaseStage):
|
|||
_params: ImageParams,
|
||||
source: StageResult,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
size: Optional[Size] = None,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from PIL import ImageDraw
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -25,6 +25,7 @@ class EditTextStage(BaseStage):
|
|||
fill: str = "white",
|
||||
stroke: str = "black",
|
||||
stroke_width: int = 1,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
# Add text to each image in source at the given position
|
||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||
from ..output import save_result
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -24,6 +24,7 @@ class PersistDiskStage(BaseStage):
|
|||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -9,7 +9,7 @@ from PIL import Image
|
|||
from ..output import make_output_names
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -29,6 +29,7 @@ class PersistS3Stage(BaseStage):
|
|||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
session = Session(profile_name=profile_name)
|
||||
|
|
|
@ -5,7 +5,7 @@ from PIL import Image
|
|||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -24,6 +24,7 @@ class ReduceCropStage(BaseStage):
|
|||
origin: Size,
|
||||
size: Size,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
outputs = []
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -22,6 +23,7 @@ class ReduceThumbnailStage(BaseStage):
|
|||
*,
|
||||
size: Size,
|
||||
stage_source: Image.Image,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
outputs = []
|
||||
|
|
|
@ -5,7 +5,7 @@ from PIL import Image
|
|||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -24,6 +24,7 @@ class SourceNoiseStage(BaseStage):
|
|||
size: Size,
|
||||
noise_source: Callable,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("generating image from noise source")
|
||||
|
|
|
@ -6,7 +6,7 @@ from boto3 import Session
|
|||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..server import ProgressCallback, ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import ImageMetadata, StageResult
|
||||
|
@ -27,6 +27,7 @@ class SourceS3Stage(BaseStage):
|
|||
bucket: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
session = Session(profile_name=profile_name)
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import ImageMetadata, StageResult
|
||||
|
||||
|
@ -25,6 +25,7 @@ class SourceURLStage(BaseStage):
|
|||
*,
|
||||
source_urls: List[str],
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("loading image from URL source")
|
||||
|
|
|
@ -17,7 +17,7 @@ from ..params import (
|
|||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -68,6 +68,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
|
|
@ -14,7 +14,7 @@ from ..params import (
|
|||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -110,6 +110,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||
|
|
|
@ -3,9 +3,9 @@ from typing import Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -13,6 +13,8 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class UpscaleSimpleStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
|
@ -24,6 +26,7 @@ class UpscaleSimpleStage(BaseStage):
|
|||
method: str,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
if upscale.scale <= 1:
|
||||
|
|
|
@ -16,7 +16,7 @@ from ..params import (
|
|||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
@ -67,6 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
|
Loading…
Reference in New Issue