1
0
Fork 0

feat(api): attempt to calculate total steps for chain pipelines

This commit is contained in:
Sean Sube 2023-09-12 18:16:16 -05:00
parent 4ccdedba89
commit 55ddb9fdac
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 87 additions and 12 deletions

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..errors import RetryException from ..errors import RetryException
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug, run_gc from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
@ -85,6 +85,14 @@ class ChainPipeline:
self.stages.append((callback, params, kwargs)) self.stages.append((callback, params, kwargs))
return self return self
def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, _kwargs in self.stages:
steps += callback.steps(params, size)
return steps
def __call__( def __call__(
self, self,
worker: WorkerContext, worker: WorkerContext,

View File

@ -48,3 +48,10 @@ class BlendGridStage(BaseStage):
output.paste(sources[n], (x * size[0], y * size[1])) output.paste(sources[n], (x * size[0], y * size[1]))
return [*sources, output] return [*sources, output]
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -103,3 +103,17 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return outputs return outputs
def steps(
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -42,3 +42,10 @@ class SourceNoiseStage(BaseStage):
outputs.append(output) outputs.append(output)
return outputs return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -49,3 +49,10 @@ class SourceS3Stage(BaseStage):
logger.exception("error loading image from S3") logger.exception("error loading image from S3")
return outputs return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)

View File

@ -126,3 +126,17 @@ class SourceTxt2ImgStage(BaseStage):
output = list(sources) output = list(sources)
output.extend(result.images) output.extend(result.images)
return output return output
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return params.steps
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -42,3 +42,10 @@ class SourceURLStage(BaseStage):
outputs.append(output) outputs.append(output)
return outputs return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -25,7 +25,14 @@ class BaseStage:
def steps( def steps(
self, self,
_params: ImageParams, params: ImageParams,
size: Size, size: Size,
) -> int: ) -> int:
raise NotImplementedError() return 1
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources

View File

@ -50,6 +50,7 @@ def needs_tile(
source: Optional[Image.Image] = None, source: Optional[Image.Image] = None,
) -> bool: ) -> bool:
tile = min(max_tile, stage_tile) tile = min(max_tile, stage_tile)
logger.debug("")
if source is not None: if source is not None:
return source.width > tile or source.height > tile return source.width > tile or source.height > tile

View File

@ -376,7 +376,7 @@ def load_pipeline(
provider=device.ort_provider("vae"), provider=device.ort_provider("vae"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components["vae_decoder_session"]._model_path = vae_decoder components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
logger.debug("loading VAE encoder from %s", vae_encoder) logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model( components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
@ -384,7 +384,7 @@ def load_pipeline(
provider=device.ort_provider("vae"), provider=device.ort_provider("vae"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components["vae_encoder_session"]._model_path = vae_encoder components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
else: else:
logger.debug("loading VAE decoder from %s", vae_decoder) logger.debug("loading VAE decoder from %s", vae_decoder)
@ -439,12 +439,14 @@ def load_pipeline(
if "vae_decoder_session" in components: if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder( pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"], vae_decoder components["vae_decoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
) )
if "vae_encoder_session" in components: if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder( pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"], vae_encoder components["vae_encoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
) )
if not server.show_progress: if not server.show_progress:

View File

@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters # get defaults from the regular parameters
device, _params, _size = pipeline_from_request(server, data=data) device, base_params, base_size = pipeline_from_request(server, data=data)
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get("stages", []): for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")] stage_class = CHAIN_STAGES[stage_data.get("type")]
@ -450,7 +450,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages)) output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
job_name = output[0] job_name = output[0]
# build and run chain pipeline # build and run chain pipeline
@ -458,14 +458,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
job_name, job_name,
pipeline, pipeline,
server, server,
params, base_params,
[], [],
output=output, output=output,
size=size, size=base_size,
needs_device=device, needs_device=device,
) )
return jsonify(json_params(output, params, size)) step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):