fix(api): fully switch between LPW and regular ONNX pipelines
This commit is contained in:
parent
f3983a7917
commit
5f35a2853b
|
@ -35,20 +35,28 @@ def blend_img2img(
|
|||
params.lpw,
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.img2img
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source_image,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source_image,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
result = pipe(
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source_image,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
)
|
||||
output = result.images[0]
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
|
|
|
@ -60,6 +60,7 @@ def blend_inpaint(
|
|||
save_image(server, "tile-source.png", image)
|
||||
save_image(server, "tile-mask.png", mask)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
|
@ -67,26 +68,36 @@ def blend_inpaint(
|
|||
job.get_device(),
|
||||
params.lpw,
|
||||
)
|
||||
|
||||
if params.lpw:
|
||||
pipe = pipe.inpaint
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=image,
|
||||
latents=latents,
|
||||
mask_image=mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=image,
|
||||
latents=latents,
|
||||
mask_image=mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=image,
|
||||
latents=latents,
|
||||
mask_image=mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
|
|
@ -32,27 +32,36 @@ def source_txt2img(
|
|||
"a source image was passed to a txt2img stage, but will be discarded"
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
|
||||
)
|
||||
|
||||
if params.lpw:
|
||||
pipe = pipe.text2img
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
|
||||
result = pipe(
|
||||
prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
)
|
||||
output = result.images[0]
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
|
|
|
@ -66,6 +66,7 @@ def upscale_outpaint(
|
|||
save_image(server, "tile-source.png", image)
|
||||
save_image(server, "tile-mask.png", mask)
|
||||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
|
@ -73,25 +74,35 @@ def upscale_outpaint(
|
|||
job.get_device(),
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.inpaint
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
image,
|
||||
mask,
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
image,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
mask_image=mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
)
|
||||
|
||||
result = pipe(
|
||||
image,
|
||||
mask,
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||
|
|
|
@ -25,29 +25,40 @@ def run_txt2img_pipeline(
|
|||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
) -> None:
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
|
||||
if params.lpw:
|
||||
pipe = pipe.text2img
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
params.prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=progress,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=progress,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale
|
||||
|
@ -79,24 +90,33 @@ def run_img2img_pipeline(
|
|||
job.get_device(),
|
||||
params.lpw
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
if params.lpw:
|
||||
pipe = pipe.img2img
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
source_image,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
callback=progress,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
source_image,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe(
|
||||
source_image,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
callback=progress,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale
|
||||
|
|
|
@ -143,8 +143,8 @@ correction_models = []
|
|||
upscaling_models = []
|
||||
|
||||
|
||||
def get_config_value(key: str, subkey: str = "default"):
|
||||
return config_params.get(key).get(subkey)
|
||||
def get_config_value(key: str, subkey: str = "default", default = None):
|
||||
return config_params.get(key, {}).get(subkey, default)
|
||||
|
||||
|
||||
def url_from_rule(rule) -> str:
|
||||
|
@ -234,7 +234,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
)
|
||||
|
||||
params = ImageParams(
|
||||
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw
|
||||
model_path, scheduler, prompt, cfg, steps, seed, lpw=lpw, negative_prompt=negative_prompt
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
@ -330,7 +330,7 @@ def load_params(context: ServerContext):
|
|||
|
||||
if "platform" in config_params and context.default_platform is not None:
|
||||
logger.info("overriding default platform to %s", context.default_platform)
|
||||
config_platform = config_params.get("platform")
|
||||
config_platform = config_params.get("platform", {})
|
||||
config_platform["default"] = context.default_platform
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue