feat(api): attempt to calculate total steps for chain pipelines
This commit is contained in:
parent
4ccdedba89
commit
55ddb9fdac
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
|
||||
from ..errors import RetryException
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
@ -85,6 +85,14 @@ class ChainPipeline:
|
|||
self.stages.append((callback, params, kwargs))
|
||||
return self
|
||||
|
||||
def steps(self, params: ImageParams, size: Size):
|
||||
steps = 0
|
||||
for callback, _params, _kwargs in self.stages:
|
||||
steps += callback.steps(params, size)
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
|
|
|
@ -48,3 +48,10 @@ class BlendGridStage(BaseStage):
|
|||
output.paste(sources[n], (x * size[0], y * size[1]))
|
||||
|
||||
return [*sources, output]
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
|
@ -103,3 +103,17 @@ class BlendImg2ImgStage(BaseStage):
|
|||
outputs.extend(result.images)
|
||||
|
||||
return outputs
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
*args,
|
||||
) -> int:
|
||||
return params.steps # TODO: multiply by strength
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -42,3 +42,10 @@ class SourceNoiseStage(BaseStage):
|
|||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
|
@ -49,3 +49,10 @@ class SourceS3Stage(BaseStage):
|
|||
logger.exception("error loading image from S3")
|
||||
|
||||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1 # TODO: len(source_keys)
|
|
@ -126,3 +126,17 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
output = list(sources)
|
||||
output.extend(result.images)
|
||||
return output
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
return params.steps
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
|
@ -42,3 +42,10 @@ class SourceURLStage(BaseStage):
|
|||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
|
@ -25,7 +25,14 @@ class BaseStage:
|
|||
|
||||
def steps(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
return 1
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources
|
||||
|
|
|
@ -50,6 +50,7 @@ def needs_tile(
|
|||
source: Optional[Image.Image] = None,
|
||||
) -> bool:
|
||||
tile = min(max_tile, stage_tile)
|
||||
logger.debug("")
|
||||
|
||||
if source is not None:
|
||||
return source.width > tile or source.height > tile
|
||||
|
|
|
@ -376,7 +376,7 @@ def load_pipeline(
|
|||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_decoder_session"]._model_path = vae_decoder
|
||||
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
|
||||
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||
|
@ -384,7 +384,7 @@ def load_pipeline(
|
|||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_encoder_session"]._model_path = vae_encoder
|
||||
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
|
||||
|
||||
else:
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
|
@ -439,12 +439,14 @@ def load_pipeline(
|
|||
|
||||
if "vae_decoder_session" in components:
|
||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||
components["vae_decoder_session"], vae_decoder
|
||||
components["vae_decoder_session"],
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
)
|
||||
|
||||
if "vae_encoder_session" in components:
|
||||
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||
components["vae_encoder_session"], vae_encoder
|
||||
components["vae_encoder_session"],
|
||||
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
|
|
|
@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, _params, _size = pipeline_from_request(server, data=data)
|
||||
device, base_params, base_size = pipeline_from_request(server, data=data)
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get("stages", []):
|
||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||
|
@ -450,7 +450,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
|
||||
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
|
||||
job_name = output[0]
|
||||
|
||||
# build and run chain pipeline
|
||||
|
@ -458,14 +458,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
job_name,
|
||||
pipeline,
|
||||
server,
|
||||
params,
|
||||
base_params,
|
||||
[],
|
||||
output=output,
|
||||
size=size,
|
||||
size=base_size,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
return jsonify(json_params(output, step_params, base_size))
|
||||
|
||||
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
Loading…
Reference in New Issue