diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index a62d6fcb..b939f028 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,7 +7,7 @@ from PIL import Image from ..errors import RetryException from ..output import save_image -from ..params import ImageParams, StageParams +from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..utils import is_debug, run_gc from ..worker import ProgressCallback, WorkerContext @@ -85,6 +85,14 @@ class ChainPipeline: self.stages.append((callback, params, kwargs)) return self + def steps(self, params: ImageParams, size: Size): + steps = 0 + for callback, _params, _kwargs in self.stages: + steps += callback.steps(params, size) + + return steps + + def __call__( self, worker: WorkerContext, diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 5a23f779..561319ab 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -48,3 +48,10 @@ class BlendGridStage(BaseStage): output.paste(sources[n], (x * size[0], y * size[1])) return [*sources, output] + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 \ No newline at end of file diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index e3c249a9..fa2023d4 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -103,3 +103,17 @@ class BlendImg2ImgStage(BaseStage): outputs.extend(result.images) return outputs + + def steps( + self, + params: ImageParams, + *args, + ) -> int: + return params.steps # TODO: multiply by strength + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 5e6035d8..930599e2 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -42,3 +42,10 @@ class SourceNoiseStage(BaseStage): outputs.append(output) return outputs + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 \ No newline at end of file diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 55f8f228..6fa50b7e 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -49,3 +49,10 @@ class SourceS3Stage(BaseStage): logger.exception("error loading image from S3") return outputs + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 # TODO: len(source_keys) \ No newline at end of file diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index cda81b3d..ed81f8db 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -126,3 +126,17 @@ class SourceTxt2ImgStage(BaseStage): output = list(sources) output.extend(result.images) return output + + def steps( + self, + params: ImageParams, + size: Size, + ) -> int: + return params.steps + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 \ No newline at end of file diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 54f86c54..29d88350 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -42,3 +42,10 @@ class SourceURLStage(BaseStage): outputs.append(output) return outputs + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 \ No newline at end of file diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py index 781b65de..fff56ba7 100644 --- a/api/onnx_web/chain/stage.py +++ b/api/onnx_web/chain/stage.py @@ -25,7 +25,14 @@ class BaseStage: def steps( self, - _params: ImageParams, + params: ImageParams, size: Size, ) -> int: - raise NotImplementedError() + return 1 + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index c80a5719..8f228392 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -50,6 +50,7 @@ def needs_tile( source: Optional[Image.Image] = None, ) -> bool: tile = min(max_tile, stage_tile) + logger.debug("") if source is not None: return source.width > tile or source.height > tile diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index c7bc953b..1bb8a00e 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -376,7 +376,7 @@ def load_pipeline( provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - components["vae_decoder_session"]._model_path = vae_decoder + components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system" logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder_session"] = OnnxRuntimeModel.load_model( @@ -384,7 +384,7 @@ def load_pipeline( provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - components["vae_encoder_session"]._model_path = vae_encoder + components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system" else: logger.debug("loading VAE decoder from %s", vae_decoder) @@ -439,12 +439,14 @@ def load_pipeline( if "vae_decoder_session" in components: pipe.vae_decoder = ORTModelVaeDecoder( - components["vae_decoder_session"], vae_decoder + components["vae_decoder_session"], + pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method ) if "vae_encoder_session" in components: pipe.vae_encoder = ORTModelVaeEncoder( - components["vae_encoder_session"], vae_encoder + components["vae_encoder_session"], + pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method ) if not server.show_progress: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 85103bba..e5142756 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): validate(data, schema) # get defaults from the regular parameters - device, _params, _size = pipeline_from_request(server, data=data) + device, base_params, base_size = pipeline_from_request(server, data=data) pipeline = ChainPipeline() for stage_data in data.get("stages", []): stage_class = CHAIN_STAGES[stage_data.get("type")] @@ -450,7 +450,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - output = make_output_name(server, "chain", params, size, count=len(pipeline.stages)) + output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages)) job_name = output[0] # build and run chain pipeline @@ -458,14 +458,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): job_name, pipeline, server, - params, + base_params, [], output=output, - size=size, + size=base_size, needs_device=device, ) - return jsonify(json_params(output, params, size)) + step_params = params.with_args(steps=pipeline.steps(base_params, base_size)) + return jsonify(json_params(output, step_params, base_size)) def blend(server: ServerContext, pool: DevicePoolExecutor):