1
0
Fork 0
onnx-web/api/onnx_web/diffusers/run.py

378 lines
8.8 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from logging import getLogger
from typing import Any, List, Optional
2023-01-16 00:46:00 +00:00
2023-02-16 03:01:25 +00:00
from PIL import Image
2023-02-05 13:53:26 +00:00
from ..chain import (
blend_img2img,
blend_mask,
source_txt2img,
upscale_highres,
upscale_outpaint,
)
from ..chain.base import ChainPipeline
from ..output import save_image
2023-04-01 17:06:31 +00:00
from ..params import (
Border,
HighresParams,
ImageParams,
Size,
StageParams,
UpscaleParams,
)
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
2023-02-26 20:15:30 +00:00
from ..worker import WorkerContext
2023-07-01 02:42:24 +00:00
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
def run_txt2img_pipeline(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
2023-07-01 02:42:24 +00:00
chain.stage(
source_txt2img,
stage,
size=size,
)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
for _i in range(highres.iterations):
2023-07-01 02:42:24 +00:00
chain.stage(
upscale_highres,
StageParams(
outscale=highres.scale,
),
highres=highres,
upscale=upscale,
overlap=params.overlap,
)
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
2023-07-01 02:42:24 +00:00
upscale=after_upscale,
chain=chain,
)
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, None, callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
highres=highres,
2023-04-13 05:02:47 +00:00
inversions=inversions,
loras=loras,
2023-02-05 13:53:26 +00:00
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished txt2img job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished txt2img job: %s", dest)
2023-01-16 00:54:20 +00:00
def run_img2img_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
strength: float,
source_filter: Optional[str] = None,
) -> None:
# run filter on the source image
if source_filter is not None:
f = get_source_filters().get(source_filter, None)
if f is not None:
logger.debug("running source filter: %s", f.__name__)
source = f(server, source)
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
2023-07-01 02:42:24 +00:00
chain.stage(
blend_img2img,
stage,
strength=strength,
2023-02-05 13:53:26 +00:00
)
2023-04-13 04:30:59 +00:00
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
2023-01-16 00:54:20 +00:00
# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
2023-07-01 02:42:24 +00:00
chain.stage(
blend_img2img,
stage,
strength=strength,
)
2023-04-14 03:51:59 +00:00
# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
2023-07-01 02:42:24 +00:00
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
2023-04-22 15:54:39 +00:00
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
)
# run and append the filtered source
progress = job.get_progress_callback()
images = [
chain(job, server, params, source, callback=progress),
]
if source_filter is not None and source_filter != "none":
images.append(source)
# save with metadata
_prompt_pairs, loras, inversions = parse_prompt(params)
size = Size(*source.size)
2023-01-16 00:54:20 +00:00
for image, output in zip(images, outputs):
dest = save_image(
2023-06-26 12:48:39 +00:00
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished img2img job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished img2img job: %s", dest)
2023-01-16 00:54:20 +00:00
def run_inpaint_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
mask: Image.Image,
border: Border,
2023-01-16 00:54:20 +00:00
noise_source: Any,
mask_filter: Any,
fill_color: str,
tile_order: str,
) -> None:
logger.debug("building inpaint pipeline")
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
2023-07-01 02:42:24 +00:00
chain.stage(
upscale_outpaint,
stage,
border=border,
stage_mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
)
2023-01-16 00:54:20 +00:00
# apply highres
2023-07-01 02:42:24 +00:00
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
# apply upscaling and correction
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=upscale,
chain=chain,
2023-02-12 18:33:36 +00:00
)
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, source, callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
2023-06-26 12:48:39 +00:00
server,
outputs[0],
image,
params,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
)
2023-01-16 00:54:20 +00:00
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished inpaint job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished inpaint job: %s", dest)
2023-01-17 05:45:54 +00:00
2023-01-17 05:45:54 +00:00
def run_upscale_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
2023-01-17 05:45:54 +00:00
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
) -> None:
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams()
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
2023-07-01 02:42:24 +00:00
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
2023-02-05 13:53:26 +00:00
)
2023-01-17 05:45:54 +00:00
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, source, callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished upscale job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished upscale job: %s", dest)
def run_blend_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
# highres: HighresParams,
sources: List[Image.Image],
mask: Image.Image,
) -> None:
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
stage.append((blend_mask, stage, None))
# apply upscaling and correction
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=upscale,
chain=chain,
)
# run and save
progress = job.get_progress_callback()
image = chain(job, server, params, sources[0], callback=progress)
dest = save_image(server, outputs[0], image, params, size, upscale=upscale)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished blend job: {dest}")
logger.info("finished blend job: %s", dest)