1
0
Fork 0

feat(api): attempt to calculate total steps for chain pipelines

This commit is contained in:
Sean Sube 2023-09-12 18:16:16 -05:00
parent 4ccdedba89
commit 55ddb9fdac
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 87 additions and 12 deletions

View File

@ -7,7 +7,7 @@ from PIL import Image
from ..errors import RetryException
from ..output import save_image
from ..params import ImageParams, StageParams
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
@ -85,6 +85,14 @@ class ChainPipeline:
self.stages.append((callback, params, kwargs))
return self
def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, _kwargs in self.stages:
steps += callback.steps(params, size)
return steps
def __call__(
self,
worker: WorkerContext,

View File

@ -48,3 +48,10 @@ class BlendGridStage(BaseStage):
output.paste(sources[n], (x * size[0], y * size[1]))
return [*sources, output]
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -103,3 +103,17 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images)
return outputs
def steps(
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -42,3 +42,10 @@ class SourceNoiseStage(BaseStage):
outputs.append(output)
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -49,3 +49,10 @@ class SourceS3Stage(BaseStage):
logger.exception("error loading image from S3")
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)

View File

@ -126,3 +126,17 @@ class SourceTxt2ImgStage(BaseStage):
output = list(sources)
output.extend(result.images)
return output
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return params.steps
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -42,3 +42,10 @@ class SourceURLStage(BaseStage):
outputs.append(output)
return outputs
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -25,7 +25,14 @@ class BaseStage:
def steps(
self,
_params: ImageParams,
params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()
return 1
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources

View File

@ -50,6 +50,7 @@ def needs_tile(
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)
logger.debug("")
if source is not None:
return source.width > tile or source.height > tile

View File

@ -376,7 +376,7 @@ def load_pipeline(
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_decoder_session"]._model_path = vae_decoder
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
@ -384,7 +384,7 @@ def load_pipeline(
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_encoder_session"]._model_path = vae_encoder
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
@ -439,12 +439,14 @@ def load_pipeline(
if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"], vae_decoder
components["vae_decoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)
if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"], vae_encoder
components["vae_encoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)
if not server.show_progress:

View File

@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema)
# get defaults from the regular parameters
device, _params, _size = pipeline_from_request(server, data=data)
device, base_params, base_size = pipeline_from_request(server, data=data)
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
@ -450,7 +450,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
job_name = output[0]
# build and run chain pipeline
@ -458,14 +458,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
job_name,
pipeline,
server,
params,
base_params,
[],
output=output,
size=size,
size=base_size,
needs_device=device,
)
return jsonify(json_params(output, params, size))
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor):