1
0
Fork 0

pass params per stage

This commit is contained in:
Sean Sube 2023-09-10 22:19:29 -05:00
parent c6302946c0
commit dcc0063195
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 7 additions and 2 deletions

View File

@ -159,7 +159,7 @@ class ChainPipeline:
worker,
server,
stage_params,
params,
kwargs["params"] if "params" in kwargs else params,
[source_tile],
tile_mask=tile_mask,
callback=callback,

View File

@ -398,10 +398,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
_device, params, size = pipeline_from_request(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
if "model" in kwargs:
kwargs.pop("model")
if "control" in kwargs:
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
kwargs["params"] = params
stage = StageParams(
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),

View File

@ -36,8 +36,8 @@ logger = getLogger(__name__)
def pipeline_from_request(
server: ServerContext,
data: Dict[str, str] = None,
default_pipeline: str = "txt2img",
data: Dict[str, str] = None,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr