1
0
Fork 0
onnx-web/api/onnx_web/diffusers/run.py

394 lines
9.3 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from logging import getLogger
from typing import Any, List, Optional
2023-01-16 00:46:00 +00:00
2023-02-16 03:01:25 +00:00
from PIL import Image
2023-02-05 13:53:26 +00:00
from onnx_web.chain.highres import stage_highres
from ..chain import (
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleOutpaintStage,
)
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..output import save_image
2023-04-01 17:06:31 +00:00
from ..params import (
Border,
HighresParams,
ImageParams,
Size,
StageParams,
UpscaleParams,
)
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
2023-02-26 20:15:30 +00:00
from ..worker import WorkerContext
from .utils import parse_prompt
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
def run_txt2img_pipeline(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams(
tile_size=params.tiles,
)
2023-07-01 02:42:24 +00:00
chain.stage(
SourceTxt2ImgStage(),
2023-07-01 02:42:24 +00:00
stage,
size=size,
)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
2023-07-01 02:42:24 +00:00
upscale=after_upscale,
chain=chain,
)
# run and save
progress = job.get_progress_callback()
2023-07-04 18:56:02 +00:00
images = chain(job, server, params, [], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
2023-07-04 18:56:02 +00:00
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished txt2img job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished txt2img job: %s", dest)
2023-01-16 00:54:20 +00:00
def run_img2img_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
strength: float,
source_filter: Optional[str] = None,
) -> None:
# run filter on the source image
if source_filter is not None:
f = get_source_filters().get(source_filter, None)
if f is not None:
logger.debug("running source filter: %s", f.__name__)
source = f(server, source)
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams(
tile_size=params.tiles,
)
2023-07-01 02:42:24 +00:00
chain.stage(
BlendImg2ImgStage(),
2023-07-01 02:42:24 +00:00
stage,
strength=strength,
2023-02-05 13:53:26 +00:00
)
2023-04-13 04:30:59 +00:00
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
2023-01-16 00:54:20 +00:00
# loopback through multiple img2img iterations
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
strength=strength,
)
2023-04-14 03:51:59 +00:00
# highres, if selected
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
2023-04-22 15:54:39 +00:00
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
)
# run and append the filtered source
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress),
if source_filter is not None and source_filter != "none":
images.append(source)
# save with metadata
_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
size = Size(*source.size)
2023-01-16 00:54:20 +00:00
for image, output in zip(images, outputs):
dest = save_image(
2023-06-26 12:48:39 +00:00
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished img2img job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished img2img job: %s", dest)
2023-01-16 00:54:20 +00:00
def run_inpaint_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
mask: Image.Image,
border: Border,
2023-01-16 00:54:20 +00:00
noise_source: Any,
mask_filter: Any,
fill_color: str,
tile_order: str,
) -> None:
logger.debug("building inpaint pipeline")
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
2023-07-01 02:42:24 +00:00
chain.stage(
UpscaleOutpaintStage(),
2023-07-01 02:42:24 +00:00
stage,
border=border,
stage_mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
)
2023-01-16 00:54:20 +00:00
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
2023-02-12 18:33:36 +00:00
)
# run and save
progress = job.get_progress_callback()
2023-07-04 18:56:02 +00:00
images = chain(job, server, params, [source], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
2023-07-04 18:56:02 +00:00
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
)
2023-01-16 00:54:20 +00:00
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished inpaint job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished inpaint job: %s", dest)
2023-01-17 05:45:54 +00:00
2023-01-17 05:45:54 +00:00
def run_upscale_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
2023-01-17 05:45:54 +00:00
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
) -> None:
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams(tile_size=params.tiles)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction, after highres
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
2023-02-05 13:53:26 +00:00
)
2023-01-17 05:45:54 +00:00
# run and save
progress = job.get_progress_callback()
2023-07-04 18:56:02 +00:00
images = chain(job, server, params, [source], callback=progress)
_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
2023-07-04 18:56:02 +00:00
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished upscale job: {dest}")
2023-02-05 13:53:26 +00:00
logger.info("finished upscale job: %s", dest)
def run_blend_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
# highres: HighresParams,
sources: List[Image.Image],
mask: Image.Image,
) -> None:
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
# apply upscaling and correction
2023-07-01 02:42:24 +00:00
stage_upscale_correction(
stage,
params,
upscale=upscale,
chain=chain,
)
# run and save
progress = job.get_progress_callback()
2023-07-04 18:56:02 +00:00
images = chain(job, server, params, sources, callback=progress)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished blend job: {dest}")
logger.info("finished blend job: %s", dest)