fix(api): add missing callback params to stages
This commit is contained in:
parent
bdbe6549bc
commit
256714b661
|
@ -27,7 +27,7 @@ class BlendMaskStage(BaseStage):
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
tile_mask: Optional[Image.Image] = None,
|
tile_mask: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("blending image using mask")
|
logger.info("blending image using mask")
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torchvision.transforms.functional import normalize
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
|
from ..worker.context import ProgressCallback
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -29,6 +30,7 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ..params import (
|
||||||
)
|
)
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..params import (
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class EditMetadataStage(BaseStage):
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source: StageResult,
|
source: StageResult,
|
||||||
*,
|
*,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
size: Optional[Size] = None,
|
size: Optional[Size] = None,
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
|
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class EditTextStage(BaseStage):
|
||||||
fill: str = "white",
|
fill: str = "white",
|
||||||
stroke: str = "black",
|
stroke: str = "black",
|
||||||
stroke_width: int = 1,
|
stroke_width: int = 1,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
# Add text to each image in source at the given position
|
# Add text to each image in source at the given position
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..output import save_result
|
from ..output import save_result
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class PersistDiskStage(BaseStage):
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from PIL import Image
|
||||||
from ..output import make_output_names
|
from ..output import make_output_names
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ class PersistS3Stage(BaseStage):
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class ReduceCropStage(BaseStage):
|
||||||
origin: Size,
|
origin: Size,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -22,6 +23,7 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Image.Image,
|
stage_source: Image.Image,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
|
@ -5,7 +5,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class SourceNoiseStage(BaseStage):
|
||||||
size: Size,
|
size: Size,
|
||||||
noise_source: Callable,
|
noise_source: Callable,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("generating image from noise source")
|
logger.info("generating image from noise source")
|
||||||
|
|
|
@ -6,7 +6,7 @@ from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ProgressCallback, ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import ImageMetadata, StageResult
|
from .result import ImageMetadata, StageResult
|
||||||
|
@ -27,6 +27,7 @@ class SourceS3Stage(BaseStage):
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import ImageMetadata, StageResult
|
from .result import ImageMetadata, StageResult
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class SourceURLStage(BaseStage):
|
||||||
*,
|
*,
|
||||||
source_urls: List[str],
|
source_urls: List[str],
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("loading image from URL source")
|
logger.info("loading image from URL source")
|
||||||
|
|
|
@ -17,7 +17,7 @@ from ..params import (
|
||||||
)
|
)
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from ..params import (
|
||||||
)
|
)
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -110,6 +110,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpscaleSimpleStage(BaseStage):
|
class UpscaleSimpleStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_worker: WorkerContext,
|
_worker: WorkerContext,
|
||||||
|
@ -24,6 +26,7 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
method: str,
|
method: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
if upscale.scale <= 1:
|
if upscale.scale <= 1:
|
||||||
|
|
|
@ -16,7 +16,7 @@ from ..params import (
|
||||||
)
|
)
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
Loading…
Reference in New Issue