feat(api): attempt to calculate total steps for chain pipelines
This commit is contained in:
parent
4ccdedba89
commit
55ddb9fdac
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
|
|
||||||
from ..errors import RetryException
|
from ..errors import RetryException
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug, run_gc
|
from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
@ -85,6 +85,14 @@ class ChainPipeline:
|
||||||
self.stages.append((callback, params, kwargs))
|
self.stages.append((callback, params, kwargs))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def steps(self, params: ImageParams, size: Size):
|
||||||
|
steps = 0
|
||||||
|
for callback, _params, _kwargs in self.stages:
|
||||||
|
steps += callback.steps(params, size)
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
|
|
|
@ -48,3 +48,10 @@ class BlendGridStage(BaseStage):
|
||||||
output.paste(sources[n], (x * size[0], y * size[1]))
|
output.paste(sources[n], (x * size[0], y * size[1]))
|
||||||
|
|
||||||
return [*sources, output]
|
return [*sources, output]
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
|
@ -103,3 +103,17 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
*args,
|
||||||
|
) -> int:
|
||||||
|
return params.steps # TODO: multiply by strength
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
||||||
|
|
|
@ -42,3 +42,10 @@ class SourceNoiseStage(BaseStage):
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
|
@ -49,3 +49,10 @@ class SourceS3Stage(BaseStage):
|
||||||
logger.exception("error loading image from S3")
|
logger.exception("error loading image from S3")
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1 # TODO: len(source_keys)
|
|
@ -126,3 +126,17 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
output = list(sources)
|
output = list(sources)
|
||||||
output.extend(result.images)
|
output.extend(result.images)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
) -> int:
|
||||||
|
return params.steps
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
|
@ -42,3 +42,10 @@ class SourceURLStage(BaseStage):
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
|
@ -25,7 +25,14 @@ class BaseStage:
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
_params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
) -> int:
|
) -> int:
|
||||||
raise NotImplementedError()
|
return 1
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources
|
||||||
|
|
|
@ -50,6 +50,7 @@ def needs_tile(
|
||||||
source: Optional[Image.Image] = None,
|
source: Optional[Image.Image] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
tile = min(max_tile, stage_tile)
|
tile = min(max_tile, stage_tile)
|
||||||
|
logger.debug("")
|
||||||
|
|
||||||
if source is not None:
|
if source is not None:
|
||||||
return source.width > tile or source.height > tile
|
return source.width > tile or source.height > tile
|
||||||
|
|
|
@ -376,7 +376,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components["vae_decoder_session"]._model_path = vae_decoder
|
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
|
@ -384,7 +384,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components["vae_encoder_session"]._model_path = vae_encoder
|
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
@ -439,12 +439,14 @@ def load_pipeline(
|
||||||
|
|
||||||
if "vae_decoder_session" in components:
|
if "vae_decoder_session" in components:
|
||||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||||
components["vae_decoder_session"], vae_decoder
|
components["vae_decoder_session"],
|
||||||
|
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||||
)
|
)
|
||||||
|
|
||||||
if "vae_encoder_session" in components:
|
if "vae_encoder_session" in components:
|
||||||
pipe.vae_encoder = ORTModelVaeEncoder(
|
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||||
components["vae_encoder_session"], vae_encoder
|
components["vae_encoder_session"],
|
||||||
|
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
|
|
|
@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
# get defaults from the regular parameters
|
||||||
device, _params, _size = pipeline_from_request(server, data=data)
|
device, base_params, base_size = pipeline_from_request(server, data=data)
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get("stages", []):
|
for stage_data in data.get("stages", []):
|
||||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||||
|
@ -450,7 +450,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
|
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
|
@ -458,14 +458,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
job_name,
|
job_name,
|
||||||
pipeline,
|
pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
base_params,
|
||||||
[],
|
[],
|
||||||
output=output,
|
output=output,
|
||||||
size=size,
|
size=base_size,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||||
|
return jsonify(json_params(output, step_params, base_size))
|
||||||
|
|
||||||
|
|
||||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
Loading…
Reference in New Issue