1
0
Fork 0

lint(api): apply black and isort style

This commit is contained in:
Sean Sube 2023-02-05 07:53:26 -06:00
parent c0e5f435ee
commit 54dd34d211
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
34 changed files with 1271 additions and 1251 deletions

View File

@ -23,3 +23,16 @@ package-dist:
package-upload: package-upload:
twine upload dist/* twine upload dist/*
lint-check:
black --check --preview onnx_web
isort --check-only --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web
lint-fix:
black onnx_web
isort --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web
typecheck:
mypy -m onnx_web.serve

View File

@ -1,3 +1,6 @@
black
flake8
isort
mypy mypy
types-Flask-Cors types-Flask-Cors

View File

@ -1,14 +1,6 @@
from . import logging from . import logging
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
from .chain import ( from .diffusion.load import get_latents_from_seed, load_pipeline
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from .diffusion.run import ( from .diffusion.run import (
run_img2img_pipeline, run_img2img_pipeline,
run_inpaint_pipeline, run_inpaint_pipeline,
@ -26,24 +18,14 @@ from .image import (
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from .params import ( from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
Param, from .upscale import run_upscale_correction
Point,
Border,
Size,
ImageParams,
StageParams,
UpscaleParams,
)
from .upscale import (
run_upscale_correction,
)
from .utils import ( from .utils import (
ServerContext,
base_join,
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
get_from_list, get_from_list,
get_from_map, get_from_map,
get_not_empty, get_not_empty,
base_join,
ServerContext,
) )

View File

@ -1,42 +1,13 @@
from .base import ( from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
ChainPipeline, from .blend_img2img import blend_img2img
PipelineStage, from .blend_inpaint import blend_inpaint
StageCallback, from .correct_gfpgan import correct_gfpgan
StageParams, from .persist_disk import persist_disk
) from .persist_s3 import persist_s3
from .blend_img2img import ( from .reduce_crop import reduce_crop
blend_img2img, from .reduce_thumbnail import reduce_thumbnail
) from .source_noise import source_noise
from .blend_inpaint import ( from .source_txt2img import source_txt2img
blend_inpaint, from .upscale_outpaint import upscale_outpaint
) from .upscale_resrgan import upscale_resrgan
from .correct_gfpgan import ( from .upscale_stable_diffusion import upscale_stable_diffusion
correct_gfpgan,
)
from .persist_disk import (
persist_disk,
)
from .persist_s3 import (
persist_s3,
)
from .reduce_crop import (
reduce_crop,
)
from .reduce_thumbnail import (
reduce_thumbnail,
)
from .source_noise import (
source_noise,
)
from .source_txt2img import (
source_txt2img,
)
from .upscale_outpaint import (
upscale_outpaint,
)
from .upscale_resrgan import (
upscale_resrgan,
)
from .upscale_stable_diffusion import (
upscale_stable_diffusion,
)

View File

@ -1,26 +1,15 @@
from datetime import timedelta from datetime import timedelta
from logging import getLogger from logging import getLogger
from PIL import Image
from time import monotonic from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple from typing import Any, List, Optional, Protocol, Tuple
from ..device_pool import ( from PIL import Image
JobContext,
) from ..device_pool import JobContext
from ..params import ( from ..output import save_image
ImageParams, from ..params import ImageParams, StageParams
StageParams, from ..utils import ServerContext, is_debug
) from .utils import process_tile_grid
from ..output import (
save_image,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tile_grid,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -42,33 +31,43 @@ PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline: class ChainPipeline:
''' """
Run many stages in series, passing the image results from each to the next, and processing Run many stages in series, passing the image results from each to the next, and processing
tiles as needed. tiles as needed.
''' """
def __init__( def __init__(
self, self,
stages: List[PipelineStage] = [], stages: List[PipelineStage] = [],
): ):
''' """
Create a new pipeline that will run the given stages. Create a new pipeline that will run the given stages.
''' """
self.stages = list(stages) self.stages = list(stages)
def append(self, stage: PipelineStage): def append(self, stage: PipelineStage):
''' """
Append an additional stage to this pipeline. Append an additional stage to this pipeline.
''' """
self.stages.append(stage) self.stages.append(stage)
def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image: def __call__(
''' self,
job: JobContext,
server: ServerContext,
params: ImageParams,
source: Image.Image,
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] outputs TODO: handle List[Image] outputs
''' """
start = monotonic() start = monotonic()
logger.info('running pipeline on source image with dimensions %sx%s', logger.info(
source.width, source.height) "running pipeline on source image with dimensions %sx%s",
source.width,
source.height,
)
image = source image = source
for stage_pipe, stage_params, stage_kwargs in self.stages: for stage_pipe, stage_params, stage_kwargs in self.stages:
@ -76,37 +75,51 @@ class ChainPipeline:
kwargs = stage_kwargs or {} kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs} kwargs = {**pipeline_kwargs, **kwargs}
logger.info('running stage %s on image with dimensions %sx%s, %s', logger.info(
name, image.width, image.height, kwargs.keys()) "running stage %s on image with dimensions %sx%s, %s",
name,
image.width,
image.height,
kwargs.keys(),
)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size: if (
logger.info('image larger than tile size of %s, tiling stage', image.width > stage_params.tile_size
stage_params.tile_size) or image.height > stage_params.tile_size
):
logger.info(
"image larger than tile size of %s, tiling stage",
stage_params.tile_size,
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(job, server, stage_params, params, tile, tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
**kwargs)
if is_debug(): if is_debug():
save_image(server, 'last-tile.png', tile) save_image(server, "last-tile.png", tile)
return tile return tile
image = process_tile_grid( image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile]) image, stage_params.tile_size, stage_params.outscale, [stage_tile]
)
else: else:
logger.info('image within tile size, running stage') logger.info("image within tile size, running stage")
image = stage_pipe(job, server, stage_params, params, image, image = stage_pipe(job, server, stage_params, params, image, **kwargs)
**kwargs)
logger.info('finished stage %s, result size: %sx%s', logger.info(
name, image.width, image.height) "finished stage %s, result size: %sx%s", name, image.width, image.height
)
if is_debug(): if is_debug():
save_image(server, 'last-stage.png', image) save_image(server, "last-stage.png", image)
end = monotonic() end = monotonic()
duration = timedelta(seconds=(end - start)) duration = timedelta(seconds=(end - start))
logger.info('finished pipeline in %s, result size: %sx%s', logger.info(
duration, image.width, image.height) "finished pipeline in %s, result size: %sx%s",
duration,
image.width,
image.height,
)
return image return image

View File

@ -1,24 +1,13 @@
from diffusers import (
OnnxStableDiffusionImg2ImgPipeline,
)
from logging import getLogger from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
load_pipeline,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -35,10 +24,14 @@ def blend_img2img(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt) logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
job.get_device(),
)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -53,6 +46,5 @@ def blend_img2img(
) )
output = result.images[0] output = result.images[0]
logger.info('final output image size: %sx%s', output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -1,41 +1,17 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger from logging import getLogger
from PIL import Image
from typing import Callable, Tuple from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
SizeChart,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tile_grid,
)
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
logger = getLogger(__name__) logger = getLogger(__name__)
@ -49,16 +25,16 @@ def blend_inpaint(
*, *,
expand: Border, expand: Border,
mask_image: Image.Image = None, mask_image: Image.Image = None,
fill_color: str = 'white', fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling image by expanding borders', expand) logger.info("upscaling image by expanding borders", expand)
if mask_image is None: if mask_image is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black') mask_image = Image.new("RGB", source_image.size, "black")
source_image, mask_image, noise_image, _full_dims = expand_image( source_image, mask_image, noise_image, _full_dims = expand_image(
source_image, source_image,
@ -66,12 +42,13 @@ def blend_inpaint(
expand, expand,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter) mask_filter=mask_filter,
)
if is_debug(): if is_debug():
save_image(server, 'last-source.png', source_image) save_image(server, "last-source.png", source_image)
save_image(server, 'last-mask.png', mask_image) save_image(server, "last-mask.png", mask_image)
save_image(server, 'last-noise.png', noise_image) save_image(server, "last-noise.png", noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]): def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
@ -79,11 +56,15 @@ def blend_inpaint(
mask = mask_image.crop((left, top, left + tile, top + tile)) mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, 'tile-source.png', image) save_image(server, "tile-source.png", image)
save_image(server, 'tile-mask.png', mask) save_image(server, "tile-mask.png", mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
job.get_device(),
)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -104,5 +85,5 @@ def blend_inpaint(
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size', output.size) logger.info("final output image size", output.size)
return output return output

View File

@ -1,30 +1,53 @@
from logging import getLogger
import torch
from basicsr.utils import img2tensor, tensor2img from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from logging import getLogger
from PIL import Image from PIL import Image
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
import torch from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
pretrain_model_url = { pretrain_model_url = (
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
} )
device = 'cpu' device = "cpu"
upscale = 2 upscale = 2
def correct_codeformer(image: Image.Image) -> Image.Image:
def correct_codeformer(
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
**kwargs,
) -> Image.Image:
model = "TODO"
# ------------------ set up CodeFormer restorer ------------------- # ------------------ set up CodeFormer restorer -------------------
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, net = ARCH_REGISTRY.get("CodeFormer")(
connect_list=['32', '64', '128', '256']).to(device) dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# ckpt_path = 'weights/CodeFormer/codeformer.pth' # ckpt_path = 'weights/CodeFormer/codeformer.pth'
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], ckpt_path = load_file_from_url(
model_dir='weights/CodeFormer', progress=True, file_name=None) url=pretrain_model_url,
checkpoint = torch.load(ckpt_path)['params_ema'] model_dir="weights/CodeFormer",
progress=True,
file_name=None,
)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
@ -36,22 +59,24 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
upscale, upscale,
face_size=512, face_size=512,
crop_ratio=(1, 1), crop_ratio=(1, 1),
det_model = args.detection_model, det_model=model,
save_ext='png', save_ext="png",
use_parse=True, use_parse=True,
device=device) device=device,
)
# get face landmarks for each face # get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5( num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) only_center_face=False, resize=640, eye_dist_threshold=5
logger.info('detect %s faces', num_det_faces) )
logger.info("detect %s faces", num_det_faces)
# align and warp each face # align and warp each face
face_helper.align_warp_face() face_helper.align_warp_face()
# face restoration for each cropped face # face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces): for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data # prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
@ -62,10 +87,10 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
del output del output
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as error: except Exception as error:
logger.error('Failed inference for CodeFormer: %s', error) logger.error("Failed inference for CodeFormer: %s", error)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8') restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face, cropped_face) face_helper.add_restored_face(restored_face, cropped_face)
# upsample the background # upsample the background
@ -75,13 +100,16 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
else: else:
bg_img = None bg_img = None
# paste_back # paste_back
face_helper.get_inverse_affine(None) face_helper.get_inverse_affine(None)
# paste each restored face to the input image # paste each restored face to the input image
if face_upsampler is not None: if face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler) restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
)
else: else:
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False) restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False
)
return restored_img return restored_img

View File

@ -1,27 +1,16 @@
from gfpgan import GFPGANer
from logging import getLogger from logging import getLogger
from os import path from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Optional from typing import Optional
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
from .upscale_resrgan import (
load_resrgan,
)
import numpy as np import numpy as np
from gfpgan import GFPGANer
from PIL import Image
from realesrgan import RealESRGANer
from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan
logger = getLogger(__name__) logger = getLogger(__name__)
@ -30,7 +19,9 @@ last_pipeline_instance = None
last_pipeline_params = None last_pipeline_params = None
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None): def load_gfpgan(
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
@ -38,22 +29,22 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
bg_upscale = upscale.rescale(upscale.outscale) bg_upscale = upscale.rescale(upscale.outscale)
upsampler = load_resrgan(ctx, bg_upscale) upsampler = load_resrgan(ctx, bg_upscale)
face_path = path.join(ctx.model_path, '%s.pth' % face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
(upscale.correction_model))
if last_pipeline_instance != None and face_path == last_pipeline_params: if last_pipeline_instance is not None and face_path == last_pipeline_params:
logger.info('reusing existing GFPGAN pipeline') logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance return last_pipeline_instance
logger.debug('loading GFPGAN model from %s', face_path) logger.debug("loading GFPGAN model from %s", face_path)
# TODO: find a way to pass the ONNX model to underlying architectures # TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer( gfpgan = GFPGANer(
model_path=face_path, model_path=face_path,
upscale=upscale.outscale, upscale=upscale.outscale,
arch='clean', arch="clean",
channel_multiplier=2, channel_multiplier=2,
bg_upsampler=upsampler) bg_upsampler=upsampler,
)
last_pipeline_instance = gfpgan last_pipeline_instance = gfpgan
last_pipeline_params = face_path last_pipeline_params = face_path
@ -74,15 +65,20 @@ def correct_gfpgan(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
if upscale.correction_model is None: if upscale.correction_model is None:
logger.warn('no face model given, skipping') logger.warn("no face model given, skipping")
return source_image return source_image
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model) logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler) gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
output = np.array(source_image) output = np.array(source_image)
_, _, output = gfpgan.enhance( _, _, output = gfpgan.enhance(
output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength) output,
output = Image.fromarray(output, 'RGB') has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
return output return output

View File

@ -1,26 +1,18 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import ( from ..device_pool import JobContext
JobContext, from ..output import save_image
) from ..params import ImageParams, StageParams
from ..params import ( from ..utils import ServerContext
ImageParams,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_disk( def persist_disk(
_job: JobContext, _job: JobContext,
ctx: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -28,6 +20,6 @@ def persist_disk(
output: str, output: str,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
dest = save_image(ctx, output, source_image) dest = save_image(server, output, source_image)
logger.info('saved image to %s', dest) logger.info("saved image to %s", dest)
return source_image return source_image

View File

@ -1,26 +1,19 @@
from boto3 import (
Session,
)
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from boto3 import Session
from PIL import Image from PIL import Image
from ..device_pool import ( from ..device_pool import JobContext
JobContext, from ..params import ImageParams, StageParams
) from ..utils import ServerContext
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_s3( def persist_s3(
ctx: ServerContext, _job: JobContext,
server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -32,16 +25,16 @@ def persist_s3(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client('s3', endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO() data = BytesIO()
source_image.save(data, format=ctx.image_format) source_image.save(data, format=server.image_format)
data.seek(0) data.seek(0)
try: try:
s3.upload_fileobj(data, bucket, output) s3.upload_fileobj(data, bucket, output)
logger.info('saved image to %s/%s', bucket, output) logger.info("saved image to %s/%s", bucket, output)
except Exception as err: except Exception as err:
logger.error('error saving image to S3: %s', err) logger.error("error saving image to S3: %s", err)
return source_image return source_image

View File

@ -1,23 +1,17 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import ( from ..device_pool import JobContext
JobContext, from ..params import ImageParams, Size, StageParams
) from ..utils import ServerContext
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_crop( def reduce_crop(
ctx: ServerContext, _job: JobContext,
_server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -26,8 +20,6 @@ def reduce_crop(
size: Size, size: Size,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = source_image.crop( image = source_image.crop((origin.width, origin.height, size.width, size.height))
(origin.width, origin.height, size.width, size.height)) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image return image

View File

@ -1,23 +1,17 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..device_pool import ( from ..device_pool import JobContext
JobContext, from ..params import ImageParams, Size, StageParams
) from ..utils import ServerContext
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_thumbnail( def reduce_thumbnail(
ctx: ServerContext, _job: JobContext,
_server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
@ -26,6 +20,5 @@ def reduce_thumbnail(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = source_image.thumbnail((size.width, size.height)) image = source_image.thumbnail((size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s', logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
image.width, image.height)
return image return image

View File

@ -1,26 +1,19 @@
from logging import getLogger from logging import getLogger
from PIL import Image
from typing import Callable from typing import Callable
from ..device_pool import ( from PIL import Image
JobContext,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_noise( def source_noise(
ctx: ServerContext, _job: JobContext,
stage: StageParams, _server: ServerContext,
_stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*, *,
@ -28,14 +21,12 @@ def source_noise(
noise_source: Callable, noise_source: Callable,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt logger.info("generating image from noise source")
logger.info('generating image from noise source')
if source_image is not None: if source_image is not None:
logger.warn( logger.warn("a source image was passed to a noise stage, but will be discarded")
'a source image was passed to a noise stage, but will be discarded')
output = noise_source(source_image, (size.width, size.height), (0, 0)) output = noise_source(source_image, (size.width, size.height), (0, 0))
logger.info('final output image size: %sx%s', output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -1,26 +1,13 @@
from diffusers import (
OnnxStableDiffusionPipeline,
)
from logging import getLogger from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -37,13 +24,16 @@ def source_txt2img(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt) logger.info("generating image using txt2img, %s steps: %s", params.steps, prompt)
if source_image is not None: if source_image is not None:
logger.warn('a source image was passed to a txt2img stage, but will be discarded') logger.warn(
"a source image was passed to a txt2img stage, but will be discarded"
)
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -60,5 +50,5 @@ def source_txt2img(
) )
output = result.images[0] output = result.images[0]
logger.info('final output image size: %sx%s', output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -1,43 +1,17 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger from logging import getLogger
from PIL import Image, ImageDraw
from typing import Callable, Tuple from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
get_tile_latents,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
SizeChart,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
base_join,
is_debug,
ServerContext,
)
from .utils import (
process_tile_spiral,
)
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_spiral
logger = getLogger(__name__) logger = getLogger(__name__)
@ -52,17 +26,17 @@ def upscale_outpaint(
border: Border, border: Border,
prompt: str = None, prompt: str = None,
mask_image: Image.Image = None, mask_image: Image.Image = None,
fill_color: str = 'white', fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
logger.info('upscaling image by expanding borders: %s', border) logger.info("upscaling image by expanding borders: %s", border)
if mask_image is None: if mask_image is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black') mask_image = Image.new("RGB", source_image.size, "black")
source_image, mask_image, noise_image, full_dims = expand_image( source_image, mask_image, noise_image, full_dims = expand_image(
source_image, source_image,
@ -70,16 +44,17 @@ def upscale_outpaint(
border, border,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter) mask_filter=mask_filter,
)
draw_mask = ImageDraw.Draw(mask_image) draw_mask = ImageDraw.Draw(mask_image)
full_size = Size(*full_dims) full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size) full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug(): if is_debug():
save_image(server, 'last-source.png', source_image) save_image(server, "last-source.png", source_image)
save_image(server, 'last-mask.png', mask_image) save_image(server, "last-mask.png", mask_image)
save_image(server, 'last-noise.png', noise_image) save_image(server, "last-noise.png", noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]): def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
@ -87,11 +62,15 @@ def upscale_outpaint(
mask = mask_image.crop((left, top, left + tile, top + tile)) mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, 'tile-source.png', image) save_image(server, "tile-source.png", image)
save_image(server, 'tile-mask.png', mask) save_image(server, "tile-mask.png", mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
job.get_device(),
)
latents = get_tile_latents(full_latents, dims) latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -110,10 +89,10 @@ def upscale_outpaint(
) )
# once part of the image has been drawn, keep it # once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black') draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0] return result.images[0]
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size: %sx%s', output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -1,27 +1,15 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from logging import getLogger from logging import getLogger
from os import path from os import path
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from ..device_pool import ( from ..device_pool import JobContext
JobContext, from ..onnx import OnnxNet
) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..onnx import ( from ..utils import ServerContext, run_gc
OnnxNet,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
import numpy as np
logger = getLogger(__name__) logger = getLogger(__name__)
@ -29,39 +17,50 @@ last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0): def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
model_file = '%s.%s' % (params.upscale_model, params.format) model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file) model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path): if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path) raise Exception("Real ESRGAN model not found at %s" % model_path)
cache_params = (model_path, params.format) cache_params = (model_path, params.format)
if last_pipeline_instance != None and cache_params == last_pipeline_params: if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info('reusing existing Real ESRGAN pipeline') logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance return last_pipeline_instance
# use ONNX acceleration, if available # use ONNX acceleration, if available
if params.format == 'onnx': if params.format == "onnx":
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options) model = OnnxNet(
elif params.format == 'pth': ctx, model_file, provider=device.provider, provider_options=device.options
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, )
num_block=23, num_grow_ch=32, scale=params.scale) elif params.format == "pth":
raise Exception('unknown platform %s' % params.format) model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
raise Exception("unknown platform %s" % params.format)
dni_weight = None dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1: if params.upscale_model == "realesr-general-x4v3" and params.denoise != 1:
wdn_model_path = model_path.replace( wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3') "realesr-general-x4v3", "realesr-general-wdn-x4v3"
)
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise] dni_weight = [params.denoise, 1 - params.denoise]
logger.debug('loading Real ESRGAN upscale model from %s', model_path) logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file # TODO: shouldn't need the PTH file
model_path_pth = path.join(ctx.model_path, '%s.pth' % params.upscale_model) model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=params.scale, scale=params.scale,
model_path=model_path_pth, model_path=model_path_pth,
@ -70,7 +69,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
tile=tile, tile=tile,
tile_pad=params.tile_pad, tile_pad=params.tile_pad,
pre_pad=params.pre_pad, pre_pad=params.pre_pad,
half=params.half) half=params.half,
)
last_pipeline_instance = upsampler last_pipeline_instance = upsampler
last_pipeline_params = cache_params last_pipeline_params = cache_params
@ -89,13 +89,13 @@ def upscale_resrgan(
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source_image) output = np.array(source_image)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size) upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale) output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB') output = Image.fromarray(output, "RGB")
logger.info('final output image size: %sx%s', output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -1,28 +1,16 @@
from diffusers import (
StableDiffusionUpscalePipeline,
)
from logging import getLogger from logging import getLogger
from os import path from os import path
import torch
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image from PIL import Image
from ..device_pool import ( from ..device_pool import JobContext
JobContext,
)
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from ..params import ( from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
DeviceParams, from ..utils import ServerContext, run_gc
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
import torch
logger = getLogger(__name__) logger = getLogger(__name__)
@ -31,23 +19,37 @@ last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams): def load_stable_diffusion(
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
model_path = path.join(ctx.model_path, upscale.upscale_model) model_path = path.join(ctx.model_path, upscale.upscale_model)
cache_params = (model_path, upscale.format) cache_params = (model_path, upscale.format)
if last_pipeline_instance != None and cache_params == last_pipeline_params: if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info('reusing existing Stable Diffusion upscale pipeline') logger.info("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance return last_pipeline_instance
if upscale.format == 'onnx': if upscale.format == "onnx":
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider) logger.debug(
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options) "loading Stable Diffusion upscale ONNX model from %s, using provider %s",
model_path,
device.provider,
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, provider_options=device.options
)
else: else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider) logger.debug(
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider) "loading Stable Diffusion upscale model from %s, using provider %s",
model_path,
device.provider,
)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider
)
last_pipeline_instance = pipeline last_pipeline_instance = pipeline
last_pipeline_params = cache_params last_pipeline_params = cache_params
@ -68,7 +70,7 @@ def upscale_stable_diffusion(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) logger.info("upscaling with Stable Diffusion, %s steps: %s", params.steps, prompt)
pipeline = load_stable_diffusion(server, upscale, job.get_device()) pipeline = load_stable_diffusion(server, upscale, job.get_device())
generator = torch.manual_seed(params.seed) generator = torch.manual_seed(params.seed)

View File

@ -1,7 +1,8 @@
from logging import getLogger from logging import getLogger
from PIL import Image
from typing import List, Protocol, Tuple from typing import List, Protocol, Tuple
from PIL import Image
logger = getLogger(__name__) logger = getLogger(__name__)
@ -17,7 +18,7 @@ def process_tile_grid(
filters: List[TileCallback], filters: List[TileCallback],
) -> Image.Image: ) -> Image.Image:
width, height = source.size width, height = source.size
image = Image.new('RGB', (width * scale, height * scale)) image = Image.new("RGB", (width * scale, height * scale))
tiles_x = width // tile tiles_x = width // tile
tiles_y = height // tile tiles_y = height // tile
@ -28,7 +29,7 @@ def process_tile_grid(
idx = (y * tiles_x) + x idx = (y * tiles_x) + x
left = x * tile left = x * tile
top = y * tile top = y * tile
logger.info('processing tile %s of %s, %s.%s', idx + 1, total, y, x) logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = source.crop((left, top, left + tile, top + tile)) tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters: for filter in filters:
@ -47,10 +48,10 @@ def process_tile_spiral(
overlap: float = 0.5, overlap: float = 0.5,
) -> Image.Image: ) -> Image.Image:
if scale != 1: if scale != 1:
raise Exception('unsupported scale') raise Exception("unsupported scale")
width, height = source.size width, height = source.size
image = Image.new('RGB', (width * scale, height * scale)) image = Image.new("RGB", (width * scale, height * scale))
image.paste(source, (0, 0, width, height)) image.paste(source, (0, 0, width, height))
center_x = (width // 2) - (tile // 2) center_x = (width // 2) - (tile // 2)
@ -76,7 +77,7 @@ def process_tile_spiral(
top = center_y + int(top) top = center_y + int(top)
counter += 1 counter += 1
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top) logger.info("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
# TODO: only valid for scale == 1, resize source for others # TODO: only valid for scale == 1, resize source for others
tile_image = image.crop((left, top, left + tile, top + tile)) tile_image = image.crop((left, top, left + tile, top + tile))

View File

@ -1,5 +1,14 @@
from . import logging import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, mkdir, path
from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from typing import Dict, List, Optional, Tuple
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from diffusers import ( from diffusers import (
@ -8,25 +17,20 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
from json import loads
from logging import getLogger
from onnx import load, save_model from onnx import load, save_model
from os import environ, makedirs, mkdir, path
from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export from torch.onnx import export
from typing import Dict, List, Optional, Tuple
import torch from . import logging
import warnings
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings( warnings.filterwarnings(
'ignore', '.*The shape inference of prim::Constant type is missing.*') "ignore", ".*The shape inference of prim::Constant type is missing.*"
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*') )
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
warnings.filterwarnings( warnings.filterwarnings(
'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*') "ignore",
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Tuple[str, str, Optional[int]]]] Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
@ -35,74 +39,95 @@ logger = getLogger(__name__)
# recommended models # recommended models
base_models: Models = { base_models: Models = {
'diffusion': [ "diffusion": [
# v1.x # v1.x
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'), ("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
('stable-diffusion-onnx-v1-inpainting', ("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
'runwayml/stable-diffusion-inpainting'),
# v2.x # v2.x
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'), ("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
('stable-diffusion-onnx-v2-inpainting', (
'stabilityai/stable-diffusion-2-inpainting'), "stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
),
# TODO: should have its own converter # TODO: should have its own converter
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'), ("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
], ],
'correction': [ "correction": [
('correction-gfpgan-v1-3', (
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4), "correction-gfpgan-v1-3",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
4,
),
], ],
'upscaling': [ "upscaling": [
('upscaling-real-esrgan-x2-plus', (
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2), "upscaling-real-esrgan-x2-plus",
('upscaling-real-esrgan-x4-plus', "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4), 2,
('upscaling-real-esrgan-x4-v3', ),
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4), (
"upscaling-real-esrgan-x4-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
(
"upscaling-real-esrgan-x4-v3",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
], ],
} }
model_path = environ.get('ONNX_WEB_MODEL_PATH', model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
path.join('..', 'models')) training_device = "cuda" if torch.cuda.is_available() else "cpu"
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
map_location = torch.device(training_device) map_location = torch.device(training_device)
@torch.no_grad() @torch.no_grad()
def convert_real_esrgan(name: str, url: str, scale: int, opset: int): def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + '.pth') dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + '.onnx') dest_onnx = path.join(model_path, name + ".onnx")
logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx) logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx): if path.isfile(dest_onnx):
logger.info('ONNX model already exists, skipping.') logger.info("ONNX model already exists, skipping.")
return return
if not path.isfile(dest_path): if not path.isfile(dest_path):
logger.info('PTH model not found, downloading...') logger.info("PTH model not found, downloading...")
download_path = load_file_from_url( download_path = load_file_from_url(
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None) url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path) copyfile(download_path, dest_path)
logger.info('loading and training model') logger.info("loading and training model")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(
num_block=23, num_grow_ch=32, scale=scale) num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location) torch_model = torch.load(dest_path, map_location=map_location)
if 'params_ema' in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model['params_ema']) model.load_state_dict(torch_model["params_ema"])
else: else:
model.load_state_dict(torch_model['params'], strict=False) model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False) model.to(training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location) rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data'] input_names = ["data"]
output_names = ['output'] output_names = ["output"]
dynamic_axes = {'data': {2: 'width', 3: 'height'}, dynamic_axes = {
'output': {2: 'width', 3: 'height'}} "data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info('exporting ONNX model to %s', dest_onnx) logger.info("exporting ONNX model to %s", dest_onnx)
export( export(
model, model,
rng, rng,
@ -111,48 +136,57 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=opset, opset_version=opset,
export_params=True export_params=True,
) )
logger.info('Real ESRGAN exported to ONNX successfully.') logger.info("Real ESRGAN exported to ONNX successfully.")
@torch.no_grad() @torch.no_grad()
def convert_gfpgan(name: str, url: str, scale: int, opset: int): def convert_gfpgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + '.pth') dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + '.onnx') dest_onnx = path.join(model_path, name + ".onnx")
logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx) logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx): if path.isfile(dest_onnx):
logger.info('ONNX model already exists, skipping.') logger.info("ONNX model already exists, skipping.")
return return
if not path.isfile(dest_path): if not path.isfile(dest_path):
logger.info('PTH model not found, downloading...') logger.info("PTH model not found, downloading...")
download_path = load_file_from_url( download_path = load_file_from_url(
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None) url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path) copyfile(download_path, dest_path)
logger.info('loading and training model') logger.info("loading and training model")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(
num_block=23, num_grow_ch=32, scale=scale) num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location) torch_model = torch.load(dest_path, map_location=map_location)
# TODO: make sure strict=False is safe here # TODO: make sure strict=False is safe here
if 'params_ema' in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model['params_ema'], strict=False) model.load_state_dict(torch_model["params_ema"], strict=False)
else: else:
model.load_state_dict(torch_model['params'], strict=False) model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False) model.to(training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location) rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data'] input_names = ["data"]
output_names = ['output'] output_names = ["output"]
dynamic_axes = {'data': {2: 'width', 3: 'height'}, dynamic_axes = {
'output': {2: 'width', 3: 'height'}} "data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info('exporting ONNX model to %s', dest_onnx) logger.info("exporting ONNX model to %s", dest_onnx)
export( export(
model, model,
rng, rng,
@ -161,9 +195,9 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=opset, opset_version=opset,
export_params=True export_params=True,
) )
logger.info('GFPGAN exported to ONNX successfully.') logger.info("GFPGAN exported to ONNX successfully.")
def onnx_export( def onnx_export(
@ -176,9 +210,9 @@ def onnx_export(
opset, opset,
use_external_data_format=False, use_external_data_format=False,
): ):
''' """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
''' """
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
export( export(
@ -194,29 +228,33 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False): def convert_diffuser(
''' name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
''' """
dtype = torch.float16 if half else torch.float32 dtype = torch.float16 if half else torch.float32
dest_path = path.join(model_path, name) dest_path = path.join(model_path, name)
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
logger.info('converting Diffusers model: %s -> %s/', name, dest_path) logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
if single_vae: if single_vae:
logger.info('converting model with single VAE') logger.info("converting model with single VAE")
if path.isdir(dest_path): if path.isdir(dest_path):
logger.info('ONNX model already exists, skipping.') logger.info("ONNX model already exists, skipping.")
return return
if half and training_device != 'cuda': if half and training_device != "cuda":
raise ValueError( raise ValueError(
'Half precision model export is only supported on GPUs with CUDA') "Half precision model export is only supported on GPUs with CUDA"
)
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
url, torch_dtype=dtype, use_auth_token=token).to(training_device) url, torch_dtype=dtype, use_auth_token=token
).to(training_device)
output_path = Path(dest_path) output_path = Path(dest_path)
# TEXT ENCODER # TEXT ENCODER
@ -232,8 +270,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export( onnx_export(
pipeline.text_encoder, pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to( model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
device=training_device, dtype=torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"], output_names=["last_hidden_state", "pooler_output"],
@ -244,7 +281,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
) )
del pipeline.text_encoder del pipeline.text_encoder
logger.debug('UNET config: %s', pipeline.unet.config) logger.debug("UNET config: %s", pipeline.unet.config)
# UNET # UNET
if single_vae: if single_vae:
@ -262,10 +299,12 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
pipeline.unet, pipeline.unet,
model_args=( model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype
),
torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2).to(device=training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to( torch.randn(2, num_tokens, text_hidden_size).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype
),
unet_scale, unet_scale,
), ),
output_path=unet_path, output_path=unet_path,
@ -298,7 +337,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
del pipeline.unet del pipeline.unet
if single_vae: if single_vae:
logger.debug('VAE config: %s', pipeline.vae.config) logger.debug("VAE config: %s", pipeline.vae.config)
# SINGLE VAE # SINGLE VAE
vae_only = pipeline.vae vae_only = pipeline.vae
@ -309,8 +348,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export( onnx_export(
vae_only, vae_only,
model_args=( model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( torch.randn(
device=training_device, dtype=dtype), 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae" / "model.onnx", output_path=output_path / "vae" / "model.onnx",
@ -328,12 +368,14 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
vae_sample_size = vae_encoder.config.sample_size vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder # need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample() sample, return_dict
)[0].sample()
onnx_export( onnx_export(
vae_encoder, vae_encoder,
model_args=( model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype
),
False, False,
), ),
output_path=output_path / "vae_encoder" / "model.onnx", output_path=output_path / "vae_encoder" / "model.onnx",
@ -354,8 +396,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export( onnx_export(
vae_decoder, vae_decoder,
model_args=( model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( torch.randn(
device=training_device, dtype=dtype), 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae_decoder" / "model.onnx", output_path=output_path / "vae_decoder" / "model.onnx",
@ -385,7 +428,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
clip_image_size, clip_image_size,
).to(device=training_device, dtype=dtype), ).to(device=training_device, dtype=dtype),
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to( torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype
),
), ),
output_path=output_path / "safety_checker" / "model.onnx", output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"], ordered_input_names=["clip_input", "images"],
@ -398,7 +442,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
) )
del pipeline.safety_checker del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained( safety_checker = OnnxRuntimeModel.from_pretrained(
output_path / "safety_checker") output_path / "safety_checker"
)
feature_extractor = pipeline.feature_extractor feature_extractor = pipeline.feature_extractor
else: else:
safety_checker = None safety_checker = None
@ -406,10 +451,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
if single_vae: if single_vae:
onnx_pipeline = StableDiffusionUpscalePipeline( onnx_pipeline = StableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained( vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
output_path / "vae"), text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer, tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler, low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
@ -417,12 +460,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
) )
else: else:
onnx_pipeline = OnnxStableDiffusionPipeline( onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained( vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
output_path / "vae_encoder"), vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained( text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer, tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler, scheduler=pipeline.scheduler,
@ -431,7 +471,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
requires_safety_checker=safety_checker is not None, requires_safety_checker=safety_checker is not None,
) )
logger.info('exporting ONNX model') logger.info("exporting ONNX model")
onnx_pipeline.save_pretrained(output_path) onnx_pipeline.save_pretrained(output_path)
logger.info("ONNX pipeline saved to %s", output_path) logger.info("ONNX pipeline saved to %s", output_path)
@ -445,90 +485,93 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
) )
else: else:
_ = OnnxStableDiffusionPipeline.from_pretrained( _ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider") output_path, provider="CPUExecutionProvider"
)
logger.info("ONNX pipeline is loadable") logger.info("ONNX pipeline is loadable")
def load_models(args, models: Models): def load_models(args, models: Models):
if args.diffusion: if args.diffusion:
for source in models.get('diffusion'): for source in models.get("diffusion"):
if source[0] in args.skip: if source[0] in args.skip:
logger.info('Skipping model: %s', source[0]) logger.info("Skipping model: %s", source[0])
else: else:
single_vae = 'upscaling' in source[0] single_vae = "upscaling" in source[0]
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae) convert_diffuser(
*source, args.opset, args.half, args.token, single_vae=single_vae
)
if args.upscaling: if args.upscaling:
for source in models.get('upscaling'): for source in models.get("upscaling"):
if source[0] in args.skip: if source[0] in args.skip:
logger.info('Skipping model: %s', source[0]) logger.info("Skipping model: %s", source[0])
else: else:
convert_real_esrgan(*source, args.opset) convert_real_esrgan(*source, args.opset)
if args.correction: if args.correction:
for source in models.get('correction'): for source in models.get("correction"):
if source[0] in args.skip: if source[0] in args.skip:
logger.info('Skipping model: %s', source[0]) logger.info("Skipping model: %s", source[0])
else: else:
convert_gfpgan(*source, args.opset) convert_gfpgan(*source, args.opset)
def main() -> int: def main() -> int:
parser = ArgumentParser( parser = ArgumentParser(
prog='onnx-web model converter', prog="onnx-web model converter", description="convert checkpoint models to ONNX"
description='convert checkpoint models to ONNX') )
# model groups # model groups
parser.add_argument('--correction', action='store_true', default=False) parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument('--diffusion', action='store_true', default=False) parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument('--upscaling', action='store_true', default=False) parser.add_argument("--upscaling", action="store_true", default=False)
# extra models # extra models
parser.add_argument('--extras', nargs='*', type=str, default=[]) parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument('--skip', nargs='*', type=str, default=[]) parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options # export options
parser.add_argument( parser.add_argument(
'--half', "--half",
action='store_true', action="store_true",
default=False, default=False,
help='Export models for half precision, faster on some Nvidia cards.' help="Export models for half precision, faster on some Nvidia cards.",
) )
parser.add_argument( parser.add_argument(
'--opset', "--opset",
default=14, default=14,
type=int, type=int,
help="The version of the ONNX operator set to use.", help="The version of the ONNX operator set to use.",
) )
parser.add_argument( parser.add_argument(
'--token', "--token",
type=str, type=str,
help="HuggingFace token with read permissions for downloading models.", help="HuggingFace token with read permissions for downloading models.",
) )
args = parser.parse_args() args = parser.parse_args()
logger.info('CLI arguments: %s', args) logger.info("CLI arguments: %s", args)
if not path.exists(model_path): if not path.exists(model_path):
logger.info('Model path does not existing, creating: %s', model_path) logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path) makedirs(model_path)
logger.info('Converting base models.') logger.info("Converting base models.")
load_models(args, base_models) load_models(args, base_models)
for file in args.extras: for file in args.extras:
logger.info('Loading extra models from %s', file) logger.info("Loading extra models from %s", file)
try: try:
with open(file, 'r') as f: with open(file, "r") as f:
data = loads(f.read()) data = loads(f.read())
logger.info('Converting extra models.') logger.info("Converting extra models.")
load_models(args, data) load_models(args, data)
except Exception as err: except Exception as err:
logger.error('Error converting extra models: %s', err) logger.error("Error converting extra models: %s", err)
return 0 return 0
if __name__ == '__main__': if __name__ == "__main__":
exit(main()) exit(main())

View File

@ -1,12 +1,10 @@
from collections import Counter from collections import Counter
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from logging import getLogger from logging import getLogger
from multiprocessing import Value from multiprocessing import Value
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
from .params import ( from .params import DeviceParams
DeviceParams,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -28,24 +26,24 @@ class JobContext:
): ):
self.key = key self.key = key
self.devices = list(devices) self.devices = list(devices)
self.cancel = Value('B', cancel) self.cancel = Value("B", cancel)
self.device_index = Value('i', device_index) self.device_index = Value("i", device_index)
self.progress = Value('I', progress) self.progress = Value("I", progress)
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
def get_device(self) -> DeviceParams: def get_device(self) -> DeviceParams:
''' """
Get the device assigned to this job. Get the device assigned to this job.
''' """
with self.device_index.get_lock(): with self.device_index.get_lock():
device_index = self.device_index.value device_index = self.device_index.value
if device_index < 0: if device_index < 0:
raise Exception('job has not been assigned to a device') raise Exception("job has not been assigned to a device")
else: else:
device = self.devices[device_index] device = self.devices[device_index]
logger.debug('job %s assigned to device %s', self.key, device) logger.debug("job %s assigned to device %s", self.key, device)
return device return device
def get_progress(self) -> int: def get_progress(self) -> int:
@ -54,10 +52,9 @@ class JobContext:
def get_progress_callback(self) -> Callable[..., None]: def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled(): if self.is_cancelled():
raise Exception('job has been cancelled') raise Exception("job has been cancelled")
else: else:
logger.debug('setting progress for job %s to %s', logger.debug("setting progress for job %s to %s", self.key, step)
self.key, step)
self.set_progress(step) self.set_progress(step)
return on_progress return on_progress
@ -72,9 +69,9 @@ class JobContext:
class Job: class Job:
''' """
Link a future to its context. Link a future to its context.
''' """
context: JobContext = None context: JobContext = None
future: Future = None future: Future = None
@ -106,7 +103,11 @@ class DevicePoolExecutor:
next_device: int = 0 next_device: int = 0
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): def __init__(
self,
devices: List[DeviceParams],
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
):
self.devices = devices self.devices = devices
self.jobs = [] self.jobs = []
self.next_device = 0 self.next_device = 0
@ -114,19 +115,25 @@ class DevicePoolExecutor:
device_count = len(devices) device_count = len(devices)
if pool is None: if pool is None:
logger.info( logger.info(
'creating thread pool executor for %s devices: %s', device_count, [d.device for d in devices]) "creating thread pool executor for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = ThreadPoolExecutor(device_count) self.pool = ThreadPoolExecutor(device_count)
else: else:
logger.info('using existing pool for %s devices: %s', logger.info(
device_count, [d.device for d in devices]) "using existing pool for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = pool self.pool = pool
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
''' """
Cancel a job. If the job has not been started, this will cancel Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
''' """
for job in self.jobs: for job in self.jobs:
if job.key == key: if job.key == key:
if job.future.cancel(): if job.future.cancel():
@ -144,7 +151,7 @@ class DevicePoolExecutor:
progress = job.get_progress() progress = job.get_progress()
return (done, progress) return (done, progress)
logger.warn('checking status for unknown key: %s', key) logger.warn("checking status for unknown key: %s", key)
return (None, 0) return (None, 0)
def get_next_device(self): def get_next_device(self):
@ -152,12 +159,14 @@ class DevicePoolExecutor:
if len(self.jobs) == 0: if len(self.jobs) == 0:
return 0 return 0
job_devices = [job.context.device_index.value for job in self.jobs if not job.future.done()] job_devices = [
job.context.device_index.value for job in self.jobs if not job.future.done()
]
job_counts = Counter(range(len(self.devices))) job_counts = Counter(range(len(self.devices)))
job_counts.update(job_devices) job_counts.update(job_devices)
queued = job_counts.most_common() queued = job_counts.most_common()
logger.debug('jobs queued by device: %s', queued) logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1] lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count] lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
@ -170,7 +179,7 @@ class DevicePoolExecutor:
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None: def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
device = self.get_next_device() device = self.get_next_device()
logger.info('assigning job %s to device %s', key, device) logger.info("assigning job %s to device %s", key, device)
context = JobContext(key, self.devices, device_index=device) context = JobContext(key, self.devices, device_index=device)
future = self.pool.submit(fn, context, *args, **kwargs) future = self.pool.submit(fn, context, *args, **kwargs)
@ -180,11 +189,19 @@ class DevicePoolExecutor:
def job_done(f: Future): def job_done(f: Future):
try: try:
f.result() f.result()
logger.info('job %s finished successfully', key) logger.info("job %s finished successfully", key)
except Exception as err: except Exception as err:
logger.warn('job %s failed with an error: %s', key, err) logger.warn("job %s failed with an error: %s", key, err)
future.add_done_callback(job_done) future.add_done_callback(job_done)
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
return [(job.key, job.context.device_index.value, job.future.done(), job.get_progress()) for job in self.jobs] return [
(
job.key,
job.context.device_index.value,
job.future.done(),
job.get_progress(),
)
for job in self.jobs
]

View File

@ -1,18 +1,11 @@
from diffusers import (
DiffusionPipeline,
)
from logging import getLogger from logging import getLogger
from typing import Any, Optional, Tuple from typing import Any, Tuple
from ..params import (
DeviceParams,
Size,
)
from ..utils import (
run_gc,
)
import numpy as np import numpy as np
from diffusers import DiffusionPipeline
from ..params import DeviceParams, Size
from ..utils import run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
@ -25,17 +18,23 @@ latent_factor = 8
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
''' """
From https://www.travelneil.com/stable-diffusion-updates.html From https://www.travelneil.com/stable-diffusion-updates.html
''' """
latents_shape = (batch, latent_channels, size.height // latent_factor, latents_shape = (
size.width // latent_factor) batch,
latent_channels,
size.height // latent_factor,
size.width // latent_factor,
)
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32) image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents return image_latents
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray: def get_tile_latents(
full_latents: np.ndarray, dims: Tuple[int, int, int]
) -> np.ndarray:
x, y, tile = dims x, y, tile = dims
t = tile // latent_factor t = tile // latent_factor
x = x // latent_factor x = x // latent_factor
@ -46,27 +45,29 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams): def load_pipeline(
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams
):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
options = (pipeline, model, device.provider) options = (pipeline, model, device.provider)
if last_pipeline_instance != None and last_pipeline_options == options: if last_pipeline_instance is not None and last_pipeline_options == options:
logger.debug('reusing existing diffusion pipeline') logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance pipe = last_pipeline_instance
else: else:
logger.debug('unloading previous diffusion pipeline') logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None last_pipeline_instance = None
last_pipeline_scheduler = None last_pipeline_scheduler = None
run_gc() run_gc()
logger.debug('loading new diffusion pipeline from %s', model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
provider_options=device.options, provider_options=device.options,
subfolder='scheduler', subfolder="scheduler",
) )
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
@ -76,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler=scheduler, scheduler=scheduler,
) )
if device is not None and hasattr(pipe, 'to'): if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device) pipe = pipe.to(device)
last_pipeline_instance = pipe last_pipeline_instance = pipe
@ -84,15 +85,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
logger.debug('loading new diffusion scheduler') logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
provider_options=device.options, provider_options=device.options,
subfolder='scheduler', subfolder="scheduler",
) )
if device is not None and hasattr(scheduler, 'to'): if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device) scheduler = scheduler.to(device)
pipe.scheduler = scheduler pipe.scheduler = scheduler

View File

@ -1,23 +1,11 @@
from diffusers import (
DDPMScheduler,
OnnxRuntimeModel,
StableDiffusionUpscalePipeline,
)
from diffusers.pipeline_utils import (
ImagePipelineOutput,
)
from logging import getLogger from logging import getLogger
from PIL import Image from typing import Any, Callable, List, Optional, Union
from typing import (
Any,
Callable,
List,
Optional,
Union,
)
import numpy as np import numpy as np
import torch import torch
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
from diffusers.pipeline_utils import ImagePipelineOutput
from PIL import Image
logger = getLogger(__name__) logger = getLogger(__name__)
@ -30,6 +18,7 @@ unet_in_channels = 7 # self.unet.config.in_channels
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py # https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
### ###
def preprocess(image): def preprocess(image):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
@ -63,8 +52,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler: Any, scheduler: Any,
max_noise_level: int = 350, max_noise_level: int = 350,
): ):
super().__init__(vae, text_encoder, tokenizer, unet, super().__init__(
low_res_scheduler, scheduler, max_noise_level) vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level,
)
def __call__( def __call__(
self, self,
@ -96,7 +92,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
) )
# 4. Preprocess image # 4. Preprocess image
@ -111,7 +111,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
text_embeddings_dtype = torch.float32 text_embeddings_dtype = torch.float32
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype) noise = torch.randn(
image.shape, generator=generator, device=device, dtype=text_embeddings_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -135,11 +137,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
num_channels_image = image.shape[1] num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != unet_in_channels: if num_channels_latents + num_channels_image != unet_in_channels:
raise ValueError( raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet` expects" "Incorrect configuration settings! The config of `pipeline.unet`"
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" expects {unet_in_channels} but received `num_channels_latents`:"
f" `num_channels_image`: {num_channels_image} " f" {num_channels_latents} + `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of" f" = {num_channels_latents+num_channels_image}. Please verify the"
" `pipeline.unet` or your `image` input." " config of `pipeline.unet` or your `image` input."
) )
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@ -150,10 +152,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1) latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor # timestep to tensor
@ -164,19 +172,25 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
sample=latent_model_input, sample=latent_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=text_embeddings, encoder_hidden_states=text_embeddings,
class_labels=noise_level class_labels=noise_level,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
@ -200,7 +214,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
return image return image
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer( text_inputs = self.tokenizer(
@ -211,13 +232,20 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" handle sequences up to {self.tokenizer.model_max_length} tokens:"
f" {removed_text}"
) )
# no positional arguments to text_encoder # no positional arguments to text_encoder
@ -240,16 +268,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt): elif type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" "`negative_prompt` should be the same type to `prompt`, but got"
f" {type(prompt)}." f" {type(negative_prompt)} != {type(prompt)}."
) )
elif isinstance(negative_prompt, str): elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size"
" the batch size of `prompt`." f" {batch_size}. Please make sure that passed `negative_prompt`"
" matches the batch size of `prompt`."
) )
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt

View File

@ -1,41 +1,17 @@
from diffusers import (
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
)
from logging import getLogger from logging import getLogger
from PIL import Image, ImageChops
from typing import Any from typing import Any
from ..chain import (
upscale_outpaint,
)
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Border,
Size,
StageParams,
)
from ..output import (
save_image,
save_params,
)
from ..upscale import (
run_upscale_correction,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
from .load import (
get_latents_from_seed,
load_pipeline,
)
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams
from ..upscale import UpscaleParams, run_upscale_correction
from ..utils import ServerContext, run_gc
from .load import get_latents_from_seed, load_pipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -46,10 +22,11 @@ def run_txt2img_pipeline(
params: ImageParams, params: ImageParams,
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams upscale: UpscaleParams,
) -> None: ) -> None:
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
)
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -68,7 +45,8 @@ def run_txt2img_pipeline(
) )
image = result.images[0] image = result.images[0]
image = run_upscale_correction( image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale) job, server, StageParams(), params, image, upscale=upscale
)
dest = save_image(server, output, image) dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
@ -77,7 +55,7 @@ def run_txt2img_pipeline(
del result del result
run_gc() run_gc()
logger.info('finished txt2img job: %s', dest) logger.info("finished txt2img job: %s", dest)
def run_img2img_pipeline( def run_img2img_pipeline(
@ -89,8 +67,12 @@ def run_img2img_pipeline(
source_image: Image.Image, source_image: Image.Image,
strength: float, strength: float,
) -> None: ) -> None:
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(
params.model, params.scheduler, job.get_device()) OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
job.get_device(),
)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -107,7 +89,8 @@ def run_img2img_pipeline(
) )
image = result.images[0] image = result.images[0]
image = run_upscale_correction( image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale) job, server, StageParams(), params, image, upscale=upscale
)
dest = save_image(server, output, image) dest = save_image(server, output, image)
size = Size(*source_image.size) size = Size(*source_image.size)
@ -117,7 +100,7 @@ def run_img2img_pipeline(
del result del result
run_gc() run_gc()
logger.info('finished img2img job: %s', dest) logger.info("finished img2img job: %s", dest)
def run_inpaint_pipeline( def run_inpaint_pipeline(
@ -151,16 +134,14 @@ def run_inpaint_pipeline(
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
) )
logger.info('applying mask filter and generating noise source') logger.info("applying mask filter and generating noise source")
if image.size == source_image.size: if image.size == source_image.size:
image = ImageChops.blend(source_image, image, strength) image = ImageChops.blend(source_image, image, strength)
else: else:
logger.info( logger.info("output image size does not match source, skipping post-blend")
'output image size does not match source, skipping post-blend')
image = run_upscale_correction( image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
job, server, stage, params, image, upscale=upscale)
dest = save_image(server, output, image) dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border) save_params(server, output, params, size, upscale=upscale, border=border)
@ -168,7 +149,7 @@ def run_inpaint_pipeline(
del image del image
run_gc() run_gc()
logger.info('finished inpaint job: %s', dest) logger.info("finished inpaint job: %s", dest)
def run_upscale_pipeline( def run_upscale_pipeline(
@ -185,7 +166,8 @@ def run_upscale_pipeline(
stage = StageParams() stage = StageParams()
image = run_upscale_correction( image = run_upscale_correction(
job, server, stage, params, source_image, upscale=upscale) job, server, stage, params, source_image, upscale=upscale
)
dest = save_image(server, output, image) dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
@ -193,4 +175,4 @@ def run_upscale_pipeline(
del image del image
run_gc() run_gc()
logger.info('finished upscale job: %s', dest) logger.info("finished upscale job: %s", dest)

View File

@ -1,31 +1,31 @@
import numpy as np
from numpy import random from numpy import random
from PIL import Image, ImageChops, ImageFilter from PIL import Image, ImageChops, ImageFilter
import numpy as np from .params import Border, Point
from .params import (
Border,
Point,
)
def get_pixel_index(x: int, y: int, width: int) -> int: def get_pixel_index(x: int, y: int, width: int) -> int:
return (y * width) + x return (y * width) + x
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: def mask_filter_none(
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
width, height = dims width, height = dims
noise = Image.new('RGB', (width, height), fill) noise = Image.new("RGB", (width, height), fill)
noise.paste(mask_image, origin) noise.paste(mask_image, origin)
return noise return noise
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: def mask_filter_gaussian_multiply(
''' mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur with multiply, source image centered on white canvas. Gaussian blur with multiply, source image centered on white canvas.
''' """
noise = mask_filter_none(mask_image, dims, origin) noise = mask_filter_none(mask_image, dims, origin)
for i in range(rounds): for i in range(rounds):
@ -35,10 +35,12 @@ def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin:
return noise return noise
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: def mask_filter_gaussian_screen(
''' mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
''' """
noise = mask_filter_none(mask_image, dims, origin) noise = mask_filter_none(mask_image, dims, origin)
for i in range(rounds): for i in range(rounds):
@ -48,33 +50,39 @@ def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Po
return noise return noise
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: def noise_source_fill_edge(
''' source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Identity transform, source image centered on white canvas. Identity transform, source image centered on white canvas.
''' """
width, height = dims width, height = dims
noise = Image.new('RGB', (width, height), fill) noise = Image.new("RGB", (width, height), fill)
noise.paste(source_image, origin) noise.paste(source_image, origin)
return noise return noise
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: def noise_source_fill_mask(
''' source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Fill the whole canvas, no source or noise. Fill the whole canvas, no source or noise.
''' """
width, height = dims width, height = dims
noise = Image.new('RGB', (width, height), fill) noise = Image.new("RGB", (width, height), fill)
return noise return noise
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: def noise_source_gaussian(
''' source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
''' """
noise = noise_source_uniform(source_image, dims, origin) noise = noise_source_uniform(source_image, dims, origin)
noise.paste(source_image, origin) noise.paste(source_image, origin)
@ -84,7 +92,9 @@ def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point,
return noise return noise
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: def noise_source_uniform(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -92,21 +102,19 @@ def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point,
noise_g = random.uniform(0, 256, size=size) noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size)
noise = Image.new('RGB', (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
for y in range(height): for y in range(height):
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), ( noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
int(noise_r[i]),
int(noise_g[i]),
int(noise_b[i])
))
return noise return noise
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: def noise_source_normal(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -114,21 +122,19 @@ def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, *
noise_g = random.normal(128, 32, size=size) noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size)
noise = Image.new('RGB', (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
for y in range(height): for y in range(height):
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), ( noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
int(noise_r[i]),
int(noise_g[i]),
int(noise_b[i])
))
return noise return noise
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: def noise_source_histogram(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
r, g, b = source_image.split() r, g, b = source_image.split()
width, height = dims width, height = dims
size = width * height size = width * height
@ -137,23 +143,22 @@ def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point
hist_g = g.histogram() hist_g = g.histogram()
hist_b = b.histogram() hist_b = b.histogram()
noise_r = random.choice(256, p=np.divide( noise_r = random.choice(
np.copy(hist_r), np.sum(hist_r)), size=size) 256, p=np.divide(np.copy(hist_r), np.sum(hist_r)), size=size
noise_g = random.choice(256, p=np.divide( )
np.copy(hist_g), np.sum(hist_g)), size=size) noise_g = random.choice(
noise_b = random.choice(256, p=np.divide( 256, p=np.divide(np.copy(hist_g), np.sum(hist_g)), size=size
np.copy(hist_b), np.sum(hist_b)), size=size) )
noise_b = random.choice(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
)
noise = Image.new('RGB', (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
for y in range(height): for y in range(height):
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), ( noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
noise_r[i],
noise_g[i],
noise_b[i]
))
return noise return noise
@ -163,7 +168,7 @@ def expand_image(
source_image: Image.Image, source_image: Image.Image,
mask_image: Image.Image, mask_image: Image.Image,
expand: Border, expand: Border,
fill='white', fill="white",
noise_source=noise_source_histogram, noise_source=noise_source_histogram,
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ):
@ -173,14 +178,13 @@ def expand_image(
dims = (full_width, full_height) dims = (full_width, full_height)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)
full_source = Image.new('RGB', dims, fill) full_source = Image.new("RGB", dims, fill)
full_source.paste(source_image, origin) full_source.paste(source_image, origin)
full_mask = mask_filter(mask_image, dims, origin, fill=fill) full_mask = mask_filter(mask_image, dims, origin, fill=fill)
full_noise = noise_source(source_image, dims, origin, fill=fill) full_noise = noise_source(source_image, dims, origin, fill=fill)
full_noise = ImageChops.multiply(full_noise, full_mask) full_noise = ImageChops.multiply(full_noise, full_mask)
full_source = Image.composite( full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
full_noise, full_source, full_mask.convert('L'))
return (full_source, full_mask, full_noise, (full_width, full_height)) return (full_source, full_mask, full_noise, (full_width, full_height))

View File

@ -1,14 +1,15 @@
from logging.config import dictConfig from logging.config import dictConfig
from os import environ, path from os import environ, path
from yaml import safe_load from yaml import safe_load
logging_path = environ.get('ONNX_WEB_LOGGING_PATH', './logging.yaml') logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
# setup logging config before anything else loads # setup logging config before anything else loads
try: try:
if path.exists(logging_path): if path.exists(logging_path):
with open(logging_path, 'r') as f: with open(logging_path, "r") as f:
config_logging = safe_load(f) config_logging = safe_load(f)
dictConfig(config_logging) dictConfig(config_logging)
except Exception as err: except Exception as err:
print('error loading logging config: %s' % (err)) print("error loading logging config: %s" % (err))

View File

@ -1,4 +1 @@
from .onnx_net import ( from .onnx_net import OnnxImage, OnnxNet
OnnxImage,
OnnxNet,
)

View File

@ -1,15 +1,14 @@
from onnxruntime import InferenceSession
from os import path from os import path
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from onnxruntime import InferenceSession
from ..utils import ( from ..utils import ServerContext
ServerContext,
)
class OnnxImage():
class OnnxImage:
def __init__(self, source) -> None: def __init__(self, source) -> None:
self.source = source self.source = source
self.data = self self.data = self
@ -38,28 +37,27 @@ class OnnxImage():
return np.shape(self.source) return np.shape(self.source)
class OnnxNet(): class OnnxNet:
''' """
Provides the RRDBNet interface using an ONNX session for DirectML acceleration. Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
''' """
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
model: str, model: str,
provider: str = 'DmlExecutionProvider', provider: str = "DmlExecutionProvider",
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
) -> None: ) -> None:
model_path = path.join(server.model_path, model) model_path = path.join(server.model_path, model)
self.session = InferenceSession( self.session = InferenceSession(
model_path, providers=[provider], provider_options=provider_options) model_path, providers=[provider], provider_options=provider_options
)
def __call__(self, image: Any) -> Any: def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], { output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
input_name: image.cpu().numpy()
})[0]
return OnnxImage(output) return OnnxImage(output)
def eval(self) -> None: def eval(self) -> None:

View File

@ -1,22 +1,14 @@
from hashlib import sha256 from hashlib import sha256
from json import dumps from json import dumps
from logging import getLogger from logging import getLogger
from PIL import Image
from struct import pack from struct import pack
from time import time from time import time
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from .params import ( from PIL import Image
Border,
ImageParams, from .params import Border, ImageParams, Param, Size, UpscaleParams
Param, from .utils import ServerContext, base_join
Size,
UpscaleParams,
)
from .utils import (
base_join,
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -25,13 +17,13 @@ def hash_value(sha, param: Param):
if param is None: if param is None:
return return
elif isinstance(param, float): elif isinstance(param, float):
sha.update(bytearray(pack('!f', param))) sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int): elif isinstance(param, int):
sha.update(bytearray(pack('!I', param))) sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str): elif isinstance(param, str):
sha.update(param.encode('utf-8')) sha.update(param.encode("utf-8"))
else: else:
logger.warn('cannot hash param: %s, %s', param, type(param)) logger.warn("cannot hash param: %s, %s", param, type(param))
def json_params( def json_params(
@ -42,22 +34,22 @@ def json_params(
border: Optional[Border] = None, border: Optional[Border] = None,
) -> Any: ) -> Any:
json = { json = {
'output': output, "output": output,
'params': params.tojson(), "params": params.tojson(),
} }
if upscale is not None and border is not None: if upscale is not None and border is not None:
size = upscale.resize(size.add_border(border)) size = upscale.resize(size.add_border(border))
if upscale is not None: if upscale is not None:
json['upscale'] = upscale.tojson() json["upscale"] = upscale.tojson()
size = upscale.resize(size) size = upscale.resize(size)
if border is not None: if border is not None:
json['border'] = border.tojson() json["border"] = border.tojson()
size = size.add_border(border) size = size.add_border(border)
json['size'] = size.tojson() json["size"] = size.tojson()
return json return json
@ -67,7 +59,7 @@ def make_output_name(
mode: str, mode: str,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
extras: Optional[Tuple[Param]] = None extras: Optional[Tuple[Param]] = None,
) -> str: ) -> str:
now = int(time()) now = int(time())
sha = sha256() sha = sha256()
@ -87,13 +79,19 @@ def make_output_name(
for param in extras: for param in extras:
hash_value(sha, param) hash_value(sha, param)
return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format) return "%s_%s_%s_%s.%s" % (
mode,
params.seed,
sha.hexdigest(),
now,
ctx.image_format,
)
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, output) path = base_join(ctx.output_path, output)
image.save(path, format=ctx.image_format) image.save(path, format=ctx.image_format)
logger.debug('saved output image to: %s', path) logger.debug("saved output image to: %s", path)
return path return path
@ -105,9 +103,9 @@ def save_params(
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
) -> str: ) -> str:
path = base_join(ctx.output_path, '%s.json' % (output)) path = base_join(ctx.output_path, "%s.json" % (output))
json = json_params(output, params, size, upscale=upscale, border=border) json = json_params(output, params, size, upscale=upscale, border=border)
with open(path, 'w') as f: with open(path, "w") as f:
f.write(dumps(json)) f.write(dumps(json))
logger.debug('saved image params to: %s', path) logger.debug("saved image params to: %s", path)
return path return path

View File

@ -26,14 +26,14 @@ class Border:
self.bottom = bottom self.bottom = bottom
def __str__(self) -> str: def __str__(self) -> str:
return '%s %s %s %s' % (self.left, self.top, self.right, self.bottom) return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
def tojson(self): def tojson(self):
return { return {
'left': self.left, "left": self.left,
'right': self.right, "right": self.right,
'top': self.top, "top": self.top,
'bottom': self.bottom, "bottom": self.bottom,
} }
@classmethod @classmethod
@ -47,32 +47,37 @@ class Size:
self.height = height self.height = height
def __str__(self) -> str: def __str__(self) -> str:
return '%sx%s' % (self.width, self.height) return "%sx%s" % (self.width, self.height)
def add_border(self, border: Border): def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right) return Size(
border.left + self.width + border.right,
border.top + self.height + border.right,
)
def tojson(self) -> Dict[str, int]: def tojson(self) -> Dict[str, int]:
return { return {
'height': self.height, "height": self.height,
'width': self.width, "width": self.width,
} }
class DeviceParams: class DeviceParams:
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None: def __init__(
self, device: str, provider: str, options: Optional[dict] = None
) -> None:
self.device = device self.device = device
self.provider = provider self.provider = provider
self.options = options self.options = options
def __str__(self) -> str: def __str__(self) -> str:
return '%s - %s (%s)' % (self.device, self.provider, self.options) return "%s - %s (%s)" % (self.device, self.provider, self.options)
def torch_device(self) -> str: def torch_device(self) -> str:
if self.device.startswith('cuda'): if self.device.startswith("cuda"):
return self.device return self.device
else: else:
return 'cpu' return "cpu"
class ImageParams: class ImageParams:
@ -84,7 +89,7 @@ class ImageParams:
negative_prompt: Optional[str], negative_prompt: Optional[str],
cfg: float, cfg: float,
steps: int, steps: int,
seed: int seed: int,
) -> None: ) -> None:
self.model = model self.model = model
self.scheduler = scheduler self.scheduler = scheduler
@ -96,20 +101,20 @@ class ImageParams:
def tojson(self) -> Dict[str, Optional[Param]]: def tojson(self) -> Dict[str, Optional[Param]]:
return { return {
'model': self.model, "model": self.model,
'scheduler': self.scheduler.__name__, "scheduler": self.scheduler.__name__,
'seed': self.seed, "seed": self.seed,
'prompt': self.prompt, "prompt": self.prompt,
'cfg': self.cfg, "cfg": self.cfg,
'negativePrompt': self.negative_prompt, "negativePrompt": self.negative_prompt,
'steps': self.steps, "steps": self.steps,
} }
class StageParams: class StageParams:
''' """
Parameters for a chained pipeline stage Parameters for a chained pipeline stage
''' """
def __init__( def __init__(
self, self,
@ -123,7 +128,7 @@ class StageParams:
self.outscale = outscale self.outscale = outscale
class UpscaleParams(): class UpscaleParams:
def __init__( def __init__(
self, self,
upscale_model: str, upscale_model: str,
@ -131,7 +136,7 @@ class UpscaleParams():
denoise: float = 0.5, denoise: float = 0.5,
faces=True, faces=True,
face_strength: float = 0.5, face_strength: float = 0.5,
format: Literal['onnx', 'pth'] = 'onnx', format: Literal["onnx", "pth"] = "onnx",
half=False, half=False,
outscale: int = 1, outscale: int = 1,
scale: int = 4, scale: int = 4,
@ -170,8 +175,8 @@ class UpscaleParams():
def tojson(self): def tojson(self):
return { return {
'model': self.upscale_model, "model": self.upscale_model,
'scale': self.scale, "scale": self.scale,
'outscale': self.outscale, "outscale": self.outscale,
# TODO: add more # TODO: add more
} }

View File

@ -1,62 +1,61 @@
from . import logging import gc
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import List, Tuple
import numpy as np
import torch
import yaml
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler, DPMSolverSinglestepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
KarrasVeScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS from flask_cors import CORS
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from jsonschema import validate from jsonschema import validate
from logging import getLogger
from PIL import Image
from onnxruntime import get_available_providers from onnxruntime import get_available_providers
from os import makedirs, path from PIL import Image
from typing import List, Tuple
from . import logging
from .chain import ( from .chain import (
ChainPipeline,
blend_img2img, blend_img2img,
blend_inpaint, blend_inpaint,
correct_gfpgan, correct_gfpgan,
persist_disk, persist_disk,
persist_s3, persist_s3,
reduce_thumbnail,
reduce_crop, reduce_crop,
reduce_thumbnail,
source_noise, source_noise,
source_txt2img, source_txt2img,
upscale_outpaint, upscale_outpaint,
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
ChainPipeline,
)
from .device_pool import (
DeviceParams,
DevicePoolExecutor,
) )
from .device_pool import DevicePoolExecutor
from .diffusion.run import ( from .diffusion.run import (
run_img2img_pipeline, run_img2img_pipeline,
run_inpaint_pipeline, run_inpaint_pipeline,
run_txt2img_pipeline, run_txt2img_pipeline,
run_upscale_pipeline, run_upscale_pipeline,
) )
from .image import ( from .image import ( # mask filters; noise sources
# mask filters
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
mask_filter_gaussian_screen, mask_filter_gaussian_screen,
mask_filter_none, mask_filter_none,
# noise sources
noise_source_fill_edge, noise_source_fill_edge,
noise_source_fill_mask, noise_source_fill_mask,
noise_source_gaussian, noise_source_gaussian,
@ -64,35 +63,20 @@ from .image import (
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from .output import ( from .output import json_params, make_output_name
json_params, from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
make_output_name,
)
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
UpscaleParams,
)
from .utils import ( from .utils import (
ServerContext,
base_join, base_join,
is_debug,
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
get_from_list, get_from_list,
get_from_map, get_from_map,
get_not_empty, get_not_empty,
get_size, get_size,
ServerContext, is_debug,
) )
import gc
import numpy as np
import torch
import yaml
logger = getLogger(__name__) logger = getLogger(__name__)
# config caching # config caching
@ -100,53 +84,53 @@ config_params = {}
# pipeline params # pipeline params
platform_providers = { platform_providers = {
'amd': 'DmlExecutionProvider', "amd": "DmlExecutionProvider",
'cpu': 'CPUExecutionProvider', "cpu": "CPUExecutionProvider",
'cuda': 'CUDAExecutionProvider', "cuda": "CUDAExecutionProvider",
'directml': 'DmlExecutionProvider', "directml": "DmlExecutionProvider",
'nvidia': 'CUDAExecutionProvider', "nvidia": "CUDAExecutionProvider",
'rocm': 'ROCMExecutionProvider', "rocm": "ROCMExecutionProvider",
} }
pipeline_schedulers = { pipeline_schedulers = {
'ddim': DDIMScheduler, "ddim": DDIMScheduler,
'ddpm': DDPMScheduler, "ddpm": DDPMScheduler,
'dpm-multi': DPMSolverMultistepScheduler, "dpm-multi": DPMSolverMultistepScheduler,
'dpm-single': DPMSolverSinglestepScheduler, "dpm-single": DPMSolverSinglestepScheduler,
'euler': EulerDiscreteScheduler, "euler": EulerDiscreteScheduler,
'euler-a': EulerAncestralDiscreteScheduler, "euler-a": EulerAncestralDiscreteScheduler,
'heun': HeunDiscreteScheduler, "heun": HeunDiscreteScheduler,
'k-dpm-2-a': KDPM2AncestralDiscreteScheduler, "k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
'k-dpm-2': KDPM2DiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler,
'karras-ve': KarrasVeScheduler, "karras-ve": KarrasVeScheduler,
'lms-discrete': LMSDiscreteScheduler, "lms-discrete": LMSDiscreteScheduler,
'pndm': PNDMScheduler, "pndm": PNDMScheduler,
} }
noise_sources = { noise_sources = {
'fill-edge': noise_source_fill_edge, "fill-edge": noise_source_fill_edge,
'fill-mask': noise_source_fill_mask, "fill-mask": noise_source_fill_mask,
'gaussian': noise_source_gaussian, "gaussian": noise_source_gaussian,
'histogram': noise_source_histogram, "histogram": noise_source_histogram,
'normal': noise_source_normal, "normal": noise_source_normal,
'uniform': noise_source_uniform, "uniform": noise_source_uniform,
} }
mask_filters = { mask_filters = {
'none': mask_filter_none, "none": mask_filter_none,
'gaussian-multiply': mask_filter_gaussian_multiply, "gaussian-multiply": mask_filter_gaussian_multiply,
'gaussian-screen': mask_filter_gaussian_screen, "gaussian-screen": mask_filter_gaussian_screen,
} }
chain_stages = { chain_stages = {
'blend-img2img': blend_img2img, "blend-img2img": blend_img2img,
'blend-inpaint': blend_inpaint, "blend-inpaint": blend_inpaint,
'correct-gfpgan': correct_gfpgan, "correct-gfpgan": correct_gfpgan,
'persist-disk': persist_disk, "persist-disk": persist_disk,
'persist-s3': persist_s3, "persist-s3": persist_s3,
'reduce-crop': reduce_crop, "reduce-crop": reduce_crop,
'reduce-thumbnail': reduce_thumbnail, "reduce-thumbnail": reduce_thumbnail,
'source-noise': source_noise, "source-noise": source_noise,
'source-txt2img': source_txt2img, "source-txt2img": source_txt2img,
'upscale-outpaint': upscale_outpaint, "upscale-outpaint": upscale_outpaint,
'upscale-resrgan': upscale_resrgan, "upscale-resrgan": upscale_resrgan,
'upscale-stable-diffusion': upscale_stable_diffusion, "upscale-stable-diffusion": upscale_stable_diffusion,
} }
# Available ORT providers # Available ORT providers
@ -158,7 +142,7 @@ correction_models = []
upscaling_models = [] upscaling_models = []
def get_config_value(key: str, subkey: str = 'default'): def get_config_value(key: str, subkey: str = "default"):
return config_params.get(key).get(subkey) return config_params.get(key).get(subkey)
@ -174,7 +158,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr user = request.remote_addr
# platform stuff # platform stuff
device_name = request.args.get('platform', available_platforms[0].device) device_name = request.args.get("platform", available_platforms[0].device)
device = None device = None
for platform in available_platforms: for platform in available_platforms:
@ -182,78 +166,101 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
device = available_platforms[0] device = available_platforms[0]
if device is None: if device is None:
raise Exception('unknown device') raise Exception("unknown device")
# pipeline stuff # pipeline stuff
model = get_not_empty(request.args, 'model', get_config_value('model')) model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model) model_path = get_model_path(model)
scheduler = get_from_map(request.args, 'scheduler', scheduler = get_from_map(
pipeline_schedulers, get_config_value('scheduler')) request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
# image params # image params
prompt = get_not_empty(request.args, prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
'prompt', get_config_value('prompt')) negative_prompt = request.args.get("negativePrompt", None)
negative_prompt = request.args.get('negativePrompt', None)
if negative_prompt is not None and negative_prompt.strip() == '': if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None negative_prompt = None
cfg = get_and_clamp_float( cfg = get_and_clamp_float(
request.args, 'cfg', request.args,
get_config_value('cfg'), "cfg",
get_config_value('cfg', 'max'), get_config_value("cfg"),
get_config_value('cfg', 'min')) get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, 'steps', request.args,
get_config_value('steps'), "steps",
get_config_value('steps', 'max'), get_config_value("steps"),
get_config_value('steps', 'min')) get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int( height = get_and_clamp_int(
request.args, 'height', request.args,
get_config_value('height'), "height",
get_config_value('height', 'max'), get_config_value("height"),
get_config_value('height', 'min')) get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int( width = get_and_clamp_int(
request.args, 'width', request.args,
get_config_value('width'), "width",
get_config_value('width', 'max'), get_config_value("width"),
get_config_value('width', 'min')) get_config_value("width", "max"),
get_config_value("width", "min"),
)
seed = int(request.args.get('seed', -1)) seed = int(request.args.get("seed", -1))
if seed == -1: if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max) seed = np.random.randint(np.iinfo(np.int32).max)
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", logger.info(
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt) "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler.__name__,
model_path,
device.provider,
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(model_path, scheduler, prompt, params = ImageParams(
negative_prompt, cfg, steps, seed) model_path, scheduler, prompt, negative_prompt, cfg, steps, seed
)
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)
def border_from_request() -> Border: def border_from_request() -> Border:
left = get_and_clamp_int(request.args, 'left', 0, left = get_and_clamp_int(
get_config_value('width', 'max'), 0) request.args, "left", 0, get_config_value("width", "max"), 0
right = get_and_clamp_int(request.args, 'right', )
0, get_config_value('width', 'max'), 0) right = get_and_clamp_int(
top = get_and_clamp_int(request.args, 'top', 0, request.args, "right", 0, get_config_value("width", "max"), 0
get_config_value('height', 'max'), 0) )
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int( bottom = get_and_clamp_int(
request.args, 'bottom', 0, get_config_value('height', 'max'), 0) request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom) return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams: def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models) upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models) correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, 'faces', 'false') == 'true' faces = get_not_empty(request.args, "faces", "false") == "true"
face_strength = get_and_clamp_float( face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
request.args, 'faceStrength', 0.5, 1.0, 0.0)
return UpscaleParams( return UpscaleParams(
upscaling, upscaling,
@ -261,7 +268,7 @@ def upscale_from_request() -> UpscaleParams:
denoise=denoise, denoise=denoise,
faces=faces, faces=faces,
face_strength=face_strength, face_strength=face_strength,
format='onnx', format="onnx",
outscale=outscale, outscale=outscale,
scale=scale, scale=scale,
) )
@ -269,7 +276,7 @@ def upscale_from_request() -> UpscaleParams:
def check_paths(context: ServerContext): def check_paths(context: ServerContext):
if not path.exists(context.model_path): if not path.exists(context.model_path):
raise RuntimeError('model path must exist') raise RuntimeError("model path must exist")
if not path.exists(context.output_path): if not path.exists(context.output_path):
makedirs(context.output_path) makedirs(context.output_path)
@ -286,35 +293,41 @@ def load_models(context: ServerContext):
global correction_models global correction_models
global upscaling_models global upscaling_models
diffusion_models = [get_model_name(f) for f in glob( diffusion_models = [
path.join(context.model_path, 'diffusion-*'))] get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
diffusion_models.extend([ ]
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))]) diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models)) diffusion_models = list(set(diffusion_models))
diffusion_models.sort() diffusion_models.sort()
correction_models = [ correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))] get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models)) correction_models = list(set(correction_models))
correction_models.sort() correction_models.sort()
upscaling_models = [ upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))] get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models)) upscaling_models = list(set(upscaling_models))
upscaling_models.sort() upscaling_models.sort()
def load_params(context: ServerContext): def load_params(context: ServerContext):
global config_params global config_params
params_file = path.join(context.params_path, 'params.json') params_file = path.join(context.params_path, "params.json")
with open(params_file, 'r') as f: with open(params_file, "r") as f:
config_params = yaml.safe_load(f) config_params = yaml.safe_load(f)
if 'platform' in config_params and context.default_platform is not None: if "platform" in config_params and context.default_platform is not None:
logger.info('overriding default platform to %s', logger.info("overriding default platform to %s", context.default_platform)
context.default_platform) config_platform = config_params.get("platform")
config_platform = config_params.get('platform') config_platform["default"] = context.default_platform
config_platform['default'] = context.default_platform
def load_platforms(): def load_platforms():
@ -323,30 +336,42 @@ def load_platforms():
providers = get_available_providers() providers = get_available_providers()
for potential in platform_providers: for potential in platform_providers:
if platform_providers[potential] in providers and potential not in context.block_platforms: if (
if potential == 'cuda': platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams(potential, platform_providers[potential], { available_platforms.append(
'device_id': i, DeviceParams(
})) potential,
platform_providers[potential],
{
"device_id": i,
},
)
)
else: else:
available_platforms.append(DeviceParams( available_platforms.append(
potential, platform_providers[potential])) DeviceParams(potential, platform_providers[potential])
)
# make sure CPU is last on the list # make sure CPU is last on the list
def cpu_last(a: DeviceParams, b: DeviceParams): def cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == 'cpu' and b.device == 'cpu': if a.device == "cpu" and b.device == "cpu":
return 0 return 0
if a.device == 'cpu': if a.device == "cpu":
return 1 return 1
return -1 return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
logger.info('available acceleration platforms: %s', logger.info(
', '.join([str(p) for p in available_platforms])) "available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)
context = ServerContext.from_environ() context = ServerContext.from_environ()
@ -365,16 +390,22 @@ if is_debug():
def ready_reply(ready: bool, progress: int = 0): def ready_reply(ready: bool, progress: int = 0):
return jsonify({ return jsonify(
'progress': progress, {
'ready': ready, "progress": progress,
}) "ready": ready,
}
)
def error_reply(err: str): def error_reply(err: str):
response = make_response(jsonify({ response = make_response(
'error': err, jsonify(
})) {
"error": err,
}
)
)
response.status_code = 400 response.status_code = 400
return response return response
@ -383,151 +414,154 @@ def get_model_path(model: str):
return base_join(context.model_path, model) return base_join(context.model_path, model)
def serve_bundle_file(filename='index.html'): def serve_bundle_file(filename="index.html"):
return send_from_directory(path.join('..', context.bundle_path), filename) return send_from_directory(path.join("..", context.bundle_path), filename)
# routes # routes
@app.route('/') @app.route("/")
def index(): def index():
return serve_bundle_file() return serve_bundle_file()
@app.route('/<path:filename>') @app.route("/<path:filename>")
def index_path(filename): def index_path(filename):
return serve_bundle_file(filename) return serve_bundle_file(filename)
@app.route('/api') @app.route("/api")
def introspect(): def introspect():
return { return {
'name': 'onnx-web', "name": "onnx-web",
'routes': [{ "routes": [
'path': url_from_rule(rule), {"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
'methods': list(rule.methods).sort() for rule in app.url_map.iter_rules()
} for rule in app.url_map.iter_rules()] ],
} }
@app.route('/api/settings/masks') @app.route("/api/settings/masks")
def list_mask_filters(): def list_mask_filters():
return jsonify(list(mask_filters.keys())) return jsonify(list(mask_filters.keys()))
@app.route('/api/settings/models') @app.route("/api/settings/models")
def list_models(): def list_models():
return jsonify({ return jsonify(
'diffusion': diffusion_models, {
'correction': correction_models, "diffusion": diffusion_models,
'upscaling': upscaling_models, "correction": correction_models,
}) "upscaling": upscaling_models,
}
)
@app.route('/api/settings/noises') @app.route("/api/settings/noises")
def list_noise_sources(): def list_noise_sources():
return jsonify(list(noise_sources.keys())) return jsonify(list(noise_sources.keys()))
@app.route('/api/settings/params') @app.route("/api/settings/params")
def list_params(): def list_params():
return jsonify(config_params) return jsonify(config_params)
@app.route('/api/settings/platforms') @app.route("/api/settings/platforms")
def list_platforms(): def list_platforms():
return jsonify([p.device for p in available_platforms]) return jsonify([p.device for p in available_platforms])
@app.route('/api/settings/schedulers') @app.route("/api/settings/schedulers")
def list_schedulers(): def list_schedulers():
return jsonify(list(pipeline_schedulers.keys())) return jsonify(list(pipeline_schedulers.keys()))
@app.route('/api/img2img', methods=['POST']) @app.route("/api/img2img", methods=["POST"])
def img2img(): def img2img():
if 'source' not in request.files: if "source" not in request.files:
return error_reply('source image is required') return error_reply("source image is required")
source_file = request.files.get('source') source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, request.args,
'strength', "strength",
get_config_value('strength'), get_config_value("strength"),
get_config_value('strength', 'max'), get_config_value("strength", "max"),
get_config_value('strength', 'min')) get_config_value("strength", "min"),
)
output = make_output_name( output = make_output_name(context, "img2img", params, size, extras=(strength,))
context,
'img2img',
params,
size,
extras=(strength,))
logger.info("img2img job queued for: %s", output) logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image.thumbnail((size.width, size.height))
executor.submit(output, run_img2img_pipeline, executor.submit(
context, params, output, upscale, source_image, strength) output,
run_img2img_pipeline,
context,
params,
output,
upscale,
source_image,
strength,
)
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/txt2img', methods=['POST']) @app.route("/api/txt2img", methods=["POST"])
def txt2img(): def txt2img():
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
output = make_output_name( output = make_output_name(context, "txt2img", params, size)
context,
'txt2img',
params,
size)
logger.info("txt2img job queued for: %s", output) logger.info("txt2img job queued for: %s", output)
executor.submit( executor.submit(
output, run_txt2img_pipeline, context, params, size, output, upscale) output, run_txt2img_pipeline, context, params, size, output, upscale
)
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/inpaint', methods=['POST']) @app.route("/api/inpaint", methods=["POST"])
def inpaint(): def inpaint():
if 'source' not in request.files: if "source" not in request.files:
return error_reply('source image is required') return error_reply("source image is required")
if 'mask' not in request.files: if "mask" not in request.files:
return error_reply('mask image is required') return error_reply("mask image is required")
source_file = request.files.get('source') source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get('mask') mask_file = request.files.get("mask")
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB') mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
expand = border_from_request() expand = border_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
fill_color = get_not_empty(request.args, 'fillColor', 'white') fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none') mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
noise_source = get_from_map( noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
request.args, 'noise', noise_sources, 'histogram')
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, request.args,
'strength', "strength",
get_config_value('strength'), get_config_value("strength"),
get_config_value('strength', 'max'), get_config_value("strength", "max"),
get_config_value('strength', 'min')) get_config_value("strength", "min"),
)
output = make_output_name( output = make_output_name(
context, context,
'inpaint', "inpaint",
params, params,
size, size,
extras=( extras=(
@ -539,7 +573,7 @@ def inpaint():
noise_source.__name__, noise_source.__name__,
strength, strength,
fill_color, fill_color,
) ),
) )
logger.info("inpaint job queued for: %s", output) logger.info("inpaint job queued for: %s", output)
@ -559,123 +593,131 @@ def inpaint():
noise_source, noise_source,
mask_filter, mask_filter,
strength, strength,
fill_color) fill_color,
)
return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
@app.route('/api/upscale', methods=['POST']) @app.route("/api/upscale", methods=["POST"])
def upscale(): def upscale():
if 'source' not in request.files: if "source" not in request.files:
return error_reply('source image is required') return error_reply("source image is required")
source_file = request.files.get('source') source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
output = make_output_name( output = make_output_name(context, "upscale", params, size)
context,
'upscale',
params,
size)
logger.info("upscale job queued for: %s", output) logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height)) source_image.thumbnail((size.width, size.height))
executor.submit(output, run_upscale_pipeline, executor.submit(
context, params, size, output, upscale, source_image) output,
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source_image,
)
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/chain', methods=['POST']) @app.route("/api/chain", methods=["POST"])
def chain(): def chain():
logger.debug('chain pipeline request: %s, %s', logger.debug(
request.form.keys(), request.files.keys()) "chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
body = request.form.get('chain') or request.files.get('chain') )
body = request.form.get("chain") or request.files.get("chain")
if body is None: if body is None:
return error_reply('chain pipeline must have a body') return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body) data = yaml.safe_load(body)
with open('./schema.yaml', 'r') as f: with open("./schema.yaml", "r") as f:
schema = yaml.safe_load(f.read()) schema = yaml.safe_load(f.read())
logger.info('validating chain request: %s against %s', data, schema) logger.info("validating chain request: %s against %s", data, schema)
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters # get defaults from the regular parameters
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
output = make_output_name( output = make_output_name(context, "chain", params, size)
context,
'chain',
params,
size)
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get('stages', []): for stage_data in data.get("stages", []):
callback = chain_stages[stage_data.get('type')] callback = chain_stages[stage_data.get("type")]
kwargs = stage_data.get('params', {}) kwargs = stage_data.get("params", {})
logger.info('request stage: %s, %s', callback.__name__, kwargs) logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams( stage = StageParams(
stage_data.get('name', callback.__name__), stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get('tile_size')), tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
) )
if 'border' in kwargs: if "border" in kwargs:
border = Border.even(int(kwargs.get('border'))) border = Border.even(int(kwargs.get("border")))
kwargs['border'] = border kwargs["border"] = border
if 'upscale' in kwargs: if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get('upscale')) upscale = UpscaleParams(kwargs.get("upscale"))
kwargs['upscale'] = upscale kwargs["upscale"] = upscale
stage_source_name = 'source:%s' % (stage.name) stage_source_name = "source:%s" % (stage.name)
stage_mask_name = 'mask:%s' % (stage.name) stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files: if stage_source_name in request.files:
logger.debug('loading source image %s for pipeline stage %s', logger.debug(
stage_source_name, stage.name) "loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
source_file = request.files.get(stage_source_name) source_file = request.files.get(stage_source_name)
source_image = Image.open( source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
BytesIO(source_file.read())).convert('RGB')
source_image = source_image.thumbnail((512, 512)) source_image = source_image.thumbnail((512, 512))
kwargs['source_image'] = source_image kwargs["source_image"] = source_image
if stage_mask_name in request.files: if stage_mask_name in request.files:
logger.debug('loading mask image %s for pipeline stage %s', logger.debug(
stage_mask_name, stage.name) "loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
mask_file = request.files.get(stage_mask_name) mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB') mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image = mask_image.thumbnail((512, 512)) mask_image = mask_image.thumbnail((512, 512))
kwargs['mask_image'] = mask_image kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs)) pipeline.append((callback, stage, kwargs))
logger.info('running chain pipeline with %s stages', len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))
# build and run chain pipeline # build and run chain pipeline
empty_source = Image.new('RGB', (size.width, size.height)) empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(output, pipeline, context, executor.submit(
params, empty_source, output=output, size=size) output, pipeline, context, params, empty_source, output=output, size=size
)
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))
@app.route('/api/cancel', methods=['PUT']) @app.route("/api/cancel", methods=["PUT"])
def cancel(): def cancel():
output_file = request.args.get('output', None) output_file = request.args.get("output", None)
cancel = executor.cancel(output_file) cancel = executor.cancel(output_file)
return ready_reply(cancel) return ready_reply(cancel)
@app.route('/api/ready') @app.route("/api/ready")
def ready(): def ready():
output_file = request.args.get('output', None) output_file = request.args.get("output", None)
done, progress = executor.done(output_file) done, progress = executor.done(output_file)
@ -687,11 +729,13 @@ def ready():
return ready_reply(done, progress=progress) return ready_reply(done, progress=progress)
@app.route('/api/status') @app.route("/api/status")
def status(): def status():
return jsonify(executor.status()) return jsonify(executor.status())
@app.route('/output/<path:filename>') @app.route("/output/<path:filename>")
def output(filename: str): def output(filename: str):
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False) return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)

View File

@ -1,24 +1,16 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from .chain import ( from .chain import (
correct_gfpgan,
upscale_stable_diffusion,
upscale_resrgan,
ChainPipeline, ChainPipeline,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
) )
from .device_pool import ( from .device_pool import JobContext
JobContext, from .params import ImageParams, SizeChart, StageParams, UpscaleParams
) from .utils import ServerContext
from .params import (
ImageParams,
SizeChart,
StageParams,
UpscaleParams,
)
from .utils import (
ServerContext,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -32,27 +24,25 @@ def run_upscale_correction(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
) -> Image.Image: ) -> Image.Image:
''' """
This is a convenience method for a chain pipeline that will run upscaling and This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params. correction, based on the `upscale` params.
''' """
logger.info('running upscaling and correction pipeline') logger.info("running upscaling and correction pipeline")
chain = ChainPipeline() chain = ChainPipeline()
if upscale.scale > 1: if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model: if "esrgan" in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size, stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale)
outscale=upscale.outscale)
chain.append((upscale_resrgan, stage, None)) chain.append((upscale_resrgan, stage, None))
elif 'stable-diffusion' in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, stage, None)) chain.append((upscale_stable_diffusion, stage, None))
if upscale.faces: if upscale.faces:
stage = StageParams(tile_size=stage.tile_size, stage = StageParams(tile_size=stage.tile_size, outscale=1)
outscale=1)
chain.append((correct_gfpgan, stage, None)) chain.append((correct_gfpgan, stage, None))
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)

View File

@ -1,13 +1,11 @@
import gc
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import gc
import torch import torch
from .params import ( from .params import SizeChart
SizeChart,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -15,15 +13,15 @@ logger = getLogger(__name__)
class ServerContext: class ServerContext:
def __init__( def __init__(
self, self,
bundle_path: str = '.', bundle_path: str = ".",
model_path: str = '.', model_path: str = ".",
output_path: str = '.', output_path: str = ".",
params_path: str = '.', params_path: str = ".",
cors_origin: str = '*', cors_origin: str = "*",
num_workers: int = 1, num_workers: int = 1,
block_platforms: List[str] = [], block_platforms: List[str] = [],
default_platform: str = None, default_platform: str = None,
image_format: str = 'png', image_format: str = "png",
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -38,40 +36,39 @@ class ServerContext:
@classmethod @classmethod
def from_environ(cls): def from_environ(cls):
return ServerContext( return ServerContext(
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH', bundle_path=environ.get(
path.join('..', 'gui', 'out')), "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
model_path=environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models')),
output_path=environ.get(
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get(
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
default_platform=environ.get(
'ONNX_WEB_DEFAULT_PLATFORM', None),
image_format=environ.get(
'ONNX_WEB_IMAGE_FORMAT', 'png'
), ),
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
) )
def base_join(base: str, tail: str) -> str: def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/') tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
return path.join(base, tail_path) return path.join(base, tail_path)
def is_debug() -> bool: def is_debug() -> bool:
return environ.get('DEBUG') is not None return environ.get("DEBUG") is not None
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int: def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
@ -80,7 +77,7 @@ def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
if selected in values: if selected in values:
return selected return selected
logger.warn('invalid selection: %s', selected) logger.warn("invalid selection: %s", selected)
if len(values) > 0: if len(values) > 0:
return values[0] return values[0]
@ -118,10 +115,10 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
return int(val) return int(val)
raise Exception('invalid size') raise Exception("invalid size")
def run_gc(): def run_gc():
logger.debug('running garbage collection') logger.debug("running garbage collection")
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()

2
api/pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.isort]
profile = "black"