1
0
Fork 0

feat(api): make LPW an image parameter

This commit is contained in:
Sean Sube 2023-02-05 17:15:37 -06:00
parent 6fe278c744
commit fb376c6b62
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 35 additions and 11 deletions

View File

@ -31,7 +31,10 @@ def blend_img2img(
params.model, params.model,
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw,
) )
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)

View File

@ -64,7 +64,10 @@ def blend_inpaint(
params.model, params.model,
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw,
) )
if params.lpw:
pipe = pipe.inpaint
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)

View File

@ -32,8 +32,10 @@ def source_txt2img(
) )
pipe = load_pipeline( pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
) )
if params.lpw:
pipe = pipe.text2img
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)

View File

@ -71,11 +71,13 @@ def upscale_outpaint(
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
) )
if params.lpw:
pipe = pipe.inpaint
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.inpaint( result = pipe(
image, image,
mask, mask,
prompt, prompt,
@ -96,7 +98,7 @@ def upscale_outpaint(
margin_y = float(max(border.top, border.bottom)) margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source_image.width, margin_y / source_image.height) overlap = min(margin_x / source_image.width, margin_y / source_image.height)
if overlap > 0 and border.left == border.right and border.top == border.bottom: if border.left == border.right and border.top == border.bottom:
logger.debug("outpainting with an even border, using spiral tiling") logger.debug("outpainting with an even border, using spiral tiling")
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap) output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap)
else: else:

View File

@ -47,13 +47,13 @@ def get_tile_latents(
def load_pipeline( def load_pipeline(
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams, lpw: bool
): ):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
options = (pipeline, model, device.provider) options = (pipeline, model, device.device, device.provider, lpw)
if last_pipeline_instance is not None and last_pipeline_options == options: if last_pipeline_instance is not None and last_pipeline_options == options:
logger.debug("reusing existing diffusion pipeline") logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance pipe = last_pipeline_instance
@ -63,6 +63,11 @@ def load_pipeline(
last_pipeline_scheduler = None last_pipeline_scheduler = None
run_gc() run_gc()
if lpw:
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
else:
custom_pipeline = None
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, model,
@ -72,7 +77,7 @@ def load_pipeline(
) )
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
custom_pipeline="./onnx_web/diffusion/lpw_stable_diffusion_onnx.py", custom_pipeline=custom_pipeline,
provider=device.provider, provider=device.provider,
provider_options=device.options, provider_options=device.options,
revision="onnx", revision="onnx",

View File

@ -25,14 +25,16 @@ def run_txt2img_pipeline(
upscale: UpscaleParams, upscale: UpscaleParams,
) -> None: ) -> None:
pipe = load_pipeline( pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
) )
if params.lpw:
pipe = pipe.text2img
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe.text2img( result = pipe(
params.prompt, params.prompt,
height=size.height, height=size.height,
width=size.width, width=size.width,
@ -72,12 +74,15 @@ def run_img2img_pipeline(
params.model, params.model,
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw
) )
if params.lpw:
pipe = pipe.img2img
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
progress = job.get_progress_callback() progress = job.get_progress_callback()
result = pipe.img2img( result = pipe(
source_image, source_image,
params.prompt, params.prompt,
generator=rng, generator=rng,

View File

@ -86,10 +86,11 @@ class ImageParams:
model: str, model: str,
scheduler: Any, scheduler: Any,
prompt: str, prompt: str,
negative_prompt: Optional[str],
cfg: float, cfg: float,
steps: int, steps: int,
seed: int, seed: int,
negative_prompt: Optional[str] = None,
lpw: Optional[bool] = False,
) -> None: ) -> None:
self.model = model self.model = model
self.scheduler = scheduler self.scheduler = scheduler
@ -98,6 +99,7 @@ class ImageParams:
self.cfg = cfg self.cfg = cfg
self.seed = seed self.seed = seed
self.steps = steps self.steps = steps
self.lpw = lpw or False
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
@ -108,6 +110,7 @@ class ImageParams:
"cfg": self.cfg, "cfg": self.cfg,
"seed": self.seed, "seed": self.seed,
"steps": self.steps, "steps": self.steps,
"lpw": self.lpw,
} }

View File

@ -171,6 +171,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
device = available_platforms[0] device = available_platforms[0]
# pipeline stuff # pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model")) model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model) model_path = get_model_path(model)
scheduler = get_from_map( scheduler = get_from_map(
@ -233,7 +234,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
) )
params = ImageParams( params = ImageParams(
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw
) )
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)