read chain params per stage
This commit is contained in:
parent
07662b22df
commit
441a47a885
|
@ -388,18 +388,16 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, params, size = pipeline_from_request(server)
|
||||
output = make_output_name(server, "chain", params, size)
|
||||
job_name = output[0]
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
device, _params, _size = pipeline_from_request(server)
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get("stages", []):
|
||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||
kwargs: Dict[str, Any] = stage_data.get("params", {})
|
||||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||
|
||||
_device, params, size = pipeline_from_request(server, data=kwargs)
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
if "control" in kwargs:
|
||||
logger.warning("TODO: resolve controlnet model")
|
||||
kwargs.pop("control")
|
||||
|
@ -447,6 +445,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
|
||||
job_name = output[0]
|
||||
|
||||
# build and run chain pipeline
|
||||
pool.submit(
|
||||
job_name,
|
||||
|
|
|
@ -34,11 +34,17 @@ from .utils import get_model_path
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def pipeline_from_json(
|
||||
server: ServerContext,
|
||||
data: Dict[str, Any],
|
||||
default_pipeline: str = "txt2img",
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
data: Dict[str, str] = None,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
# platform stuff
|
||||
device = None
|
||||
device_name = data.get("platform")
|
||||
|
||||
|
@ -51,6 +57,12 @@ def pipeline_from_json(
|
|||
model = get_not_empty(data, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
||||
control = None
|
||||
control_name = data.get("control")
|
||||
for network in get_network_models():
|
||||
if network.name == control_name:
|
||||
control = network
|
||||
|
||||
# pipeline stuff
|
||||
pipeline = get_from_list(
|
||||
data, "pipeline", get_available_pipelines(), default_pipeline
|
||||
|
@ -149,165 +161,6 @@ def pipeline_from_json(
|
|||
# this one can safely use np.random because it produces a single value
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.debug(
|
||||
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
steps,
|
||||
scheduler,
|
||||
model_path,
|
||||
pipeline,
|
||||
device or "any device",
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
seed,
|
||||
prompt,
|
||||
)
|
||||
|
||||
params = ImageParams(
|
||||
model_path,
|
||||
pipeline,
|
||||
scheduler,
|
||||
prompt,
|
||||
cfg,
|
||||
steps,
|
||||
seed,
|
||||
eta=eta,
|
||||
negative_prompt=negative_prompt,
|
||||
batch=batch,
|
||||
# TODO: control=control,
|
||||
loopback=loopback,
|
||||
tiled_vae=tiled_vae,
|
||||
tiles=tiles,
|
||||
overlap=overlap,
|
||||
stride=stride,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
# platform stuff
|
||||
device = None
|
||||
device_name = request.args.get("platform")
|
||||
|
||||
if device_name is not None and device_name != "any":
|
||||
for platform in get_available_platforms():
|
||||
if platform.device == device_name:
|
||||
device = platform
|
||||
|
||||
# diffusion model
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
||||
control = None
|
||||
control_name = request.args.get("control")
|
||||
for network in get_network_models():
|
||||
if network.name == control_name:
|
||||
control = network
|
||||
|
||||
# pipeline stuff
|
||||
pipeline = get_from_list(
|
||||
request.args, "pipeline", get_available_pipelines(), default_pipeline
|
||||
)
|
||||
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = get_config_value("scheduler")
|
||||
|
||||
# prompt does not come from config
|
||||
prompt = request.args.get("prompt", "")
|
||||
negative_prompt = request.args.get("negativePrompt", None)
|
||||
|
||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||
negative_prompt = None
|
||||
|
||||
# image params
|
||||
batch = get_and_clamp_int(
|
||||
request.args,
|
||||
"batch",
|
||||
get_config_value("batch"),
|
||||
get_config_value("batch", "max"),
|
||||
get_config_value("batch", "min"),
|
||||
)
|
||||
cfg = get_and_clamp_float(
|
||||
request.args,
|
||||
"cfg",
|
||||
get_config_value("cfg"),
|
||||
get_config_value("cfg", "max"),
|
||||
get_config_value("cfg", "min"),
|
||||
)
|
||||
eta = get_and_clamp_float(
|
||||
request.args,
|
||||
"eta",
|
||||
get_config_value("eta"),
|
||||
get_config_value("eta", "max"),
|
||||
get_config_value("eta", "min"),
|
||||
)
|
||||
loopback = get_and_clamp_int(
|
||||
request.args,
|
||||
"loopback",
|
||||
get_config_value("loopback"),
|
||||
get_config_value("loopback", "max"),
|
||||
get_config_value("loopback", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args,
|
||||
"steps",
|
||||
get_config_value("steps"),
|
||||
get_config_value("steps", "max"),
|
||||
get_config_value("steps", "min"),
|
||||
)
|
||||
height = get_and_clamp_int(
|
||||
request.args,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
request.args,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
|
||||
tiles = get_and_clamp_int(
|
||||
request.args,
|
||||
"tiles",
|
||||
get_config_value("tiles"),
|
||||
get_config_value("tiles", "max"),
|
||||
get_config_value("tiles", "min"),
|
||||
)
|
||||
overlap = get_and_clamp_float(
|
||||
request.args,
|
||||
"overlap",
|
||||
get_config_value("overlap"),
|
||||
get_config_value("overlap", "max"),
|
||||
get_config_value("overlap", "min"),
|
||||
)
|
||||
stride = get_and_clamp_int(
|
||||
request.args,
|
||||
"stride",
|
||||
get_config_value("stride"),
|
||||
get_config_value("stride", "max"),
|
||||
get_config_value("stride", "min"),
|
||||
)
|
||||
|
||||
if stride > tiles:
|
||||
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
|
||||
stride = tiles
|
||||
|
||||
seed = int(request.args.get("seed", -1))
|
||||
if seed == -1:
|
||||
# this one can safely use np.random because it produces a single value
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.info(
|
||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
|
|
Loading…
Reference in New Issue