diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 5716fd22..a00a5b4d 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -25,7 +25,7 @@ class TileCallback(Protocol): Definition for a tile job function. """ - def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult: + def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult: """ Run this stage against a single tile. """ @@ -366,10 +366,7 @@ def process_tile_order( scale: int, filters: List[TileCallback], **kwargs, -) -> Image.Image: - """ - TODO: needs to handle more than one image - """ +) -> List[Image.Image]: if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) return process_tile_stack( @@ -483,7 +480,7 @@ def generate_tile_grid( height: int, tile: int, overlap: float = 0.0, -) -> List[Tuple[int, int]]: +) -> List[Tuple[int, int, Image.Image]]: adj_tile = int(float(tile) * (1.0 - overlap)) tiles_x = ceil(width / adj_tile) tiles_y = ceil(height / adj_tile) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index e641b953..39468542 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union from flask import request @@ -345,7 +345,7 @@ PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] def pipeline_from_json( server: ServerContext, - data: Dict[str, str], + data: Dict[str, Union[str, Dict[str, str]]], default_pipeline: str = "txt2img", ) -> PipelineParams: """ diff --git a/api/pyproject.toml b/api/pyproject.toml index 5d69e906..b40218d0 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -23,9 +23,14 @@ module = [ "basicsr", "boto3", "codeformer", - "codeformer.facelib.utils.misc", - "codeformer.facelib.utils", + "codeformer.basicsr", + "codeformer.basicsr.utils", + "codeformer.basicsr.utils.download_util", + "codeformer.basicsr.utils.registry", "codeformer.facelib", + "codeformer.facelib.utils", + "codeformer.facelib.utils.misc", + "codeformer.facelib.utils.face_restoration_helper", "compel", "controlnet_aux", "cv2", @@ -73,6 +78,7 @@ module = [ "safetensors", "scipy", "timm.models.layers", + "torchvision.transforms.functional", "transformers", "win10toast" ]