1
0
Fork 0

read chain params per stage

This commit is contained in:
Sean Sube 2023-09-10 21:21:57 -05:00
parent 07662b22df
commit 441a47a885
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 23 additions and 169 deletions

View File

@ -388,18 +388,16 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request(server)
output = make_output_name(server, "chain", params, size)
job_name = output[0]
replace_wildcards(params, get_wildcard_data())
device, _params, _size = pipeline_from_request(server)
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs: Dict[str, Any] = stage_data.get("params", {})
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
_device, params, size = pipeline_from_request(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
if "control" in kwargs:
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
@ -447,6 +445,9 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
job_name = output[0]
# build and run chain pipeline
pool.submit(
job_name,

View File

@ -34,11 +34,17 @@ from .utils import get_model_path
logger = getLogger(__name__)
def pipeline_from_json(
server: ServerContext,
data: Dict[str, Any],
default_pipeline: str = "txt2img",
def pipeline_from_request(
server: ServerContext,
data: Dict[str, str] = None,
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
if data is None:
data = request.args
# platform stuff
device = None
device_name = data.get("platform")
@ -51,6 +57,12 @@ def pipeline_from_json(
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
control = None
control_name = data.get("control")
for network in get_network_models():
if network.name == control_name:
control = network
# pipeline stuff
pipeline = get_from_list(
data, "pipeline", get_available_pipelines(), default_pipeline
@ -149,165 +161,6 @@ def pipeline_from_json(
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.debug(
"parsed parameters for %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
pipeline,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
negative_prompt=negative_prompt,
batch=batch,
# TODO: control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
)
size = Size(width, height)
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
device = None
device_name = request.args.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
# diffusion model
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(server, model)
control = None
control_name = request.args.get("control")
for network in get_network_models():
if network.name == control_name:
control = network
# pipeline stuff
pipeline = get_from_list(
request.args, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
# prompt does not come from config
prompt = request.args.get("prompt", "")
negative_prompt = request.args.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
# image params
batch = get_and_clamp_int(
request.args,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
request.args,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
request.args,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
)
overlap = get_and_clamp_float(
request.args,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
stride = get_and_clamp_int(
request.args,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
)
if stride > tiles:
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles
seed = int(request.args.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,