1
0
Fork 0

fix(api): fully switch between LPW and regular ONNX pipelines

This commit is contained in:
Sean Sube 2023-02-05 17:36:00 -06:00
parent f3983a7917
commit 5f35a2853b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 141 additions and 82 deletions

View File

@ -35,11 +35,18 @@ def blend_img2img(
params.lpw,
)
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
generator=rng,
@ -49,6 +56,7 @@ def blend_img2img(
num_inference_steps=params.steps,
strength=strength,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)

View File

@ -60,6 +60,7 @@ def blend_inpaint(
save_image(server, "tile-source.png", image)
save_image(server, "tile-mask.png", mask)
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
OnnxStableDiffusionInpaintPipeline,
params.model,
@ -67,14 +68,23 @@ def blend_inpaint(
job.get_device(),
params.lpw,
)
if params.lpw:
pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=image,
latents=latents,
mask_image=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
result = pipe(
params.prompt,
generator=rng,
@ -87,6 +97,7 @@ def blend_inpaint(
num_inference_steps=params.steps,
width=size.width,
)
return result.images[0]
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])

View File

@ -32,17 +32,25 @@ def source_txt2img(
"a source image was passed to a txt2img stage, but will be discarded"
)
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
)
if params.lpw:
pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
result = pipe(
prompt,
height=size.height,
@ -53,6 +61,7 @@ def source_txt2img(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)

View File

@ -66,6 +66,7 @@ def upscale_outpaint(
save_image(server, "tile-source.png", image)
save_image(server, "tile-mask.png", mask)
latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline(
OnnxStableDiffusionInpaintPipeline,
params.model,
@ -73,14 +74,8 @@ def upscale_outpaint(
job.get_device(),
)
if params.lpw:
pipe = pipe.inpaint
rng = torch.manual_seed(params.seed)
else:
rng = np.random.RandomState(params.seed)
latents = get_tile_latents(full_latents, dims)
result = pipe(
result = pipe.inpaint(
image,
mask,
prompt,
@ -92,6 +87,22 @@ def upscale_outpaint(
num_inference_steps=params.steps,
width=size.width,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
image,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
mask_image=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")

View File

@ -25,18 +25,27 @@ def run_txt2img_pipeline(
output: str,
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
)
progress = job.get_progress_callback()
if params.lpw:
pipe = pipe.text2img
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=progress,
)
else:
rng = np.random.RandomState(params.seed)
latents = get_latents_from_seed(params.seed, size)
progress = job.get_progress_callback()
result = pipe(
params.prompt,
height=size.height,
@ -48,6 +57,8 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale
@ -79,14 +90,21 @@ def run_img2img_pipeline(
job.get_device(),
params.lpw
)
progress = job.get_progress_callback()
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source_image,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=progress,
)
else:
rng = np.random.RandomState(params.seed)
progress = job.get_progress_callback()
result = pipe(
source_image,
params.prompt,
@ -97,6 +115,8 @@ def run_img2img_pipeline(
strength=strength,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale

View File

@ -143,8 +143,8 @@ correction_models = []
upscaling_models = []
def get_config_value(key: str, subkey: str = "default"):
return config_params.get(key).get(subkey)
def get_config_value(key: str, subkey: str = "default", default = None):
return config_params.get(key, {}).get(subkey, default)
def url_from_rule(rule) -> str:
@ -234,7 +234,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
)
params = ImageParams(
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw
model_path, scheduler, prompt, cfg, steps, seed, lpw=lpw, negative_prompt=negative_prompt
)
size = Size(width, height)
return (device, params, size)
@ -330,7 +330,7 @@ def load_params(context: ServerContext):
if "platform" in config_params and context.default_platform is not None:
logger.info("overriding default platform to %s", context.default_platform)
config_platform = config_params.get("platform")
config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform