feat(api): make LPW an image parameter
This commit is contained in:
parent
6fe278c744
commit
fb376c6b62
|
@ -31,7 +31,10 @@ def blend_img2img(
|
|||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.img2img
|
||||
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
||||
|
|
|
@ -64,7 +64,10 @@ def blend_inpaint(
|
|||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.inpaint
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -32,8 +32,10 @@ def source_txt2img(
|
|||
)
|
||||
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.text2img
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -71,11 +71,13 @@ def upscale_outpaint(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.inpaint
|
||||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
||||
result = pipe.inpaint(
|
||||
result = pipe(
|
||||
image,
|
||||
mask,
|
||||
prompt,
|
||||
|
@ -96,7 +98,7 @@ def upscale_outpaint(
|
|||
margin_y = float(max(border.top, border.bottom))
|
||||
overlap = min(margin_x / source_image.width, margin_y / source_image.height)
|
||||
|
||||
if overlap > 0 and border.left == border.right and border.top == border.bottom:
|
||||
if border.left == border.right and border.top == border.bottom:
|
||||
logger.debug("outpainting with an even border, using spiral tiling")
|
||||
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap)
|
||||
else:
|
||||
|
|
|
@ -47,13 +47,13 @@ def get_tile_latents(
|
|||
|
||||
|
||||
def load_pipeline(
|
||||
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams
|
||||
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams, lpw: bool
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (pipeline, model, device.provider)
|
||||
options = (pipeline, model, device.device, device.provider, lpw)
|
||||
if last_pipeline_instance is not None and last_pipeline_options == options:
|
||||
logger.debug("reusing existing diffusion pipeline")
|
||||
pipe = last_pipeline_instance
|
||||
|
@ -63,6 +63,11 @@ def load_pipeline(
|
|||
last_pipeline_scheduler = None
|
||||
run_gc()
|
||||
|
||||
if lpw:
|
||||
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
|
||||
else:
|
||||
custom_pipeline = None
|
||||
|
||||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
scheduler = scheduler.from_pretrained(
|
||||
model,
|
||||
|
@ -72,7 +77,7 @@ def load_pipeline(
|
|||
)
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
custom_pipeline="./onnx_web/diffusion/lpw_stable_diffusion_onnx.py",
|
||||
custom_pipeline=custom_pipeline,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
revision="onnx",
|
||||
|
|
|
@ -25,14 +25,16 @@ def run_txt2img_pipeline(
|
|||
upscale: UpscaleParams,
|
||||
) -> None:
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.text2img
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe.text2img(
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
|
@ -72,12 +74,15 @@ def run_img2img_pipeline(
|
|||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw
|
||||
)
|
||||
if params.lpw:
|
||||
pipe = pipe.img2img
|
||||
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe.img2img(
|
||||
result = pipe(
|
||||
source_image,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
|
|
|
@ -86,10 +86,11 @@ class ImageParams:
|
|||
model: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int,
|
||||
negative_prompt: Optional[str] = None,
|
||||
lpw: Optional[bool] = False,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -98,6 +99,7 @@ class ImageParams:
|
|||
self.cfg = cfg
|
||||
self.seed = seed
|
||||
self.steps = steps
|
||||
self.lpw = lpw or False
|
||||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
|
@ -108,6 +110,7 @@ class ImageParams:
|
|||
"cfg": self.cfg,
|
||||
"seed": self.seed,
|
||||
"steps": self.steps,
|
||||
"lpw": self.lpw,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -171,6 +171,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
device = available_platforms[0]
|
||||
|
||||
# pipeline stuff
|
||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(model)
|
||||
scheduler = get_from_map(
|
||||
|
@ -233,7 +234,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
)
|
||||
|
||||
params = ImageParams(
|
||||
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed
|
||||
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
|
Loading…
Reference in New Issue