start replacing image output with results
This commit is contained in:
parent
5a517704ea
commit
a63669c76b
|
@ -17,7 +17,7 @@ class BaseStage:
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
_sources: List[Image.Image],
|
_sources: StageResult,
|
||||||
*args,
|
*args,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -22,19 +23,19 @@ class BlendDenoiseStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
strength: int = 3,
|
strength: int = 3,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("denoising source images")
|
logger.info("denoising source images")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
data = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
|
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
||||||
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
||||||
results.append(Image.fromarray(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)))
|
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
return results
|
return StageResult(arrays=results)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ class BlendGridStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
|
@ -31,7 +32,7 @@ class BlendGridStage(BaseStage):
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("combining source images using grid layout")
|
logger.info("combining source images using grid layout")
|
||||||
|
|
||||||
size = sources[0].size
|
size = sources[0].size
|
||||||
|
@ -49,7 +50,7 @@ class BlendGridStage(BaseStage):
|
||||||
n = order[i]
|
n = order[i]
|
||||||
output.paste(sources[n], (x * size[0], y * size[1]))
|
output.paste(sources[n], (x * size[0], y * size[1]))
|
||||||
|
|
||||||
return [*sources, output]
|
return StageResult(images=[*sources, output])
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
strength: float,
|
strength: float,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
prompt_index: Optional[int] = None,
|
prompt_index: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
|
|
||||||
# multi-stage prompting
|
# multi-stage prompting
|
||||||
|
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
pipe_params["strength"] = strength
|
pipe_params["strength"] = strength
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
if params.is_lpw():
|
if params.is_lpw():
|
||||||
logger.debug("using LPW pipeline for img2img")
|
logger.debug("using LPW pipeline for img2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
@ -101,7 +102,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
|
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,13 +19,13 @@ class BlendLinearStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
alpha: float,
|
alpha: float,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
_callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return [Image.blend(source, stage_source, alpha) for source in sources]
|
return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources])
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,13 +21,13 @@ class BlendMaskStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
_callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("blending image using mask")
|
logger.info("blending image using mask")
|
||||||
|
|
||||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
||||||
|
@ -37,4 +38,4 @@ class BlendMaskStage(BaseStage):
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
return [Image.composite(stage_source, source, mult_mask) for source in sources]
|
return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources])
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
# TODO: rewrite and remove
|
# TODO: rewrite and remove
|
||||||
from codeformer import CodeFormer
|
from codeformer import CodeFormer
|
||||||
|
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
|
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
return [pipe(source) for source in sources]
|
return StageResult(images=[pipe(source) for source in sources])
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -57,12 +58,12 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.correction_model is None:
|
if upscale.correction_model is None:
|
||||||
|
@ -73,16 +74,12 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
gfpgan = self.load(server, stage, upscale, device)
|
gfpgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = [gfpgan.enhance(
|
||||||
for source in sources:
|
source,
|
||||||
output = np.array(source)
|
|
||||||
_, _, output = gfpgan.enhance(
|
|
||||||
output,
|
|
||||||
has_aligned=False,
|
has_aligned=False,
|
||||||
only_center_face=False,
|
only_center_face=False,
|
||||||
paste_back=True,
|
paste_back=True,
|
||||||
weight=upscale.face_strength,
|
weight=upscale.face_strength,
|
||||||
)
|
) for source in sources.as_numpy()]
|
||||||
outputs.append(Image.fromarray(output, "RGB"))
|
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -21,13 +22,13 @@ class PersistDiskStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
output: List[str],
|
output: List[str],
|
||||||
size: Optional[Size] = None,
|
size: Optional[Size] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info(
|
logger.info(
|
||||||
"persisting images to disk: %s, %s", [s.size for s in sources], output
|
"persisting images to disk: %s, %s", [s.size for s in sources], output
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ class PersistS3Stage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
output: str,
|
output: str,
|
||||||
bucket: str,
|
bucket: str,
|
||||||
|
@ -28,11 +29,11 @@ class PersistS3Stage(BaseStage):
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source.save(data, format=server.image_format)
|
source.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .tile import needs_tile, process_tile_order
|
from .tile import needs_tile, process_tile_order
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -73,26 +74,27 @@ class ChainPipeline:
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
callback: Optional[ProgressCallback],
|
callback: Optional[ProgressCallback],
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
return self(
|
result = self(
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
)
|
)
|
||||||
|
return result.as_image()
|
||||||
|
|
||||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
self.stages.append((callback, params, kwargs))
|
self.stages.append((callback, params, kwargs))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def steps(self, params: ImageParams, size: Size):
|
def steps(self, params: ImageParams, size: Size) -> int:
|
||||||
steps = 0
|
steps = 0
|
||||||
for callback, _params, kwargs in self.stages:
|
for callback, _params, kwargs in self.stages:
|
||||||
steps += callback.steps(kwargs.get("params", params), size)
|
steps += callback.steps(kwargs.get("params", params), size)
|
||||||
|
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
def outputs(self, params: ImageParams, sources: int):
|
def outputs(self, params: ImageParams, sources: int) -> int:
|
||||||
outputs = sources
|
outputs = sources
|
||||||
for callback, _params, kwargs in self.stages:
|
for callback, _params, kwargs in self.stages:
|
||||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||||
|
@ -104,10 +106,10 @@ class ChainPipeline:
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**pipeline_kwargs
|
**pipeline_kwargs
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
"""
|
"""
|
||||||
DEPRECATED: use `run` instead
|
DEPRECATED: use `run` instead
|
||||||
"""
|
"""
|
||||||
|
@ -161,23 +163,21 @@ class ChainPipeline:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
if stage_sources or must_tile:
|
if stage_sources or must_tile:
|
||||||
stage_outputs = []
|
stage_results = []
|
||||||
for source in stage_sources:
|
for source in stage_sources:
|
||||||
logger.info(
|
logger.info(
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
tile,
|
tile,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_tiles = []
|
|
||||||
|
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: Image.Image,
|
source_tile: Image.Image,
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
) -> Image.Image:
|
) -> StageResult:
|
||||||
for _i in range(worker.retries):
|
for _i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
output_tile = stage_pipe.run(
|
tile_result = stage_pipe.run(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
|
@ -189,17 +189,11 @@ class ChainPipeline:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(output_tile) > 1:
|
|
||||||
while len(extra_tiles) < len(output_tile):
|
|
||||||
extra_tiles.append([])
|
|
||||||
|
|
||||||
for tile, layer in zip(output_tile, extra_tiles):
|
|
||||||
layer.append((tile, dims))
|
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "last-tile.png", output_tile[0])
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
return output_tile[0]
|
return tile_result
|
||||||
except Exception:
|
except Exception:
|
||||||
worker.retries = worker.retries - 1
|
worker.retries = worker.retries - 1
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
@ -220,17 +214,9 @@ class ChainPipeline:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
stage_outputs.append(output)
|
stage_results.append(output)
|
||||||
|
|
||||||
if len(extra_tiles) > 1:
|
stage_sources = StageResult(images=stage_results)
|
||||||
for layer in extra_tiles:
|
|
||||||
layer_output = Image.new("RGB", output.size)
|
|
||||||
for layer_tile, dims in layer:
|
|
||||||
layer_output.paste(layer_tile, (dims[0], dims[1]))
|
|
||||||
|
|
||||||
stage_outputs.append(layer_output)
|
|
||||||
|
|
||||||
stage_sources = stage_outputs
|
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"image does not contain sources and is within tile size of %s, running stage",
|
"image does not contain sources and is within tile size of %s, running stage",
|
||||||
|
@ -238,7 +224,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
for i in range(worker.retries):
|
for i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
stage_outputs = stage_pipe.run(
|
stage_result = stage_pipe.run(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
|
@ -250,7 +236,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
||||||
# does not like, so it throws
|
# does not like, so it throws
|
||||||
stage_sources = stage_outputs
|
stage_sources = stage_result
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
worker.retries = worker.retries - 1
|
worker.retries = worker.retries - 1
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
origin: Size,
|
origin: Size,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||||
logger.info(
|
logger.info(
|
||||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||||
)
|
)
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,15 +19,15 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Image.Image,
|
stage_source: Image.Image,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
image = source.copy()
|
image = source.copy()
|
||||||
|
|
||||||
image = image.thumbnail((size.width, size.height))
|
image = image.thumbnail((size.width, size.height))
|
||||||
|
@ -37,4 +38,4 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from PIL.Image import Image, fromarray
|
from PIL import Image
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -7,25 +7,35 @@ class StageResult:
|
||||||
"""
|
"""
|
||||||
Chain pipeline stage result.
|
Chain pipeline stage result.
|
||||||
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||||
|
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||||
|
they are expected.
|
||||||
"""
|
"""
|
||||||
arrays: Optional[List[np.ndarray]]
|
arrays: Optional[List[np.ndarray]]
|
||||||
images: Optional[List[Image]]
|
images: Optional[List[Image.Image]]
|
||||||
|
|
||||||
def __init__(self, arrays = None, images = None) -> None:
|
def __init__(self, arrays = None, images = None) -> None:
|
||||||
if arrays is not None and images is not None:
|
if arrays is not None and images is not None:
|
||||||
raise ValueError("stages must only return one type of result")
|
raise ValueError("stages must only return one type of result")
|
||||||
|
elif arrays is None and images is None:
|
||||||
|
raise ValueError("stages must return results")
|
||||||
|
|
||||||
self.arrays = arrays
|
self.arrays = arrays
|
||||||
self.images = images
|
self.images = images
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.arrays is not None:
|
||||||
|
return len(self.arrays)
|
||||||
|
else:
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
def as_numpy(self) -> List[np.ndarray]:
|
def as_numpy(self) -> List[np.ndarray]:
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
return self.arrays
|
return self.arrays
|
||||||
|
|
||||||
return [np.array(i) for i in self.images]
|
return [np.array(i) for i in self.images]
|
||||||
|
|
||||||
def as_image(self) -> List[Image]:
|
def as_image(self) -> List[Image.Image]:
|
||||||
if self.images is not None:
|
if self.images is not None:
|
||||||
return self.images
|
return self.images
|
||||||
|
|
||||||
return [fromarray(i) for i in self.arrays]
|
return [Image.fromarray(i, "RGB") for i in self.arrays]
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,13 +19,13 @@ class SourceNoiseStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
noise_source: Callable,
|
noise_source: Callable,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("generating image from noise source")
|
logger.info("generating image from noise source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
|
@ -32,16 +33,16 @@ class SourceNoiseStage(BaseStage):
|
||||||
"source images were passed to a source stage, new images will be appended"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = list(sources)
|
outputs = []
|
||||||
|
|
||||||
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,14 +21,14 @@ class SourceS3Stage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
source_keys: List[str],
|
source_keys: List[str],
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ class SourceS3Stage(BaseStage):
|
||||||
"source images were passed to a source stage, new images will be appended"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = list(sources)
|
outputs = sources.as_image()
|
||||||
for key in source_keys:
|
for key in source_keys:
|
||||||
try:
|
try:
|
||||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||||
|
@ -48,7 +49,7 @@ class SourceS3Stage(BaseStage):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error loading image from S3")
|
logger.exception("error loading image from S3")
|
||||||
|
|
||||||
return outputs
|
return StageResult(outputs)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -19,6 +19,7 @@ from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
dims: Tuple[int, int, int] = None,
|
dims: Tuple[int, int, int] = None,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -153,10 +154,10 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = list(sources)
|
outputs = list(sources)
|
||||||
output.extend(result.images)
|
outputs.extend(result.images)
|
||||||
logger.debug("produced %s outputs", len(output))
|
logger.debug("produced %s outputs", len(outputs))
|
||||||
return output
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,12 +21,12 @@ class SourceURLStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
source_urls: List[str],
|
source_urls: List[str],
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("loading image from URL source")
|
logger.info("loading image from URL source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
|
@ -41,7 +42,7 @@ class SourceURLStage(BaseStage):
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..image.noise_source import noise_source_histogram
|
from ..image.noise_source import noise_source_histogram
|
||||||
from ..params import Size, TileOrder
|
from ..params import Size, TileOrder
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
# from skimage.exposure import match_histograms
|
# from skimage.exposure import match_histograms
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ class TileCallback(Protocol):
|
||||||
Definition for a tile job function.
|
Definition for a tile job function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
|
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||||
"""
|
"""
|
||||||
Run this stage against a single tile.
|
Run this stage against a single tile.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -54,12 +55,12 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
|
@ -71,8 +72,8 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
bsrgan = self.load(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
image = np.array(source) / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
@ -99,7 +100,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
|
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from ..worker.context import ProgressCallback
|
from ..worker.context import ProgressCallback
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*args,
|
*args,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
if highres.scale <= 1:
|
if highres.scale <= 1:
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
chain = stage_highres(stage, params, highres, upscale)
|
chain = stage_highres(stage, params, highres, upscale)
|
||||||
|
|
||||||
return [
|
outputs = [
|
||||||
chain(
|
chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
|
@ -43,3 +44,5 @@ class UpscaleHighresStage(BaseStage):
|
||||||
)
|
)
|
||||||
for source in sources
|
for source in sources
|
||||||
]
|
]
|
||||||
|
|
||||||
|
return StageResult(images=outputs)
|
|
@ -19,6 +19,7 @@ from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
border: Border,
|
border: Border,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
|
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||||
params
|
params
|
||||||
)
|
)
|
||||||
|
@ -61,7 +62,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "tile-source.png", source)
|
save_image(server, "tile-source.png", source)
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
|
@ -122,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
|
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -77,25 +78,22 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for source in sources:
|
|
||||||
output = np.array(source)
|
|
||||||
upsampler = self.load(
|
upsampler = self.load(
|
||||||
server, upscale, worker.get_device(), tile=stage.tile_size
|
server, upscale, worker.get_device(), tile=stage.tile_size
|
||||||
)
|
)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
outputs = []
|
||||||
|
for source in sources.as_numpy():
|
||||||
output = Image.fromarray(output, "RGB")
|
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(arrays=outputs)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
method: str,
|
method: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
if upscale.scale <= 1:
|
if upscale.scale <= 1:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
||||||
|
@ -32,7 +33,7 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||||
|
|
||||||
if method == "bilinear":
|
if method == "bilinear":
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -58,7 +59,7 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
pipeline.unet.set_prompts(prompt_embeds)
|
pipeline.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
result = pipeline(
|
result = pipeline(
|
||||||
prompt,
|
prompt,
|
||||||
source,
|
source,
|
||||||
|
|
|
@ -11,6 +11,7 @@ from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
|
@ -71,31 +72,27 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
# TODO: add support for grayscale (1-channel) images
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = np.array(source) / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.trace("SwinIR input shape: %s", image.shape)
|
logger.trace("SwinIR input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
logger.trace("SwinIR output shape: %s", (
|
||||||
(
|
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
|
||||||
|
|
||||||
dest = swinir(image)
|
output = swinir(image)
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
logger.info("output image size: %s", output.shape)
|
||||||
logger.info("output image size: %s x %s", output.width, output.height)
|
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -486,7 +486,7 @@ def run_blend_pipeline(
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
# highres: HighresParams,
|
# highres: HighresParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
mask: Image.Image,
|
mask: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
|
|
Loading…
Reference in New Issue