pass params per stage
This commit is contained in:
parent
c6302946c0
commit
dcc0063195
|
@ -159,7 +159,7 @@ class ChainPipeline:
|
|||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
kwargs["params"] if "params" in kwargs else params,
|
||||
[source_tile],
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
|
|
|
@ -398,10 +398,15 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
_device, params, size = pipeline_from_request(server, data=kwargs)
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
|
||||
if "control" in kwargs:
|
||||
logger.warning("TODO: resolve controlnet model")
|
||||
kwargs.pop("control")
|
||||
|
||||
kwargs["params"] = params
|
||||
|
||||
stage = StageParams(
|
||||
stage_data.get("name", stage_class.__name__),
|
||||
tile_size=get_size(kwargs.get("tile_size")),
|
||||
|
|
|
@ -36,8 +36,8 @@ logger = getLogger(__name__)
|
|||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
data: Dict[str, str] = None,
|
||||
default_pipeline: str = "txt2img",
|
||||
data: Dict[str, str] = None,
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
|
|
Loading…
Reference in New Issue