1
0
Fork 0

lint(api): apply black and isort style

This commit is contained in:
Sean Sube 2023-02-05 07:53:26 -06:00
parent c0e5f435ee
commit 54dd34d211
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
34 changed files with 1271 additions and 1251 deletions

View File

@ -23,3 +23,16 @@ package-dist:
package-upload:
twine upload dist/*
lint-check:
black --check --preview onnx_web
isort --check-only --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web
lint-fix:
black onnx_web
isort --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web
typecheck:
mypy -m onnx_web.serve

View File

@ -1,3 +1,6 @@
black
flake8
isort
mypy
types-Flask-Cors

View File

@ -1,49 +1,31 @@
from . import logging
from .chain import (
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
from .diffusion.load import get_latents_from_seed, load_pipeline
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .image import (
expand_image,
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from .params import (
Param,
Point,
Border,
Size,
ImageParams,
StageParams,
UpscaleParams,
)
from .upscale import (
run_upscale_correction,
expand_image,
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
from .upscale import run_upscale_correction
from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
base_join,
ServerContext,
)
ServerContext,
base_join,
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
)

View File

@ -1,42 +1,13 @@
from .base import (
ChainPipeline,
PipelineStage,
StageCallback,
StageParams,
)
from .blend_img2img import (
blend_img2img,
)
from .blend_inpaint import (
blend_inpaint,
)
from .correct_gfpgan import (
correct_gfpgan,
)
from .persist_disk import (
persist_disk,
)
from .persist_s3 import (
persist_s3,
)
from .reduce_crop import (
reduce_crop,
)
from .reduce_thumbnail import (
reduce_thumbnail,
)
from .source_noise import (
source_noise,
)
from .source_txt2img import (
source_txt2img,
)
from .upscale_outpaint import (
upscale_outpaint,
)
from .upscale_resrgan import (
upscale_resrgan,
)
from .upscale_stable_diffusion import (
upscale_stable_diffusion,
)
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
from .persist_s3 import persist_s3
from .reduce_crop import reduce_crop
from .reduce_thumbnail import reduce_thumbnail
from .source_noise import source_noise
from .source_txt2img import source_txt2img
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion

View File

@ -1,26 +1,15 @@
from datetime import timedelta
from logging import getLogger
from PIL import Image
from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tile_grid,
)
from PIL import Image
from ..device_pool import JobContext
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
logger = getLogger(__name__)
@ -42,33 +31,43 @@ PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline:
'''
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''
"""
def __init__(
self,
stages: List[PipelineStage] = [],
):
'''
"""
Create a new pipeline that will run the given stages.
'''
"""
self.stages = list(stages)
def append(self, stage: PipelineStage):
'''
"""
Append an additional stage to this pipeline.
'''
"""
self.stages.append(stage)
def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
'''
def __call__(
self,
job: JobContext,
server: ServerContext,
params: ImageParams,
source: Image.Image,
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] outputs
'''
"""
start = monotonic()
logger.info('running pipeline on source image with dimensions %sx%s',
source.width, source.height)
logger.info(
"running pipeline on source image with dimensions %sx%s",
source.width,
source.height,
)
image = source
for stage_pipe, stage_params, stage_kwargs in self.stages:
@ -76,37 +75,51 @@ class ChainPipeline:
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.info('running stage %s on image with dimensions %sx%s, %s',
name, image.width, image.height, kwargs.keys())
logger.info(
"running stage %s on image with dimensions %sx%s, %s",
name,
image.width,
image.height,
kwargs.keys(),
)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
logger.info('image larger than tile size of %s, tiling stage',
stage_params.tile_size)
if (
image.width > stage_params.tile_size
or image.height > stage_params.tile_size
):
logger.info(
"image larger than tile size of %s, tiling stage",
stage_params.tile_size,
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(job, server, stage_params, params, tile,
**kwargs)
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
if is_debug():
save_image(server, 'last-tile.png', tile)
save_image(server, "last-tile.png", tile)
return tile
image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
)
else:
logger.info('image within tile size, running stage')
image = stage_pipe(job, server, stage_params, params, image,
**kwargs)
logger.info("image within tile size, running stage")
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
logger.info('finished stage %s, result size: %sx%s',
name, image.width, image.height)
logger.info(
"finished stage %s, result size: %sx%s", name, image.width, image.height
)
if is_debug():
save_image(server, 'last-stage.png', image)
save_image(server, "last-stage.png", image)
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info('finished pipeline in %s, result size: %sx%s',
duration, image.width, image.height)
logger.info(
"finished pipeline in %s, result size: %sx%s",
duration,
image.width,
image.height,
)
return image

View File

@ -1,24 +1,13 @@
from diffusers import (
OnnxStableDiffusionImg2ImgPipeline,
)
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
load_pipeline,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
@ -35,10 +24,14 @@ def blend_img2img(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
job.get_device(),
)
rng = np.random.RandomState(params.seed)
@ -53,6 +46,5 @@ def blend_img2img(
)
output = result.images[0]
logger.info('final output image size: %sx%s', output.width, output.height)
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,41 +1,17 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger
from PIL import Image
from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
SizeChart,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tile_grid,
)
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
logger = getLogger(__name__)
@ -49,16 +25,16 @@ def blend_inpaint(
*,
expand: Border,
mask_image: Image.Image = None,
fill_color: str = 'white',
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image:
logger.info('upscaling image by expanding borders', expand)
logger.info("upscaling image by expanding borders", expand)
if mask_image is None:
# if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black')
mask_image = Image.new("RGB", source_image.size, "black")
source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
@ -66,12 +42,13 @@ def blend_inpaint(
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)
mask_filter=mask_filter,
)
if is_debug():
save_image(server, 'last-source.png', source_image)
save_image(server, 'last-mask.png', mask_image)
save_image(server, 'last-noise.png', noise_image)
save_image(server, "last-source.png", source_image)
save_image(server, "last-mask.png", mask_image)
save_image(server, "last-noise.png", noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
@ -79,11 +56,15 @@ def blend_inpaint(
mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, 'tile-source.png', image)
save_image(server, 'tile-mask.png', mask)
save_image(server, "tile-source.png", image)
save_image(server, "tile-mask.png", mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
job.get_device(),
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
@ -104,5 +85,5 @@ def blend_inpaint(
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size', output.size)
logger.info("final output image size", output.size)
return output

View File

@ -1,30 +1,53 @@
from logging import getLogger
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from logging import getLogger
from PIL import Image
from torchvision.transforms.functional import normalize
import torch
from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}
pretrain_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)
device = 'cpu'
device = "cpu"
upscale = 2
def correct_codeformer(image: Image.Image) -> Image.Image:
def correct_codeformer(
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
**kwargs,
) -> Image.Image:
model = "TODO"
# ------------------ set up CodeFormer restorer -------------------
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
model_dir='weights/CodeFormer', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema']
ckpt_path = load_file_from_url(
url=pretrain_model_url,
model_dir="weights/CodeFormer",
progress=True,
file_name=None,
)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
@ -36,22 +59,24 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model = args.detection_model,
save_ext='png',
det_model=model,
save_ext="png",
use_parse=True,
device=device)
device=device,
)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
logger.info('detect %s faces', num_det_faces)
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.info("detect %s faces", num_det_faces)
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
@ -62,10 +87,10 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
del output
torch.cuda.empty_cache()
except Exception as error:
logger.error('Failed inference for CodeFormer: %s', error)
logger.error("Failed inference for CodeFormer: %s", error)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face, cropped_face)
# upsample the background
@ -75,13 +100,16 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
else:
bg_img = None
# paste_back
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler)
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
)
else:
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False
)
return restored_img
return restored_img

View File

@ -1,27 +1,16 @@
from gfpgan import GFPGANer
from logging import getLogger
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Optional
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
from .upscale_resrgan import (
load_resrgan,
)
import numpy as np
from gfpgan import GFPGANer
from PIL import Image
from realesrgan import RealESRGANer
from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan
logger = getLogger(__name__)
@ -30,7 +19,9 @@ last_pipeline_instance = None
last_pipeline_params = None
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None):
def load_gfpgan(
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
):
global last_pipeline_instance
global last_pipeline_params
@ -38,22 +29,22 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
bg_upscale = upscale.rescale(upscale.outscale)
upsampler = load_resrgan(ctx, bg_upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
if last_pipeline_instance != None and face_path == last_pipeline_params:
logger.info('reusing existing GFPGAN pipeline')
if last_pipeline_instance is not None and face_path == last_pipeline_params:
logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance
logger.debug('loading GFPGAN model from %s', face_path)
logger.debug("loading GFPGAN model from %s", face_path)
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
arch='clean',
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler)
bg_upsampler=upsampler,
)
last_pipeline_instance = gfpgan
last_pipeline_params = face_path
@ -74,15 +65,20 @@ def correct_gfpgan(
**kwargs,
) -> Image.Image:
if upscale.correction_model is None:
logger.warn('no face model given, skipping')
logger.warn("no face model given, skipping")
return source_image
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
output = np.array(source_image)
_, _, output = gfpgan.enhance(
output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
output = Image.fromarray(output, 'RGB')
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
return output

View File

@ -1,26 +1,18 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
ServerContext,
)
from ..device_pool import JobContext
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
def persist_disk(
_job: JobContext,
ctx: ServerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -28,6 +20,6 @@ def persist_disk(
output: str,
**kwargs,
) -> Image.Image:
dest = save_image(ctx, output, source_image)
logger.info('saved image to %s', dest)
dest = save_image(server, output, source_image)
logger.info("saved image to %s", dest)
return source_image

View File

@ -1,26 +1,19 @@
from boto3 import (
Session,
)
from io import BytesIO
from logging import getLogger
from boto3 import Session
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
def persist_s3(
ctx: ServerContext,
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -32,16 +25,16 @@ def persist_s3(
**kwargs,
) -> Image.Image:
session = Session(profile_name=profile_name)
s3 = session.client('s3', endpoint_url=endpoint_url)
s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO()
source_image.save(data, format=ctx.image_format)
source_image.save(data, format=server.image_format)
data.seek(0)
try:
s3.upload_fileobj(data, bucket, output)
logger.info('saved image to %s/%s', bucket, output)
logger.info("saved image to %s/%s", bucket, output)
except Exception as err:
logger.error('error saving image to S3: %s', err)
logger.error("error saving image to S3: %s", err)
return source_image

View File

@ -1,23 +1,17 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
def reduce_crop(
ctx: ServerContext,
_job: JobContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -26,8 +20,6 @@ def reduce_crop(
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.crop(
(origin.width, origin.height, size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
image = source_image.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image

View File

@ -1,23 +1,17 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
def reduce_thumbnail(
ctx: ServerContext,
_job: JobContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -26,6 +20,5 @@ def reduce_thumbnail(
**kwargs,
) -> Image.Image:
image = source_image.thumbnail((size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image

View File

@ -1,26 +1,19 @@
from logging import getLogger
from PIL import Image
from typing import Callable
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
def source_noise(
ctx: ServerContext,
stage: StageParams,
_job: JobContext,
_server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
@ -28,14 +21,12 @@ def source_noise(
noise_source: Callable,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image from noise source')
logger.info("generating image from noise source")
if source_image is not None:
logger.warn(
'a source image was passed to a noise stage, but will be discarded')
logger.warn("a source image was passed to a noise stage, but will be discarded")
output = noise_source(source_image, (size.width, size.height), (0, 0))
logger.info('final output image size: %sx%s', output.width, output.height)
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,26 +1,13 @@
from diffusers import (
OnnxStableDiffusionPipeline,
)
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
logger = getLogger(__name__)
@ -37,13 +24,16 @@ def source_txt2img(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
logger.info("generating image using txt2img, %s steps: %s", params.steps, prompt)
if source_image is not None:
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
logger.warn(
"a source image was passed to a txt2img stage, but will be discarded"
)
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
@ -60,5 +50,5 @@ def source_txt2img(
)
output = result.images[0]
logger.info('final output image size: %sx%s', output.width, output.height)
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,43 +1,17 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger
from PIL import Image, ImageDraw
from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
get_tile_latents,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
SizeChart,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
base_join,
is_debug,
ServerContext,
)
from .utils import (
process_tile_spiral,
)
import numpy as np
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_spiral
logger = getLogger(__name__)
@ -52,17 +26,17 @@ def upscale_outpaint(
border: Border,
prompt: str = None,
mask_image: Image.Image = None,
fill_color: str = 'white',
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('upscaling image by expanding borders: %s', border)
logger.info("upscaling image by expanding borders: %s", border)
if mask_image is None:
# if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black')
mask_image = Image.new("RGB", source_image.size, "black")
source_image, mask_image, noise_image, full_dims = expand_image(
source_image,
@ -70,16 +44,17 @@ def upscale_outpaint(
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)
mask_filter=mask_filter,
)
draw_mask = ImageDraw.Draw(mask_image)
full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug():
save_image(server, 'last-source.png', source_image)
save_image(server, 'last-mask.png', mask_image)
save_image(server, 'last-noise.png', noise_image)
save_image(server, "last-source.png", source_image)
save_image(server, "last-mask.png", mask_image)
save_image(server, "last-noise.png", noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
@ -87,11 +62,15 @@ def upscale_outpaint(
mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, 'tile-source.png', image)
save_image(server, 'tile-mask.png', mask)
save_image(server, "tile-source.png", image)
save_image(server, "tile-mask.png", mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
job.get_device(),
)
latents = get_tile_latents(full_latents, dims)
rng = np.random.RandomState(params.seed)
@ -110,10 +89,10 @@ def upscale_outpaint(
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black')
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size: %sx%s', output.width, output.height)
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,27 +1,15 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from logging import getLogger
from os import path
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from ..device_pool import (
JobContext,
)
from ..onnx import (
OnnxNet,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
import numpy as np
from ..device_pool import JobContext
from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)
@ -29,39 +17,50 @@ last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
global last_pipeline_instance
global last_pipeline_params
model_file = '%s.%s' % (params.upscale_model, params.format)
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path)
raise Exception("Real ESRGAN model not found at %s" % model_path)
cache_params = (model_path, params.format)
if last_pipeline_instance != None and cache_params == last_pipeline_params:
logger.info('reusing existing Real ESRGAN pipeline')
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance
# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
raise Exception('unknown platform %s' % params.format)
if params.format == "onnx":
model = OnnxNet(
ctx, model_file, provider=device.provider, provider_options=device.options
)
elif params.format == "pth":
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
raise Exception("unknown platform %s" % params.format)
dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
if params.upscale_model == "realesr-general-x4v3" and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
"realesr-general-x4v3", "realesr-general-wdn-x4v3"
)
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
logger.debug('loading Real ESRGAN upscale model from %s', model_path)
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file
model_path_pth = path.join(ctx.model_path, '%s.pth' % params.upscale_model)
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path_pth,
@ -70,7 +69,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=params.half)
half=params.half,
)
last_pipeline_instance = upsampler
last_pipeline_params = cache_params
@ -89,13 +89,13 @@ def upscale_resrgan(
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source_image)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB')
logger.info('final output image size: %sx%s', output.width, output.height)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,28 +1,16 @@
from diffusers import (
StableDiffusionUpscalePipeline,
)
from logging import getLogger
from os import path
import torch
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..device_pool import JobContext
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
import torch
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
logger = getLogger(__name__)
@ -31,23 +19,37 @@ last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
def load_stable_diffusion(
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
):
global last_pipeline_instance
global last_pipeline_params
model_path = path.join(ctx.model_path, upscale.upscale_model)
cache_params = (model_path, upscale.format)
if last_pipeline_instance != None and cache_params == last_pipeline_params:
logger.info('reusing existing Stable Diffusion upscale pipeline')
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance
if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
if upscale.format == "onnx":
logger.debug(
"loading Stable Diffusion upscale ONNX model from %s, using provider %s",
model_path,
device.provider,
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, provider_options=device.options
)
else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
logger.debug(
"loading Stable Diffusion upscale model from %s, using provider %s",
model_path,
device.provider,
)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider
)
last_pipeline_instance = pipeline
last_pipeline_params = cache_params
@ -68,7 +70,7 @@ def upscale_stable_diffusion(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
logger.info("upscaling with Stable Diffusion, %s steps: %s", params.steps, prompt)
pipeline = load_stable_diffusion(server, upscale, job.get_device())
generator = torch.manual_seed(params.seed)

View File

@ -1,7 +1,8 @@
from logging import getLogger
from PIL import Image
from typing import List, Protocol, Tuple
from PIL import Image
logger = getLogger(__name__)
@ -17,7 +18,7 @@ def process_tile_grid(
filters: List[TileCallback],
) -> Image.Image:
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
image = Image.new("RGB", (width * scale, height * scale))
tiles_x = width // tile
tiles_y = height // tile
@ -28,7 +29,7 @@ def process_tile_grid(
idx = (y * tiles_x) + x
left = x * tile
top = y * tile
logger.info('processing tile %s of %s, %s.%s', idx + 1, total, y, x)
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters:
@ -47,10 +48,10 @@ def process_tile_spiral(
overlap: float = 0.5,
) -> Image.Image:
if scale != 1:
raise Exception('unsupported scale')
raise Exception("unsupported scale")
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
image = Image.new("RGB", (width * scale, height * scale))
image.paste(source, (0, 0, width, height))
center_x = (width // 2) - (tile // 2)
@ -76,7 +77,7 @@ def process_tile_spiral(
top = center_y + int(top)
counter += 1
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top)
logger.info("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
# TODO: only valid for scale == 1, resize source for others
tile_image = image.crop((left, top, left + tile, top + tile))

View File

@ -1,5 +1,14 @@
from . import logging
import warnings
from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, mkdir, path
from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from typing import Dict, List, Optional, Tuple
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from diffusers import (
@ -8,25 +17,20 @@ from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from json import loads
from logging import getLogger
from onnx import load, save_model
from os import environ, makedirs, mkdir, path
from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export
from typing import Dict, List, Optional, Tuple
import torch
import warnings
from . import logging
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
'ignore', '.*The shape inference of prim::Constant type is missing.*')
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
"ignore", ".*The shape inference of prim::Constant type is missing.*"
)
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
warnings.filterwarnings(
'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
"ignore",
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
@ -35,74 +39,95 @@ logger = getLogger(__name__)
# recommended models
base_models: Models = {
'diffusion': [
"diffusion": [
# v1.x
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
('stable-diffusion-onnx-v1-inpainting',
'runwayml/stable-diffusion-inpainting'),
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
('stable-diffusion-onnx-v2-inpainting',
'stabilityai/stable-diffusion-2-inpainting'),
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
(
"stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
),
# TODO: should have its own converter
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
],
'correction': [
('correction-gfpgan-v1-3',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
"correction": [
(
"correction-gfpgan-v1-3",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
4,
),
],
'upscaling': [
('upscaling-real-esrgan-x2-plus',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
('upscaling-real-esrgan-x4-plus',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
('upscaling-real-esrgan-x4-v3',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4),
"upscaling": [
(
"upscaling-real-esrgan-x2-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
2,
),
(
"upscaling-real-esrgan-x4-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
(
"upscaling-real-esrgan-x4-v3",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
],
}
model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
map_location = torch.device(training_device)
@torch.no_grad()
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + '.pth')
dest_onnx = path.join(model_path, name + '.onnx')
logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx)
dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info('ONNX model already exists, skipping.')
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info('PTH model not found, downloading...')
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=scale)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location)
if 'params_ema' in torch_model:
model.load_state_dict(torch_model['params_ema'])
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
model.load_state_dict(torch_model['params'], strict=False)
model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data']
output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'},
'output': {2: 'width', 3: 'height'}}
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info('exporting ONNX model to %s', dest_onnx)
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
@ -111,48 +136,57 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True
export_params=True,
)
logger.info('Real ESRGAN exported to ONNX successfully.')
logger.info("Real ESRGAN exported to ONNX successfully.")
@torch.no_grad()
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + '.pth')
dest_onnx = path.join(model_path, name + '.onnx')
logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx)
dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info('ONNX model already exists, skipping.')
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info('PTH model not found, downloading...')
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=scale)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location)
# TODO: make sure strict=False is safe here
if 'params_ema' in torch_model:
model.load_state_dict(torch_model['params_ema'], strict=False)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model['params'], strict=False)
model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data']
output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'},
'output': {2: 'width', 3: 'height'}}
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info('exporting ONNX model to %s', dest_onnx)
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
@ -161,9 +195,9 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True
export_params=True,
)
logger.info('GFPGAN exported to ONNX successfully.')
logger.info("GFPGAN exported to ONNX successfully.")
def onnx_export(
@ -176,9 +210,9 @@ def onnx_export(
opset,
use_external_data_format=False,
):
'''
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
'''
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
export(
@ -194,29 +228,33 @@ def onnx_export(
@torch.no_grad()
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
'''
def convert_diffuser(
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
'''
"""
dtype = torch.float16 if half else torch.float32
dest_path = path.join(model_path, name)
# diffusers go into a directory rather than .onnx file
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
if single_vae:
logger.info('converting model with single VAE')
logger.info("converting model with single VAE")
if path.isdir(dest_path):
logger.info('ONNX model already exists, skipping.')
logger.info("ONNX model already exists, skipping.")
return
if half and training_device != 'cuda':
if half and training_device != "cuda":
raise ValueError(
'Half precision model export is only supported on GPUs with CUDA')
"Half precision model export is only supported on GPUs with CUDA"
)
pipeline = StableDiffusionPipeline.from_pretrained(
url, torch_dtype=dtype, use_auth_token=token).to(training_device)
url, torch_dtype=dtype, use_auth_token=token
).to(training_device)
output_path = Path(dest_path)
# TEXT ENCODER
@ -232,8 +270,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(
device=training_device, dtype=torch.int32)),
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
@ -244,7 +281,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
)
del pipeline.text_encoder
logger.debug('UNET config: %s', pipeline.unet.config)
logger.debug("UNET config: %s", pipeline.unet.config)
# UNET
if single_vae:
@ -262,10 +299,12 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
device=training_device, dtype=dtype
),
torch.randn(2).to(device=training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=training_device, dtype=dtype),
device=training_device, dtype=dtype
),
unet_scale,
),
output_path=unet_path,
@ -298,7 +337,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
del pipeline.unet
if single_vae:
logger.debug('VAE config: %s', pipeline.vae.config)
logger.debug("VAE config: %s", pipeline.vae.config)
# SINGLE VAE
vae_only = pipeline.vae
@ -309,8 +348,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export(
vae_only,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / "model.onnx",
@ -328,12 +368,14 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
sample, return_dict
)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
device=training_device, dtype=dtype
),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
@ -354,8 +396,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
@ -385,7 +428,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
clip_image_size,
).to(device=training_device, dtype=dtype),
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
device=training_device, dtype=dtype),
device=training_device, dtype=dtype
),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
@ -398,7 +442,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(
output_path / "safety_checker")
output_path / "safety_checker"
)
feature_extractor = pipeline.feature_extractor
else:
safety_checker = None
@ -406,10 +451,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
if single_vae:
onnx_pipeline = StableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained(
output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
@ -417,12 +460,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
)
else:
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
@ -431,7 +471,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
requires_safety_checker=safety_checker is not None,
)
logger.info('exporting ONNX model')
logger.info("exporting ONNX model")
onnx_pipeline.save_pretrained(output_path)
logger.info("ONNX pipeline saved to %s", output_path)
@ -445,90 +485,93 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")
output_path, provider="CPUExecutionProvider"
)
logger.info("ONNX pipeline is loadable")
def load_models(args, models: Models):
if args.diffusion:
for source in models.get('diffusion'):
for source in models.get("diffusion"):
if source[0] in args.skip:
logger.info('Skipping model: %s', source[0])
logger.info("Skipping model: %s", source[0])
else:
single_vae = 'upscaling' in source[0]
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
single_vae = "upscaling" in source[0]
convert_diffuser(
*source, args.opset, args.half, args.token, single_vae=single_vae
)
if args.upscaling:
for source in models.get('upscaling'):
for source in models.get("upscaling"):
if source[0] in args.skip:
logger.info('Skipping model: %s', source[0])
logger.info("Skipping model: %s", source[0])
else:
convert_real_esrgan(*source, args.opset)
if args.correction:
for source in models.get('correction'):
for source in models.get("correction"):
if source[0] in args.skip:
logger.info('Skipping model: %s', source[0])
logger.info("Skipping model: %s", source[0])
else:
convert_gfpgan(*source, args.opset)
def main() -> int:
parser = ArgumentParser(
prog='onnx-web model converter',
description='convert checkpoint models to ONNX')
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument('--correction', action='store_true', default=False)
parser.add_argument('--diffusion', action='store_true', default=False)
parser.add_argument('--upscaling', action='store_true', default=False)
parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False)
# extra models
parser.add_argument('--extras', nargs='*', type=str, default=[])
parser.add_argument('--skip', nargs='*', type=str, default=[])
parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options
parser.add_argument(
'--half',
action='store_true',
"--half",
action="store_true",
default=False,
help='Export models for half precision, faster on some Nvidia cards.'
help="Export models for half precision, faster on some Nvidia cards.",
)
parser.add_argument(
'--opset',
"--opset",
default=14,
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument(
'--token',
"--token",
type=str,
help="HuggingFace token with read permissions for downloading models.",
)
args = parser.parse_args()
logger.info('CLI arguments: %s', args)
logger.info("CLI arguments: %s", args)
if not path.exists(model_path):
logger.info('Model path does not existing, creating: %s', model_path)
logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path)
logger.info('Converting base models.')
logger.info("Converting base models.")
load_models(args, base_models)
for file in args.extras:
logger.info('Loading extra models from %s', file)
logger.info("Loading extra models from %s", file)
try:
with open(file, 'r') as f:
with open(file, "r") as f:
data = loads(f.read())
logger.info('Converting extra models.')
logger.info("Converting extra models.")
load_models(args, data)
except Exception as err:
logger.error('Error converting extra models: %s', err)
logger.error("Error converting extra models: %s", err)
return 0
if __name__ == '__main__':
if __name__ == "__main__":
exit(main())

View File

@ -1,12 +1,10 @@
from collections import Counter
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, List, Optional, Tuple, Union
from .params import (
DeviceParams,
)
from .params import DeviceParams
logger = getLogger(__name__)
@ -28,24 +26,24 @@ class JobContext:
):
self.key = key
self.devices = list(devices)
self.cancel = Value('B', cancel)
self.device_index = Value('i', device_index)
self.progress = Value('I', progress)
self.cancel = Value("B", cancel)
self.device_index = Value("i", device_index)
self.progress = Value("I", progress)
def is_cancelled(self) -> bool:
return self.cancel.value
def get_device(self) -> DeviceParams:
'''
"""
Get the device assigned to this job.
'''
"""
with self.device_index.get_lock():
device_index = self.device_index.value
if device_index < 0:
raise Exception('job has not been assigned to a device')
raise Exception("job has not been assigned to a device")
else:
device = self.devices[device_index]
logger.debug('job %s assigned to device %s', self.key, device)
logger.debug("job %s assigned to device %s", self.key, device)
return device
def get_progress(self) -> int:
@ -54,10 +52,9 @@ class JobContext:
def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
raise Exception('job has been cancelled')
raise Exception("job has been cancelled")
else:
logger.debug('setting progress for job %s to %s',
self.key, step)
logger.debug("setting progress for job %s to %s", self.key, step)
self.set_progress(step)
return on_progress
@ -72,9 +69,9 @@ class JobContext:
class Job:
'''
"""
Link a future to its context.
'''
"""
context: JobContext = None
future: Future = None
@ -106,7 +103,11 @@ class DevicePoolExecutor:
next_device: int = 0
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
def __init__(
self,
devices: List[DeviceParams],
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
):
self.devices = devices
self.jobs = []
self.next_device = 0
@ -114,19 +115,25 @@ class DevicePoolExecutor:
device_count = len(devices)
if pool is None:
logger.info(
'creating thread pool executor for %s devices: %s', device_count, [d.device for d in devices])
"creating thread pool executor for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = ThreadPoolExecutor(device_count)
else:
logger.info('using existing pool for %s devices: %s',
device_count, [d.device for d in devices])
logger.info(
"using existing pool for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = pool
def cancel(self, key: str) -> bool:
'''
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
'''
"""
for job in self.jobs:
if job.key == key:
if job.future.cancel():
@ -144,7 +151,7 @@ class DevicePoolExecutor:
progress = job.get_progress()
return (done, progress)
logger.warn('checking status for unknown key: %s', key)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
def get_next_device(self):
@ -152,12 +159,14 @@ class DevicePoolExecutor:
if len(self.jobs) == 0:
return 0
job_devices = [job.context.device_index.value for job in self.jobs if not job.future.done()]
job_devices = [
job.context.device_index.value for job in self.jobs if not job.future.done()
]
job_counts = Counter(range(len(self.devices)))
job_counts.update(job_devices)
queued = job_counts.most_common()
logger.debug('jobs queued by device: %s', queued)
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
@ -170,7 +179,7 @@ class DevicePoolExecutor:
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
device = self.get_next_device()
logger.info('assigning job %s to device %s', key, device)
logger.info("assigning job %s to device %s", key, device)
context = JobContext(key, self.devices, device_index=device)
future = self.pool.submit(fn, context, *args, **kwargs)
@ -180,11 +189,19 @@ class DevicePoolExecutor:
def job_done(f: Future):
try:
f.result()
logger.info('job %s finished successfully', key)
logger.info("job %s finished successfully", key)
except Exception as err:
logger.warn('job %s failed with an error: %s', key, err)
logger.warn("job %s failed with an error: %s", key, err)
future.add_done_callback(job_done)
def status(self) -> List[Tuple[str, int, bool, int]]:
return [(job.key, job.context.device_index.value, job.future.done(), job.get_progress()) for job in self.jobs]
return [
(
job.key,
job.context.device_index.value,
job.future.done(),
job.get_progress(),
)
for job in self.jobs
]

View File

@ -1,18 +1,11 @@
from diffusers import (
DiffusionPipeline,
)
from logging import getLogger
from typing import Any, Optional, Tuple
from ..params import (
DeviceParams,
Size,
)
from ..utils import (
run_gc,
)
from typing import Any, Tuple
import numpy as np
from diffusers import DiffusionPipeline
from ..params import DeviceParams, Size
from ..utils import run_gc
logger = getLogger(__name__)
@ -25,17 +18,23 @@ latent_factor = 8
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
'''
"""
From https://www.travelneil.com/stable-diffusion-updates.html
'''
latents_shape = (batch, latent_channels, size.height // latent_factor,
size.width // latent_factor)
"""
latents_shape = (
batch,
latent_channels,
size.height // latent_factor,
size.width // latent_factor,
)
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
def get_tile_latents(
full_latents: np.ndarray, dims: Tuple[int, int, int]
) -> np.ndarray:
x, y, tile = dims
t = tile // latent_factor
x = x // latent_factor
@ -46,27 +45,29 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
return full_latents[:, :, y:yt, x:xt]
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
def load_pipeline(
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams
):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (pipeline, model, device.provider)
if last_pipeline_instance != None and last_pipeline_options == options:
logger.debug('reusing existing diffusion pipeline')
if last_pipeline_instance is not None and last_pipeline_options == options:
logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance
else:
logger.debug('unloading previous diffusion pipeline')
logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None
last_pipeline_scheduler = None
run_gc()
logger.debug('loading new diffusion pipeline from %s', model)
logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder='scheduler',
subfolder="scheduler",
)
pipe = pipeline.from_pretrained(
model,
@ -76,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler=scheduler,
)
if device is not None and hasattr(pipe, 'to'):
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device)
last_pipeline_instance = pipe
@ -84,15 +85,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
logger.debug('loading new diffusion scheduler')
logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder='scheduler',
subfolder="scheduler",
)
if device is not None and hasattr(scheduler, 'to'):
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device)
pipe.scheduler = scheduler

View File

@ -1,28 +1,16 @@
from diffusers import (
DDPMScheduler,
OnnxRuntimeModel,
StableDiffusionUpscalePipeline,
)
from diffusers.pipeline_utils import (
ImagePipelineOutput,
)
from logging import getLogger
from PIL import Image
from typing import (
Any,
Callable,
List,
Optional,
Union,
)
from typing import Any, Callable, List, Optional, Union
import numpy as np
import torch
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
from diffusers.pipeline_utils import ImagePipelineOutput
from PIL import Image
logger = getLogger(__name__)
num_channels_latents = 4 # self.vae.config.latent_channels
unet_in_channels = 7 # self.unet.config.in_channels
num_channels_latents = 4 # self.vae.config.latent_channels
unet_in_channels = 7 # self.unet.config.in_channels
###
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
@ -30,6 +18,7 @@ unet_in_channels = 7 # self.unet.config.in_channels
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
###
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
@ -63,8 +52,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler: Any,
max_noise_level: int = 350,
):
super().__init__(vae, text_encoder, tokenizer, unet,
low_res_scheduler, scheduler, max_noise_level)
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level,
)
def __call__(
self,
@ -96,7 +92,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
# 4. Preprocess image
@ -111,7 +111,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
text_embeddings_dtype = torch.float32
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
noise = torch.randn(
image.shape, generator=generator, device=device, dtype=text_embeddings_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -135,11 +137,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != unet_in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
"Incorrect configuration settings! The config of `pipeline.unet`"
f" expects {unet_in_channels} but received `num_channels_latents`:"
f" {num_channels_latents} + `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the"
" config of `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@ -150,10 +152,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
@ -164,19 +172,25 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level
class_labels=noise_level,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
@ -200,7 +214,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
@ -211,13 +232,20 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
"The following part of your input was truncated because CLIP can only"
f" handle sequences up to {self.tokenizer.model_max_length} tokens:"
f" {removed_text}"
)
# no positional arguments to text_encoder
@ -240,16 +268,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
"`negative_prompt` should be the same type to `prompt`, but got"
f" {type(negative_prompt)} != {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
f"`negative_prompt`: {negative_prompt} has batch size"
f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size"
f" {batch_size}. Please make sure that passed `negative_prompt`"
" matches the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt

View File

@ -1,41 +1,17 @@
from diffusers import (
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
)
from logging import getLogger
from PIL import Image, ImageChops
from typing import Any
from ..chain import (
upscale_outpaint,
)
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Border,
Size,
StageParams,
)
from ..output import (
save_image,
save_params,
)
from ..upscale import (
run_upscale_correction,
UpscaleParams,
)
from ..utils import (
run_gc,
ServerContext,
)
from .load import (
get_latents_from_seed,
load_pipeline,
)
import numpy as np
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams
from ..upscale import UpscaleParams, run_upscale_correction
from ..utils import ServerContext, run_gc
from .load import get_latents_from_seed, load_pipeline
logger = getLogger(__name__)
@ -46,10 +22,11 @@ def run_txt2img_pipeline(
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams
upscale: UpscaleParams,
) -> None:
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
@ -68,7 +45,8 @@ def run_txt2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
@ -77,7 +55,7 @@ def run_txt2img_pipeline(
del result
run_gc()
logger.info('finished txt2img job: %s', dest)
logger.info("finished txt2img job: %s", dest)
def run_img2img_pipeline(
@ -89,8 +67,12 @@ def run_img2img_pipeline(
source_image: Image.Image,
strength: float,
) -> None:
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.scheduler, job.get_device())
pipe = load_pipeline(
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
job.get_device(),
)
rng = np.random.RandomState(params.seed)
@ -107,7 +89,8 @@ def run_img2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale
)
dest = save_image(server, output, image)
size = Size(*source_image.size)
@ -117,7 +100,7 @@ def run_img2img_pipeline(
del result
run_gc()
logger.info('finished img2img job: %s', dest)
logger.info("finished img2img job: %s", dest)
def run_inpaint_pipeline(
@ -151,16 +134,14 @@ def run_inpaint_pipeline(
mask_filter=mask_filter,
noise_source=noise_source,
)
logger.info('applying mask filter and generating noise source')
logger.info("applying mask filter and generating noise source")
if image.size == source_image.size:
image = ImageChops.blend(source_image, image, strength)
else:
logger.info(
'output image size does not match source, skipping post-blend')
logger.info("output image size does not match source, skipping post-blend")
image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale)
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
@ -168,7 +149,7 @@ def run_inpaint_pipeline(
del image
run_gc()
logger.info('finished inpaint job: %s', dest)
logger.info("finished inpaint job: %s", dest)
def run_upscale_pipeline(
@ -185,7 +166,8 @@ def run_upscale_pipeline(
stage = StageParams()
image = run_upscale_correction(
job, server, stage, params, source_image, upscale=upscale)
job, server, stage, params, source_image, upscale=upscale
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
@ -193,4 +175,4 @@ def run_upscale_pipeline(
del image
run_gc()
logger.info('finished upscale job: %s', dest)
logger.info("finished upscale job: %s", dest)

View File

@ -1,31 +1,31 @@
import numpy as np
from numpy import random
from PIL import Image, ImageChops, ImageFilter
import numpy as np
from .params import (
Border,
Point,
)
from .params import Border, Point
def get_pixel_index(x: int, y: int, width: int) -> int:
return (y * width) + x
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
def mask_filter_none(
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
width, height = dims
noise = Image.new('RGB', (width, height), fill)
noise = Image.new("RGB", (width, height), fill)
noise.paste(mask_image, origin)
return noise
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
def mask_filter_gaussian_multiply(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur with multiply, source image centered on white canvas.
'''
"""
noise = mask_filter_none(mask_image, dims, origin)
for i in range(rounds):
@ -35,10 +35,12 @@ def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin:
return noise
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
def mask_filter_gaussian_screen(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas.
'''
"""
noise = mask_filter_none(mask_image, dims, origin)
for i in range(rounds):
@ -48,33 +50,39 @@ def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Po
return noise
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
'''
def noise_source_fill_edge(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Identity transform, source image centered on white canvas.
'''
"""
width, height = dims
noise = Image.new('RGB', (width, height), fill)
noise = Image.new("RGB", (width, height), fill)
noise.paste(source_image, origin)
return noise
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
'''
def noise_source_fill_mask(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Fill the whole canvas, no source or noise.
'''
"""
width, height = dims
noise = Image.new('RGB', (width, height), fill)
noise = Image.new("RGB", (width, height), fill)
return noise
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
def noise_source_gaussian(
source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas.
'''
"""
noise = noise_source_uniform(source_image, dims, origin)
noise.paste(source_image, origin)
@ -84,7 +92,9 @@ def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point,
return noise
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
def noise_source_uniform(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -92,21 +102,19 @@ def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point,
noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size)
noise = Image.new('RGB', (width, height))
noise = Image.new("RGB", (width, height))
for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (
int(noise_r[i]),
int(noise_g[i]),
int(noise_b[i])
))
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
def noise_source_normal(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -114,21 +122,19 @@ def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, *
noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size)
noise = Image.new('RGB', (width, height))
noise = Image.new("RGB", (width, height))
for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (
int(noise_r[i]),
int(noise_g[i]),
int(noise_b[i])
))
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
def noise_source_histogram(
source_image: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
r, g, b = source_image.split()
width, height = dims
size = width * height
@ -137,35 +143,34 @@ def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point
hist_g = g.histogram()
hist_b = b.histogram()
noise_r = random.choice(256, p=np.divide(
np.copy(hist_r), np.sum(hist_r)), size=size)
noise_g = random.choice(256, p=np.divide(
np.copy(hist_g), np.sum(hist_g)), size=size)
noise_b = random.choice(256, p=np.divide(
np.copy(hist_b), np.sum(hist_b)), size=size)
noise_r = random.choice(
256, p=np.divide(np.copy(hist_r), np.sum(hist_r)), size=size
)
noise_g = random.choice(
256, p=np.divide(np.copy(hist_g), np.sum(hist_g)), size=size
)
noise_b = random.choice(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
)
noise = Image.new('RGB', (width, height))
noise = Image.new("RGB", (width, height))
for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (
noise_r[i],
noise_g[i],
noise_b[i]
))
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
return noise
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image(
source_image: Image.Image,
mask_image: Image.Image,
expand: Border,
fill='white',
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
source_image: Image.Image,
mask_image: Image.Image,
expand: Border,
fill="white",
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
full_width = expand.left + source_image.width + expand.right
full_height = expand.top + source_image.height + expand.bottom
@ -173,14 +178,13 @@ def expand_image(
dims = (full_width, full_height)
origin = (expand.left, expand.top)
full_source = Image.new('RGB', dims, fill)
full_source = Image.new("RGB", dims, fill)
full_source.paste(source_image, origin)
full_mask = mask_filter(mask_image, dims, origin, fill=fill)
full_noise = noise_source(source_image, dims, origin, fill=fill)
full_noise = ImageChops.multiply(full_noise, full_mask)
full_source = Image.composite(
full_noise, full_source, full_mask.convert('L'))
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, (full_width, full_height))

View File

@ -1,14 +1,15 @@
from logging.config import dictConfig
from os import environ, path
from yaml import safe_load
logging_path = environ.get('ONNX_WEB_LOGGING_PATH', './logging.yaml')
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
# setup logging config before anything else loads
try:
if path.exists(logging_path):
with open(logging_path, 'r') as f:
config_logging = safe_load(f)
dictConfig(config_logging)
if path.exists(logging_path):
with open(logging_path, "r") as f:
config_logging = safe_load(f)
dictConfig(config_logging)
except Exception as err:
print('error loading logging config: %s' % (err))
print("error loading logging config: %s" % (err))

View File

@ -1,4 +1 @@
from .onnx_net import (
OnnxImage,
OnnxNet,
)
from .onnx_net import OnnxImage, OnnxNet

View File

@ -1,15 +1,14 @@
from onnxruntime import InferenceSession
from os import path
from typing import Any, Optional
import numpy as np
import torch
from onnxruntime import InferenceSession
from ..utils import (
ServerContext,
)
from ..utils import ServerContext
class OnnxImage():
class OnnxImage:
def __init__(self, source) -> None:
self.source = source
self.data = self
@ -38,28 +37,27 @@ class OnnxImage():
return np.shape(self.source)
class OnnxNet():
'''
class OnnxNet:
"""
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''
"""
def __init__(
self,
server: ServerContext,
model: str,
provider: str = 'DmlExecutionProvider',
provider: str = "DmlExecutionProvider",
provider_options: Optional[dict] = None,
) -> None:
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], provider_options=provider_options)
model_path, providers=[provider], provider_options=provider_options
)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
return OnnxImage(output)
def eval(self) -> None:

View File

@ -1,22 +1,14 @@
from hashlib import sha256
from json import dumps
from logging import getLogger
from PIL import Image
from struct import pack
from time import time
from typing import Any, Optional, Tuple
from .params import (
Border,
ImageParams,
Param,
Size,
UpscaleParams,
)
from .utils import (
base_join,
ServerContext,
)
from PIL import Image
from .params import Border, ImageParams, Param, Size, UpscaleParams
from .utils import ServerContext, base_join
logger = getLogger(__name__)
@ -25,13 +17,13 @@ def hash_value(sha, param: Param):
if param is None:
return
elif isinstance(param, float):
sha.update(bytearray(pack('!f', param)))
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack('!I', param)))
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode('utf-8'))
sha.update(param.encode("utf-8"))
else:
logger.warn('cannot hash param: %s, %s', param, type(param))
logger.warn("cannot hash param: %s, %s", param, type(param))
def json_params(
@ -42,22 +34,22 @@ def json_params(
border: Optional[Border] = None,
) -> Any:
json = {
'output': output,
'params': params.tojson(),
"output": output,
"params": params.tojson(),
}
if upscale is not None and border is not None:
size = upscale.resize(size.add_border(border))
if upscale is not None:
json['upscale'] = upscale.tojson()
json["upscale"] = upscale.tojson()
size = upscale.resize(size)
if border is not None:
json['border'] = border.tojson()
json["border"] = border.tojson()
size = size.add_border(border)
json['size'] = size.tojson()
json["size"] = size.tojson()
return json
@ -67,7 +59,7 @@ def make_output_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None
extras: Optional[Tuple[Param]] = None,
) -> str:
now = int(time())
sha = sha256()
@ -87,13 +79,19 @@ def make_output_name(
for param in extras:
hash_value(sha, param)
return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format)
return "%s_%s_%s_%s.%s" % (
mode,
params.seed,
sha.hexdigest(),
now,
ctx.image_format,
)
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, output)
image.save(path, format=ctx.image_format)
logger.debug('saved output image to: %s', path)
logger.debug("saved output image to: %s", path)
return path
@ -105,9 +103,9 @@ def save_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
) -> str:
path = base_join(ctx.output_path, '%s.json' % (output))
path = base_join(ctx.output_path, "%s.json" % (output))
json = json_params(output, params, size, upscale=upscale, border=border)
with open(path, 'w') as f:
with open(path, "w") as f:
f.write(dumps(json))
logger.debug('saved image params to: %s', path)
logger.debug("saved image params to: %s", path)
return path

View File

@ -3,9 +3,9 @@ from typing import Any, Dict, Literal, Optional, Tuple, Union
class SizeChart(IntEnum):
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
auto = 512 # auto tile size
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
auto = 512 # auto tile size
hd1k = 2**10
hd2k = 2**11
hd4k = 2**12
@ -26,14 +26,14 @@ class Border:
self.bottom = bottom
def __str__(self) -> str:
return '%s %s %s %s' % (self.left, self.top, self.right, self.bottom)
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
def tojson(self):
return {
'left': self.left,
'right': self.right,
'top': self.top,
'bottom': self.bottom,
"left": self.left,
"right": self.right,
"top": self.top,
"bottom": self.bottom,
}
@classmethod
@ -47,32 +47,37 @@ class Size:
self.height = height
def __str__(self) -> str:
return '%sx%s' % (self.width, self.height)
return "%sx%s" % (self.width, self.height)
def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
return Size(
border.left + self.width + border.right,
border.top + self.height + border.right,
)
def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
"height": self.height,
"width": self.width,
}
class DeviceParams:
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
def __init__(
self, device: str, provider: str, options: Optional[dict] = None
) -> None:
self.device = device
self.provider = provider
self.options = options
def __str__(self) -> str:
return '%s - %s (%s)' % (self.device, self.provider, self.options)
return "%s - %s (%s)" % (self.device, self.provider, self.options)
def torch_device(self) -> str:
if self.device.startswith('cuda'):
if self.device.startswith("cuda"):
return self.device
else:
return 'cpu'
return "cpu"
class ImageParams:
@ -84,7 +89,7 @@ class ImageParams:
negative_prompt: Optional[str],
cfg: float,
steps: int,
seed: int
seed: int,
) -> None:
self.model = model
self.scheduler = scheduler
@ -96,20 +101,20 @@ class ImageParams:
def tojson(self) -> Dict[str, Optional[Param]]:
return {
'model': self.model,
'scheduler': self.scheduler.__name__,
'seed': self.seed,
'prompt': self.prompt,
'cfg': self.cfg,
'negativePrompt': self.negative_prompt,
'steps': self.steps,
"model": self.model,
"scheduler": self.scheduler.__name__,
"seed": self.seed,
"prompt": self.prompt,
"cfg": self.cfg,
"negativePrompt": self.negative_prompt,
"steps": self.steps,
}
class StageParams:
'''
"""
Parameters for a chained pipeline stage
'''
"""
def __init__(
self,
@ -123,7 +128,7 @@ class StageParams:
self.outscale = outscale
class UpscaleParams():
class UpscaleParams:
def __init__(
self,
upscale_model: str,
@ -131,7 +136,7 @@ class UpscaleParams():
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
format: Literal['onnx', 'pth'] = 'onnx',
format: Literal["onnx", "pth"] = "onnx",
half=False,
outscale: int = 1,
scale: int = 4,
@ -170,8 +175,8 @@ class UpscaleParams():
def tojson(self):
return {
'model': self.upscale_model,
'scale': self.scale,
'outscale': self.outscale,
"model": self.upscale_model,
"scale": self.scale,
"outscale": self.outscale,
# TODO: add more
}

View File

@ -1,62 +1,61 @@
from . import logging
import gc
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import List, Tuple
import numpy as np
import torch
import yaml
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
KarrasVeScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from functools import cmp_to_key
from glob import glob
from io import BytesIO
from jsonschema import validate
from logging import getLogger
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path
from typing import List, Tuple
from PIL import Image
from . import logging
from .chain import (
ChainPipeline,
blend_img2img,
blend_inpaint,
correct_gfpgan,
persist_disk,
persist_s3,
reduce_thumbnail,
reduce_crop,
reduce_thumbnail,
source_noise,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
ChainPipeline,
)
from .device_pool import (
DeviceParams,
DevicePoolExecutor,
)
from .device_pool import DevicePoolExecutor
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .image import (
# mask filters
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
# noise sources
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
@ -64,35 +63,20 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .output import (
json_params,
make_output_name,
)
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
UpscaleParams,
)
from .output import json_params, make_output_name
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from .utils import (
ServerContext,
base_join,
is_debug,
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
get_size,
ServerContext,
is_debug,
)
import gc
import numpy as np
import torch
import yaml
logger = getLogger(__name__)
# config caching
@ -100,53 +84,53 @@ config_params = {}
# pipeline params
platform_providers = {
'amd': 'DmlExecutionProvider',
'cpu': 'CPUExecutionProvider',
'cuda': 'CUDAExecutionProvider',
'directml': 'DmlExecutionProvider',
'nvidia': 'CUDAExecutionProvider',
'rocm': 'ROCMExecutionProvider',
"amd": "DmlExecutionProvider",
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"directml": "DmlExecutionProvider",
"nvidia": "CUDAExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
pipeline_schedulers = {
'ddim': DDIMScheduler,
'ddpm': DDPMScheduler,
'dpm-multi': DPMSolverMultistepScheduler,
'dpm-single': DPMSolverSinglestepScheduler,
'euler': EulerDiscreteScheduler,
'euler-a': EulerAncestralDiscreteScheduler,
'heun': HeunDiscreteScheduler,
'k-dpm-2-a': KDPM2AncestralDiscreteScheduler,
'k-dpm-2': KDPM2DiscreteScheduler,
'karras-ve': KarrasVeScheduler,
'lms-discrete': LMSDiscreteScheduler,
'pndm': PNDMScheduler,
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
"dpm-multi": DPMSolverMultistepScheduler,
"dpm-single": DPMSolverSinglestepScheduler,
"euler": EulerDiscreteScheduler,
"euler-a": EulerAncestralDiscreteScheduler,
"heun": HeunDiscreteScheduler,
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
}
noise_sources = {
'fill-edge': noise_source_fill_edge,
'fill-mask': noise_source_fill_mask,
'gaussian': noise_source_gaussian,
'histogram': noise_source_histogram,
'normal': noise_source_normal,
'uniform': noise_source_uniform,
"fill-edge": noise_source_fill_edge,
"fill-mask": noise_source_fill_mask,
"gaussian": noise_source_gaussian,
"histogram": noise_source_histogram,
"normal": noise_source_normal,
"uniform": noise_source_uniform,
}
mask_filters = {
'none': mask_filter_none,
'gaussian-multiply': mask_filter_gaussian_multiply,
'gaussian-screen': mask_filter_gaussian_screen,
"none": mask_filter_none,
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
chain_stages = {
'blend-img2img': blend_img2img,
'blend-inpaint': blend_inpaint,
'correct-gfpgan': correct_gfpgan,
'persist-disk': persist_disk,
'persist-s3': persist_s3,
'reduce-crop': reduce_crop,
'reduce-thumbnail': reduce_thumbnail,
'source-noise': source_noise,
'source-txt2img': source_txt2img,
'upscale-outpaint': upscale_outpaint,
'upscale-resrgan': upscale_resrgan,
'upscale-stable-diffusion': upscale_stable_diffusion,
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
# Available ORT providers
@ -158,7 +142,7 @@ correction_models = []
upscaling_models = []
def get_config_value(key: str, subkey: str = 'default'):
def get_config_value(key: str, subkey: str = "default"):
return config_params.get(key).get(subkey)
@ -174,7 +158,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
device_name = request.args.get('platform', available_platforms[0].device)
device_name = request.args.get("platform", available_platforms[0].device)
device = None
for platform in available_platforms:
@ -182,78 +166,101 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
device = available_platforms[0]
if device is None:
raise Exception('unknown device')
raise Exception("unknown device")
# pipeline stuff
model = get_not_empty(request.args, 'model', get_config_value('model'))
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(model)
scheduler = get_from_map(request.args, 'scheduler',
pipeline_schedulers, get_config_value('scheduler'))
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
# image params
prompt = get_not_empty(request.args,
'prompt', get_config_value('prompt'))
negative_prompt = request.args.get('negativePrompt', None)
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
negative_prompt = request.args.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == '':
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
cfg = get_and_clamp_float(
request.args, 'cfg',
get_config_value('cfg'),
get_config_value('cfg', 'max'),
get_config_value('cfg', 'min'))
request.args,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
steps = get_and_clamp_int(
request.args, 'steps',
get_config_value('steps'),
get_config_value('steps', 'max'),
get_config_value('steps', 'min'))
request.args,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
request.args, 'height',
get_config_value('height'),
get_config_value('height', 'max'),
get_config_value('height', 'min'))
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
request.args, 'width',
get_config_value('width'),
get_config_value('width', 'max'),
get_config_value('width', 'min'))
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
seed = int(request.args.get('seed', -1))
seed = int(request.args.get("seed", -1))
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler.__name__,
model_path,
device.provider,
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(model_path, scheduler, prompt,
negative_prompt, cfg, steps, seed)
params = ImageParams(
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed
)
size = Size(width, height)
return (device, params, size)
def border_from_request() -> Border:
left = get_and_clamp_int(request.args, 'left', 0,
get_config_value('width', 'max'), 0)
right = get_and_clamp_int(request.args, 'right',
0, get_config_value('width', 'max'), 0)
top = get_and_clamp_int(request.args, 'top', 0,
get_config_value('height', 'max'), 0)
left = get_and_clamp_int(
request.args, "left", 0, get_config_value("width", "max"), 0
)
right = get_and_clamp_int(
request.args, "right", 0, get_config_value("width", "max"), 0
)
top = get_and_clamp_int(
request.args, "top", 0, get_config_value("height", "max"), 0
)
bottom = get_and_clamp_int(
request.args, 'bottom', 0, get_config_value('height', 'max'), 0)
request.args, "bottom", 0, get_config_value("height", "max"), 0
)
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = get_not_empty(request.args, 'faces', 'false') == 'true'
face_strength = get_and_clamp_float(
request.args, 'faceStrength', 0.5, 1.0, 0.0)
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
return UpscaleParams(
upscaling,
@ -261,7 +268,7 @@ def upscale_from_request() -> UpscaleParams:
denoise=denoise,
faces=faces,
face_strength=face_strength,
format='onnx',
format="onnx",
outscale=outscale,
scale=scale,
)
@ -269,7 +276,7 @@ def upscale_from_request() -> UpscaleParams:
def check_paths(context: ServerContext):
if not path.exists(context.model_path):
raise RuntimeError('model path must exist')
raise RuntimeError("model path must exist")
if not path.exists(context.output_path):
makedirs(context.output_path)
@ -286,35 +293,41 @@ def load_models(context: ServerContext):
global correction_models
global upscaling_models
diffusion_models = [get_model_name(f) for f in glob(
path.join(context.model_path, 'diffusion-*'))]
diffusion_models.extend([
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))])
diffusion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
]
diffusion_models.extend(
[
get_model_name(f)
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
]
)
diffusion_models = list(set(diffusion_models))
diffusion_models.sort()
correction_models = [
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))]
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
]
correction_models = list(set(correction_models))
correction_models.sort()
upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))]
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
upscaling_models = list(set(upscaling_models))
upscaling_models.sort()
def load_params(context: ServerContext):
global config_params
params_file = path.join(context.params_path, 'params.json')
with open(params_file, 'r') as f:
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
if 'platform' in config_params and context.default_platform is not None:
logger.info('overriding default platform to %s',
context.default_platform)
config_platform = config_params.get('platform')
config_platform['default'] = context.default_platform
if "platform" in config_params and context.default_platform is not None:
logger.info("overriding default platform to %s", context.default_platform)
config_platform = config_params.get("platform")
config_platform["default"] = context.default_platform
def load_platforms():
@ -323,30 +336,42 @@ def load_platforms():
providers = get_available_providers()
for potential in platform_providers:
if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda':
if (
platform_providers[potential] in providers
and potential not in context.block_platforms
):
if potential == "cuda":
for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams(potential, platform_providers[potential], {
'device_id': i,
}))
available_platforms.append(
DeviceParams(
potential,
platform_providers[potential],
{
"device_id": i,
},
)
)
else:
available_platforms.append(DeviceParams(
potential, platform_providers[potential]))
available_platforms.append(
DeviceParams(potential, platform_providers[potential])
)
# make sure CPU is last on the list
def cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == 'cpu' and b.device == 'cpu':
if a.device == "cpu" and b.device == "cpu":
return 0
if a.device == 'cpu':
if a.device == "cpu":
return 1
return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
logger.info('available acceleration platforms: %s',
', '.join([str(p) for p in available_platforms]))
logger.info(
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)
context = ServerContext.from_environ()
@ -365,16 +390,22 @@ if is_debug():
def ready_reply(ready: bool, progress: int = 0):
return jsonify({
'progress': progress,
'ready': ready,
})
return jsonify(
{
"progress": progress,
"ready": ready,
}
)
def error_reply(err: str):
response = make_response(jsonify({
'error': err,
}))
response = make_response(
jsonify(
{
"error": err,
}
)
)
response.status_code = 400
return response
@ -383,151 +414,154 @@ def get_model_path(model: str):
return base_join(context.model_path, model)
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', context.bundle_path), filename)
def serve_bundle_file(filename="index.html"):
return send_from_directory(path.join("..", context.bundle_path), filename)
# routes
@app.route('/')
@app.route("/")
def index():
return serve_bundle_file()
@app.route('/<path:filename>')
@app.route("/<path:filename>")
def index_path(filename):
return serve_bundle_file(filename)
@app.route('/api')
@app.route("/api")
def introspect():
return {
'name': 'onnx-web',
'routes': [{
'path': url_from_rule(rule),
'methods': list(rule.methods).sort()
} for rule in app.url_map.iter_rules()]
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
for rule in app.url_map.iter_rules()
],
}
@app.route('/api/settings/masks')
@app.route("/api/settings/masks")
def list_mask_filters():
return jsonify(list(mask_filters.keys()))
@app.route('/api/settings/models')
@app.route("/api/settings/models")
def list_models():
return jsonify({
'diffusion': diffusion_models,
'correction': correction_models,
'upscaling': upscaling_models,
})
return jsonify(
{
"diffusion": diffusion_models,
"correction": correction_models,
"upscaling": upscaling_models,
}
)
@app.route('/api/settings/noises')
@app.route("/api/settings/noises")
def list_noise_sources():
return jsonify(list(noise_sources.keys()))
@app.route('/api/settings/params')
@app.route("/api/settings/params")
def list_params():
return jsonify(config_params)
@app.route('/api/settings/platforms')
@app.route("/api/settings/platforms")
def list_platforms():
return jsonify([p.device for p in available_platforms])
@app.route('/api/settings/schedulers')
@app.route("/api/settings/schedulers")
def list_schedulers():
return jsonify(list(pipeline_schedulers.keys()))
@app.route('/api/img2img', methods=['POST'])
@app.route("/api/img2img", methods=["POST"])
def img2img():
if 'source' not in request.files:
return error_reply('source image is required')
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
'strength',
get_config_value('strength'),
get_config_value('strength', 'max'),
get_config_value('strength', 'min'))
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
output = make_output_name(
context,
'img2img',
params,
size,
extras=(strength,))
output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
executor.submit(output, run_img2img_pipeline,
context, params, output, upscale, source_image, strength)
executor.submit(
output,
run_img2img_pipeline,
context,
params,
output,
upscale,
source_image,
strength,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/txt2img', methods=['POST'])
@app.route("/api/txt2img", methods=["POST"])
def txt2img():
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(
context,
'txt2img',
params,
size)
output = make_output_name(context, "txt2img", params, size)
logger.info("txt2img job queued for: %s", output)
executor.submit(
output, run_txt2img_pipeline, context, params, size, output, upscale)
output, run_txt2img_pipeline, context, params, size, output, upscale
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/inpaint', methods=['POST'])
@app.route("/api/inpaint", methods=["POST"])
def inpaint():
if 'source' not in request.files:
return error_reply('source image is required')
if "source" not in request.files:
return error_reply("source image is required")
if 'mask' not in request.files:
return error_reply('mask image is required')
if "mask" not in request.files:
return error_reply("mask image is required")
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get('mask')
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
mask_file = request.files.get("mask")
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
expand = border_from_request()
upscale = upscale_from_request()
fill_color = get_not_empty(request.args, 'fillColor', 'white')
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
noise_source = get_from_map(
request.args, 'noise', noise_sources, 'histogram')
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
strength = get_and_clamp_float(
request.args,
'strength',
get_config_value('strength'),
get_config_value('strength', 'max'),
get_config_value('strength', 'min'))
"strength",
get_config_value("strength"),
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
output = make_output_name(
context,
'inpaint',
"inpaint",
params,
size,
extras=(
@ -539,7 +573,7 @@ def inpaint():
noise_source.__name__,
strength,
fill_color,
)
),
)
logger.info("inpaint job queued for: %s", output)
@ -559,123 +593,131 @@ def inpaint():
noise_source,
mask_filter,
strength,
fill_color)
fill_color,
)
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
@app.route('/api/upscale', methods=['POST'])
@app.route("/api/upscale", methods=["POST"])
def upscale():
if 'source' not in request.files:
return error_reply('source image is required')
if "source" not in request.files:
return error_reply("source image is required")
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(
context,
'upscale',
params,
size)
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
executor.submit(output, run_upscale_pipeline,
context, params, size, output, upscale, source_image)
executor.submit(
output,
run_upscale_pipeline,
context,
params,
size,
output,
upscale,
source_image,
)
return jsonify(json_params(output, params, size, upscale=upscale))
@app.route('/api/chain', methods=['POST'])
@app.route("/api/chain", methods=["POST"])
def chain():
logger.debug('chain pipeline request: %s, %s',
request.form.keys(), request.files.keys())
body = request.form.get('chain') or request.files.get('chain')
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply('chain pipeline must have a body')
return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body)
with open('./schema.yaml', 'r') as f:
with open("./schema.yaml", "r") as f:
schema = yaml.safe_load(f.read())
logger.info('validating chain request: %s against %s', data, schema)
logger.info("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request()
output = make_output_name(
context,
'chain',
params,
size)
output = make_output_name(context, "chain", params, size)
pipeline = ChainPipeline()
for stage_data in data.get('stages', []):
callback = chain_stages[stage_data.get('type')]
kwargs = stage_data.get('params', {})
logger.info('request stage: %s, %s', callback.__name__, kwargs)
for stage_data in data.get("stages", []):
callback = chain_stages[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
stage = StageParams(
stage_data.get('name', callback.__name__),
tile_size=get_size(kwargs.get('tile_size')),
outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4),
stage_data.get("name", callback.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
if 'border' in kwargs:
border = Border.even(int(kwargs.get('border')))
kwargs['border'] = border
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if 'upscale' in kwargs:
upscale = UpscaleParams(kwargs.get('upscale'))
kwargs['upscale'] = upscale
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
stage_source_name = 'source:%s' % (stage.name)
stage_mask_name = 'mask:%s' % (stage.name)
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
if stage_source_name in request.files:
logger.debug('loading source image %s for pipeline stage %s',
stage_source_name, stage.name)
logger.debug(
"loading source image %s for pipeline stage %s",
stage_source_name,
stage.name,
)
source_file = request.files.get(stage_source_name)
source_image = Image.open(
BytesIO(source_file.read())).convert('RGB')
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image = source_image.thumbnail((512, 512))
kwargs['source_image'] = source_image
kwargs["source_image"] = source_image
if stage_mask_name in request.files:
logger.debug('loading mask image %s for pipeline stage %s',
stage_mask_name, stage.name)
logger.debug(
"loading mask image %s for pipeline stage %s",
stage_mask_name,
stage.name,
)
mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image = mask_image.thumbnail((512, 512))
kwargs['mask_image'] = mask_image
kwargs["mask_image"] = mask_image
pipeline.append((callback, stage, kwargs))
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
# build and run chain pipeline
empty_source = Image.new('RGB', (size.width, size.height))
executor.submit(output, pipeline, context,
params, empty_source, output=output, size=size)
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
output, pipeline, context, params, empty_source, output=output, size=size
)
return jsonify(json_params(output, params, size))
@app.route('/api/cancel', methods=['PUT'])
@app.route("/api/cancel", methods=["PUT"])
def cancel():
output_file = request.args.get('output', None)
output_file = request.args.get("output", None)
cancel = executor.cancel(output_file)
return ready_reply(cancel)
@app.route('/api/ready')
@app.route("/api/ready")
def ready():
output_file = request.args.get('output', None)
output_file = request.args.get("output", None)
done, progress = executor.done(output_file)
@ -687,11 +729,13 @@ def ready():
return ready_reply(done, progress=progress)
@app.route('/api/status')
@app.route("/api/status")
def status():
return jsonify(executor.status())
@app.route('/output/<path:filename>')
@app.route("/output/<path:filename>")
def output(filename: str):
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
return send_from_directory(
path.join("..", context.output_path), filename, as_attachment=False
)

View File

@ -1,24 +1,16 @@
from logging import getLogger
from PIL import Image
from .chain import (
correct_gfpgan,
upscale_stable_diffusion,
upscale_resrgan,
ChainPipeline,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .device_pool import (
JobContext,
)
from .params import (
ImageParams,
SizeChart,
StageParams,
UpscaleParams,
)
from .utils import (
ServerContext,
)
from .device_pool import JobContext
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .utils import ServerContext
logger = getLogger(__name__)
@ -32,27 +24,25 @@ def run_upscale_correction(
*,
upscale: UpscaleParams,
) -> Image.Image:
'''
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
'''
logger.info('running upscaling and correction pipeline')
"""
logger.info("running upscaling and correction pipeline")
chain = ChainPipeline()
if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
if "esrgan" in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale)
chain.append((upscale_resrgan, stage, None))
elif 'stable-diffusion' in upscale.upscale_model:
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, stage, None))
if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=1)
stage = StageParams(tile_size=stage.tile_size, outscale=1)
chain.append((correct_gfpgan, stage, None))
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)

View File

@ -1,13 +1,11 @@
import gc
from logging import getLogger
from os import environ, path
from typing import Any, Dict, List, Optional, Union
import gc
import torch
from .params import (
SizeChart,
)
from .params import SizeChart
logger = getLogger(__name__)
@ -15,15 +13,15 @@ logger = getLogger(__name__)
class ServerContext:
def __init__(
self,
bundle_path: str = '.',
model_path: str = '.',
output_path: str = '.',
params_path: str = '.',
cors_origin: str = '*',
bundle_path: str = ".",
model_path: str = ".",
output_path: str = ".",
params_path: str = ".",
cors_origin: str = "*",
num_workers: int = 1,
block_platforms: List[str] = [],
default_platform: str = None,
image_format: str = 'png',
image_format: str = "png",
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -38,40 +36,39 @@ class ServerContext:
@classmethod
def from_environ(cls):
return ServerContext(
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out')),
model_path=environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models')),
output_path=environ.get(
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get(
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
default_platform=environ.get(
'ONNX_WEB_DEFAULT_PLATFORM', None),
image_format=environ.get(
'ONNX_WEB_IMAGE_FORMAT', 'png'
bundle_path=environ.get(
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
),
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
)
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
return path.join(base, tail_path)
def is_debug() -> bool:
return environ.get('DEBUG') is not None
return environ.get("DEBUG") is not None
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int:
def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
@ -80,7 +77,7 @@ def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
if selected in values:
return selected
logger.warn('invalid selection: %s', selected)
logger.warn("invalid selection: %s", selected)
if len(values) > 0:
return values[0]
@ -118,10 +115,10 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
return int(val)
raise Exception('invalid size')
raise Exception("invalid size")
def run_gc():
logger.debug('running garbage collection')
logger.debug("running garbage collection")
gc.collect()
torch.cuda.empty_cache()

2
api/pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.isort]
profile = "black"