lint(api): apply black and isort style
This commit is contained in:
parent
c0e5f435ee
commit
54dd34d211
13
api/Makefile
13
api/Makefile
|
@ -23,3 +23,16 @@ package-dist:
|
|||
|
||||
package-upload:
|
||||
twine upload dist/*
|
||||
|
||||
lint-check:
|
||||
black --check --preview onnx_web
|
||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
||||
|
||||
lint-fix:
|
||||
black onnx_web
|
||||
isort --skip __init__.py --filter-files onnx_web
|
||||
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
||||
|
||||
typecheck:
|
||||
mypy -m onnx_web.serve
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
black
|
||||
flake8
|
||||
isort
|
||||
mypy
|
||||
|
||||
types-Flask-Cors
|
||||
|
|
|
@ -1,14 +1,6 @@
|
|||
from . import logging
|
||||
|
||||
from .chain import (
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .diffusion.load import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
|
||||
from .diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from .diffusion.run import (
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
|
@ -26,24 +18,14 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .params import (
|
||||
Param,
|
||||
Point,
|
||||
Border,
|
||||
Size,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .upscale import (
|
||||
run_upscale_correction,
|
||||
)
|
||||
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
|
||||
from .upscale import run_upscale_correction
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
base_join,
|
||||
get_and_clamp_float,
|
||||
get_and_clamp_int,
|
||||
get_from_list,
|
||||
get_from_map,
|
||||
get_not_empty,
|
||||
base_join,
|
||||
ServerContext,
|
||||
)
|
|
@ -1,42 +1,13 @@
|
|||
from .base import (
|
||||
ChainPipeline,
|
||||
PipelineStage,
|
||||
StageCallback,
|
||||
StageParams,
|
||||
)
|
||||
from .blend_img2img import (
|
||||
blend_img2img,
|
||||
)
|
||||
from .blend_inpaint import (
|
||||
blend_inpaint,
|
||||
)
|
||||
from .correct_gfpgan import (
|
||||
correct_gfpgan,
|
||||
)
|
||||
from .persist_disk import (
|
||||
persist_disk,
|
||||
)
|
||||
from .persist_s3 import (
|
||||
persist_s3,
|
||||
)
|
||||
from .reduce_crop import (
|
||||
reduce_crop,
|
||||
)
|
||||
from .reduce_thumbnail import (
|
||||
reduce_thumbnail,
|
||||
)
|
||||
from .source_noise import (
|
||||
source_noise,
|
||||
)
|
||||
from .source_txt2img import (
|
||||
source_txt2img,
|
||||
)
|
||||
from .upscale_outpaint import (
|
||||
upscale_outpaint,
|
||||
)
|
||||
from .upscale_resrgan import (
|
||||
upscale_resrgan,
|
||||
)
|
||||
from .upscale_stable_diffusion import (
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||
from .blend_img2img import blend_img2img
|
||||
from .blend_inpaint import blend_inpaint
|
||||
from .correct_gfpgan import correct_gfpgan
|
||||
from .persist_disk import persist_disk
|
||||
from .persist_s3 import persist_s3
|
||||
from .reduce_crop import reduce_crop
|
||||
from .reduce_thumbnail import reduce_thumbnail
|
||||
from .source_noise import source_noise
|
||||
from .source_txt2img import source_txt2img
|
||||
from .upscale_outpaint import upscale_outpaint
|
||||
from .upscale_resrgan import upscale_resrgan
|
||||
from .upscale_stable_diffusion import upscale_stable_diffusion
|
||||
|
|
|
@ -1,26 +1,15 @@
|
|||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from time import monotonic
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..output import (
|
||||
save_image,
|
||||
)
|
||||
from ..utils import (
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tile_grid,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -42,33 +31,43 @@ PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
|||
|
||||
|
||||
class ChainPipeline:
|
||||
'''
|
||||
"""
|
||||
Run many stages in series, passing the image results from each to the next, and processing
|
||||
tiles as needed.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[PipelineStage] = [],
|
||||
):
|
||||
'''
|
||||
"""
|
||||
Create a new pipeline that will run the given stages.
|
||||
'''
|
||||
"""
|
||||
self.stages = list(stages)
|
||||
|
||||
def append(self, stage: PipelineStage):
|
||||
'''
|
||||
"""
|
||||
Append an additional stage to this pipeline.
|
||||
'''
|
||||
"""
|
||||
self.stages.append(stage)
|
||||
|
||||
def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
|
||||
'''
|
||||
def __call__(
|
||||
self,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
**pipeline_kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: handle List[Image] outputs
|
||||
'''
|
||||
"""
|
||||
start = monotonic()
|
||||
logger.info('running pipeline on source image with dimensions %sx%s',
|
||||
source.width, source.height)
|
||||
logger.info(
|
||||
"running pipeline on source image with dimensions %sx%s",
|
||||
source.width,
|
||||
source.height,
|
||||
)
|
||||
image = source
|
||||
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
|
@ -76,37 +75,51 @@ class ChainPipeline:
|
|||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
|
||||
logger.info('running stage %s on image with dimensions %sx%s, %s',
|
||||
name, image.width, image.height, kwargs.keys())
|
||||
logger.info(
|
||||
"running stage %s on image with dimensions %sx%s, %s",
|
||||
name,
|
||||
image.width,
|
||||
image.height,
|
||||
kwargs.keys(),
|
||||
)
|
||||
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
logger.info('image larger than tile size of %s, tiling stage',
|
||||
stage_params.tile_size)
|
||||
if (
|
||||
image.width > stage_params.tile_size
|
||||
or image.height > stage_params.tile_size
|
||||
):
|
||||
logger.info(
|
||||
"image larger than tile size of %s, tiling stage",
|
||||
stage_params.tile_size,
|
||||
)
|
||||
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
tile = stage_pipe(job, server, stage_params, params, tile,
|
||||
**kwargs)
|
||||
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'last-tile.png', tile)
|
||||
save_image(server, "last-tile.png", tile)
|
||||
|
||||
return tile
|
||||
|
||||
image = process_tile_grid(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
|
||||
)
|
||||
else:
|
||||
logger.info('image within tile size, running stage')
|
||||
image = stage_pipe(job, server, stage_params, params, image,
|
||||
**kwargs)
|
||||
logger.info("image within tile size, running stage")
|
||||
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
|
||||
|
||||
logger.info('finished stage %s, result size: %sx%s',
|
||||
name, image.width, image.height)
|
||||
logger.info(
|
||||
"finished stage %s, result size: %sx%s", name, image.width, image.height
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'last-stage.png', image)
|
||||
save_image(server, "last-stage.png", image)
|
||||
|
||||
end = monotonic()
|
||||
duration = timedelta(seconds=(end - start))
|
||||
logger.info('finished pipeline in %s, result size: %sx%s',
|
||||
duration, image.width, image.height)
|
||||
logger.info(
|
||||
"finished pipeline in %s, result size: %sx%s",
|
||||
duration,
|
||||
image.width,
|
||||
image.height,
|
||||
)
|
||||
return image
|
||||
|
|
|
@ -1,24 +1,13 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..diffusion.load import (
|
||||
load_pipeline,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.load import load_pipeline
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -35,10 +24,14 @@ def blend_img2img(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
|
||||
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
|
@ -53,6 +46,5 @@ def blend_img2img(
|
|||
)
|
||||
output = result.images[0]
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
|
||||
|
|
|
@ -1,41 +1,17 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..diffusion.load import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from ..image import (
|
||||
expand_image,
|
||||
mask_filter_none,
|
||||
noise_source_histogram,
|
||||
)
|
||||
from ..params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
)
|
||||
from ..output import (
|
||||
save_image,
|
||||
)
|
||||
from ..utils import (
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tile_grid,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -49,16 +25,16 @@ def blend_inpaint(
|
|||
*,
|
||||
expand: Border,
|
||||
mask_image: Image.Image = None,
|
||||
fill_color: str = 'white',
|
||||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling image by expanding borders', expand)
|
||||
logger.info("upscaling image by expanding borders", expand)
|
||||
|
||||
if mask_image is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
mask_image = Image.new('RGB', source_image.size, 'black')
|
||||
mask_image = Image.new("RGB", source_image.size, "black")
|
||||
|
||||
source_image, mask_image, noise_image, _full_dims = expand_image(
|
||||
source_image,
|
||||
|
@ -66,12 +42,13 @@ def blend_inpaint(
|
|||
expand,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter)
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'last-source.png', source_image)
|
||||
save_image(server, 'last-mask.png', mask_image)
|
||||
save_image(server, 'last-noise.png', noise_image)
|
||||
save_image(server, "last-source.png", source_image)
|
||||
save_image(server, "last-mask.png", mask_image)
|
||||
save_image(server, "last-noise.png", noise_image)
|
||||
|
||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
|
@ -79,11 +56,15 @@ def blend_inpaint(
|
|||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'tile-source.png', image)
|
||||
save_image(server, 'tile-mask.png', mask)
|
||||
save_image(server, "tile-source.png", image)
|
||||
save_image(server, "tile-mask.png", mask)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -104,5 +85,5 @@ def blend_inpaint(
|
|||
|
||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info('final output image size', output.size)
|
||||
logger.info("final output image size", output.size)
|
||||
return output
|
||||
|
|
|
@ -1,30 +1,53 @@
|
|||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
import torch
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
pretrain_model_url = {
|
||||
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
|
||||
}
|
||||
pretrain_model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
|
||||
device = 'cpu'
|
||||
device = "cpu"
|
||||
upscale = 2
|
||||
|
||||
def correct_codeformer(image: Image.Image) -> Image.Image:
|
||||
|
||||
def correct_codeformer(
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
model = "TODO"
|
||||
|
||||
# ------------------ set up CodeFormer restorer -------------------
|
||||
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
|
||||
connect_list=['32', '64', '128', '256']).to(device)
|
||||
net = ARCH_REGISTRY.get("CodeFormer")(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
).to(device)
|
||||
|
||||
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
||||
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
|
||||
model_dir='weights/CodeFormer', progress=True, file_name=None)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
ckpt_path = load_file_from_url(
|
||||
url=pretrain_model_url,
|
||||
model_dir="weights/CodeFormer",
|
||||
progress=True,
|
||||
file_name=None,
|
||||
)
|
||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
|
@ -36,22 +59,24 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
|||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model = args.detection_model,
|
||||
save_ext='png',
|
||||
det_model=model,
|
||||
save_ext="png",
|
||||
use_parse=True,
|
||||
device=device)
|
||||
device=device,
|
||||
)
|
||||
|
||||
# get face landmarks for each face
|
||||
num_det_faces = face_helper.get_face_landmarks_5(
|
||||
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
|
||||
logger.info('detect %s faces', num_det_faces)
|
||||
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||
)
|
||||
logger.info("detect %s faces", num_det_faces)
|
||||
# align and warp each face
|
||||
face_helper.align_warp_face()
|
||||
|
||||
# face restoration for each cropped face
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
|
@ -62,10 +87,10 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
|||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as error:
|
||||
logger.error('Failed inference for CodeFormer: %s', error)
|
||||
logger.error("Failed inference for CodeFormer: %s", error)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
restored_face = restored_face.astype("uint8")
|
||||
face_helper.add_restored_face(restored_face, cropped_face)
|
||||
|
||||
# upsample the background
|
||||
|
@ -75,13 +100,16 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
|||
else:
|
||||
bg_img = None
|
||||
|
||||
|
||||
# paste_back
|
||||
face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
if face_upsampler is not None:
|
||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler)
|
||||
restored_img = face_helper.paste_faces_to_input_image(
|
||||
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
|
||||
)
|
||||
else:
|
||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
|
||||
restored_img = face_helper.paste_faces_to_input_image(
|
||||
upsample_img=bg_img, draw_box=False
|
||||
)
|
||||
|
||||
return restored_img
|
|
@ -1,27 +1,16 @@
|
|||
from gfpgan import GFPGANer
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Optional
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
run_gc,
|
||||
ServerContext,
|
||||
)
|
||||
from .upscale_resrgan import (
|
||||
load_resrgan,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from gfpgan import GFPGANer
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext, run_gc
|
||||
from .upscale_resrgan import load_resrgan
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +19,9 @@ last_pipeline_instance = None
|
|||
last_pipeline_params = None
|
||||
|
||||
|
||||
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None):
|
||||
def load_gfpgan(
|
||||
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
|
@ -38,22 +29,22 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
|
|||
bg_upscale = upscale.rescale(upscale.outscale)
|
||||
upsampler = load_resrgan(ctx, bg_upscale)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' %
|
||||
(upscale.correction_model))
|
||||
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
|
||||
|
||||
if last_pipeline_instance != None and face_path == last_pipeline_params:
|
||||
logger.info('reusing existing GFPGAN pipeline')
|
||||
if last_pipeline_instance is not None and face_path == last_pipeline_params:
|
||||
logger.info("reusing existing GFPGAN pipeline")
|
||||
return last_pipeline_instance
|
||||
|
||||
logger.debug('loading GFPGAN model from %s', face_path)
|
||||
logger.debug("loading GFPGAN model from %s", face_path)
|
||||
|
||||
# TODO: find a way to pass the ONNX model to underlying architectures
|
||||
gfpgan = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=upscale.outscale,
|
||||
arch='clean',
|
||||
arch="clean",
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
bg_upsampler=upsampler,
|
||||
)
|
||||
|
||||
last_pipeline_instance = gfpgan
|
||||
last_pipeline_params = face_path
|
||||
|
@ -74,15 +65,20 @@ def correct_gfpgan(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
if upscale.correction_model is None:
|
||||
logger.warn('no face model given, skipping')
|
||||
logger.warn("no face model given, skipping")
|
||||
return source_image
|
||||
|
||||
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
|
||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
||||
|
||||
output = np.array(source_image)
|
||||
_, _, output = gfpgan.enhance(
|
||||
output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
output,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=upscale.face_strength,
|
||||
)
|
||||
output = Image.fromarray(output, "RGB")
|
||||
|
||||
return output
|
||||
|
|
|
@ -1,26 +1,18 @@
|
|||
from logging import getLogger
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..output import (
|
||||
save_image,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_disk(
|
||||
_job: JobContext,
|
||||
ctx: ServerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -28,6 +20,6 @@ def persist_disk(
|
|||
output: str,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
dest = save_image(ctx, output, source_image)
|
||||
logger.info('saved image to %s', dest)
|
||||
dest = save_image(server, output, source_image)
|
||||
logger.info("saved image to %s", dest)
|
||||
return source_image
|
||||
|
|
|
@ -1,26 +1,19 @@
|
|||
from boto3 import (
|
||||
Session,
|
||||
)
|
||||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
|
||||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_s3(
|
||||
ctx: ServerContext,
|
||||
_job: JobContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -32,16 +25,16 @@ def persist_s3(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
data = BytesIO()
|
||||
source_image.save(data, format=ctx.image_format)
|
||||
source_image.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
|
||||
try:
|
||||
s3.upload_fileobj(data, bucket, output)
|
||||
logger.info('saved image to %s/%s', bucket, output)
|
||||
logger.info("saved image to %s/%s", bucket, output)
|
||||
except Exception as err:
|
||||
logger.error('error saving image to S3: %s', err)
|
||||
logger.error("error saving image to S3: %s", err)
|
||||
|
||||
return source_image
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
from logging import getLogger
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_crop(
|
||||
ctx: ServerContext,
|
||||
_job: JobContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -26,8 +20,6 @@ def reduce_crop(
|
|||
size: Size,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
image = source_image.crop(
|
||||
(origin.width, origin.height, size.width, size.height))
|
||||
logger.info('created thumbnail with dimensions: %sx%s',
|
||||
image.width, image.height)
|
||||
image = source_image.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||
return image
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
from logging import getLogger
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_thumbnail(
|
||||
ctx: ServerContext,
|
||||
_job: JobContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
|
@ -26,6 +20,5 @@ def reduce_thumbnail(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
image = source_image.thumbnail((size.width, size.height))
|
||||
logger.info('created thumbnail with dimensions: %sx%s',
|
||||
image.width, image.height)
|
||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||
return image
|
||||
|
|
|
@ -1,26 +1,19 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from typing import Callable
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_noise(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
_job: JobContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
|
@ -28,14 +21,12 @@ def source_noise(
|
|||
noise_source: Callable,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image from noise source')
|
||||
logger.info("generating image from noise source")
|
||||
|
||||
if source_image is not None:
|
||||
logger.warn(
|
||||
'a source image was passed to a noise stage, but will be discarded')
|
||||
logger.warn("a source image was passed to a noise stage, but will be discarded")
|
||||
|
||||
output = noise_source(source_image, (size.width, size.height), (0, 0))
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -1,26 +1,13 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..diffusion.load import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -37,13 +24,16 @@ def source_txt2img(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
|
||||
logger.info("generating image using txt2img, %s steps: %s", params.steps, prompt)
|
||||
|
||||
if source_image is not None:
|
||||
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
||||
logger.warn(
|
||||
"a source image was passed to a txt2img stage, but will be discarded"
|
||||
)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -60,5 +50,5 @@ def source_txt2img(
|
|||
)
|
||||
output = result.images[0]
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -1,43 +1,17 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image, ImageDraw
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..diffusion.load import (
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
load_pipeline,
|
||||
)
|
||||
from ..image import (
|
||||
expand_image,
|
||||
mask_filter_none,
|
||||
noise_source_histogram,
|
||||
)
|
||||
from ..params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
)
|
||||
from ..output import (
|
||||
save_image,
|
||||
)
|
||||
from ..utils import (
|
||||
base_join,
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tile_spiral,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_spiral
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -52,17 +26,17 @@ def upscale_outpaint(
|
|||
border: Border,
|
||||
prompt: str = None,
|
||||
mask_image: Image.Image = None,
|
||||
fill_color: str = 'white',
|
||||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('upscaling image by expanding borders: %s', border)
|
||||
logger.info("upscaling image by expanding borders: %s", border)
|
||||
|
||||
if mask_image is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
mask_image = Image.new('RGB', source_image.size, 'black')
|
||||
mask_image = Image.new("RGB", source_image.size, "black")
|
||||
|
||||
source_image, mask_image, noise_image, full_dims = expand_image(
|
||||
source_image,
|
||||
|
@ -70,16 +44,17 @@ def upscale_outpaint(
|
|||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter)
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
draw_mask = ImageDraw.Draw(mask_image)
|
||||
full_size = Size(*full_dims)
|
||||
full_latents = get_latents_from_seed(params.seed, full_size)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'last-source.png', source_image)
|
||||
save_image(server, 'last-mask.png', mask_image)
|
||||
save_image(server, 'last-noise.png', noise_image)
|
||||
save_image(server, "last-source.png", source_image)
|
||||
save_image(server, "last-mask.png", mask_image)
|
||||
save_image(server, "last-noise.png", noise_image)
|
||||
|
||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
|
@ -87,11 +62,15 @@ def upscale_outpaint(
|
|||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
save_image(server, 'tile-source.png', image)
|
||||
save_image(server, 'tile-mask.png', mask)
|
||||
save_image(server, "tile-source.png", image)
|
||||
save_image(server, "tile-mask.png", mask)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
latents = get_tile_latents(full_latents, dims)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -110,10 +89,10 @@ def upscale_outpaint(
|
|||
)
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black')
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||
return result.images[0]
|
||||
|
||||
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -1,27 +1,15 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..onnx import (
|
||||
OnnxNet,
|
||||
)
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
run_gc,
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from ..device_pool import JobContext
|
||||
from ..onnx import OnnxNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -29,39 +17,50 @@ last_pipeline_instance = None
|
|||
last_pipeline_params = (None, None)
|
||||
|
||||
|
||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
|
||||
def load_resrgan(
|
||||
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
if not path.isfile(model_path):
|
||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||
|
||||
cache_params = (model_path, params.format)
|
||||
if last_pipeline_instance != None and cache_params == last_pipeline_params:
|
||||
logger.info('reusing existing Real ESRGAN pipeline')
|
||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||
logger.info("reusing existing Real ESRGAN pipeline")
|
||||
return last_pipeline_instance
|
||||
|
||||
# use ONNX acceleration, if available
|
||||
if params.format == 'onnx':
|
||||
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
|
||||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
raise Exception('unknown platform %s' % params.format)
|
||||
if params.format == "onnx":
|
||||
model = OnnxNet(
|
||||
ctx, model_file, provider=device.provider, provider_options=device.options
|
||||
)
|
||||
elif params.format == "pth":
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=params.scale,
|
||||
)
|
||||
raise Exception("unknown platform %s" % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
if params.upscale_model == "realesr-general-x4v3" and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(
|
||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||
"realesr-general-x4v3", "realesr-general-wdn-x4v3"
|
||||
)
|
||||
model_path = [model_path, wdn_model_path]
|
||||
dni_weight = [params.denoise, 1 - params.denoise]
|
||||
|
||||
logger.debug('loading Real ESRGAN upscale model from %s', model_path)
|
||||
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||
|
||||
# TODO: shouldn't need the PTH file
|
||||
model_path_pth = path.join(ctx.model_path, '%s.pth' % params.upscale_model)
|
||||
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
|
||||
upsampler = RealESRGANer(
|
||||
scale=params.scale,
|
||||
model_path=model_path_pth,
|
||||
|
@ -70,7 +69,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
|
|||
tile=tile,
|
||||
tile_pad=params.tile_pad,
|
||||
pre_pad=params.pre_pad,
|
||||
half=params.half)
|
||||
half=params.half,
|
||||
)
|
||||
|
||||
last_pipeline_instance = upsampler
|
||||
last_pipeline_params = cache_params
|
||||
|
@ -89,13 +89,13 @@ def upscale_resrgan(
|
|||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
|
||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||
|
||||
output = np.array(source_image)
|
||||
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
output = Image.fromarray(output, "RGB")
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -1,28 +1,16 @@
|
|||
from diffusers import (
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionUpscalePipeline
|
||||
from PIL import Image
|
||||
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
run_gc,
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import torch
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext, run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -31,23 +19,37 @@ last_pipeline_instance = None
|
|||
last_pipeline_params = (None, None)
|
||||
|
||||
|
||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
|
||||
def load_stable_diffusion(
|
||||
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||
cache_params = (model_path, upscale.format)
|
||||
|
||||
if last_pipeline_instance != None and cache_params == last_pipeline_params:
|
||||
logger.info('reusing existing Stable Diffusion upscale pipeline')
|
||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||
logger.info("reusing existing Stable Diffusion upscale pipeline")
|
||||
return last_pipeline_instance
|
||||
|
||||
if upscale.format == 'onnx':
|
||||
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
|
||||
if upscale.format == "onnx":
|
||||
logger.debug(
|
||||
"loading Stable Diffusion upscale ONNX model from %s, using provider %s",
|
||||
model_path,
|
||||
device.provider,
|
||||
)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path, provider=device.provider, provider_options=device.options
|
||||
)
|
||||
else:
|
||||
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
|
||||
logger.debug(
|
||||
"loading Stable Diffusion upscale model from %s, using provider %s",
|
||||
model_path,
|
||||
device.provider,
|
||||
)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path, provider=device.provider
|
||||
)
|
||||
|
||||
last_pipeline_instance = pipeline
|
||||
last_pipeline_params = cache_params
|
||||
|
@ -68,7 +70,7 @@ def upscale_stable_diffusion(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
||||
logger.info("upscaling with Stable Diffusion, %s steps: %s", params.steps, prompt)
|
||||
|
||||
pipeline = load_stable_diffusion(server, upscale, job.get_device())
|
||||
generator = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from typing import List, Protocol, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -17,7 +18,7 @@ def process_tile_grid(
|
|||
filters: List[TileCallback],
|
||||
) -> Image.Image:
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
image = Image.new("RGB", (width * scale, height * scale))
|
||||
|
||||
tiles_x = width // tile
|
||||
tiles_y = height // tile
|
||||
|
@ -28,7 +29,7 @@ def process_tile_grid(
|
|||
idx = (y * tiles_x) + x
|
||||
left = x * tile
|
||||
top = y * tile
|
||||
logger.info('processing tile %s of %s, %s.%s', idx + 1, total, y, x)
|
||||
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
|
@ -47,10 +48,10 @@ def process_tile_spiral(
|
|||
overlap: float = 0.5,
|
||||
) -> Image.Image:
|
||||
if scale != 1:
|
||||
raise Exception('unsupported scale')
|
||||
raise Exception("unsupported scale")
|
||||
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
image = Image.new("RGB", (width * scale, height * scale))
|
||||
image.paste(source, (0, 0, width, height))
|
||||
|
||||
center_x = (width // 2) - (tile // 2)
|
||||
|
@ -76,7 +77,7 @@ def process_tile_spiral(
|
|||
top = center_y + int(top)
|
||||
|
||||
counter += 1
|
||||
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top)
|
||||
logger.info("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
|
||||
|
||||
# TODO: only valid for scale == 1, resize source for others
|
||||
tile_image = image.crop((left, top, left + tile, top + tile))
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
from . import logging
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from os import environ, makedirs, mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from diffusers import (
|
||||
|
@ -8,25 +17,20 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from onnx import load, save_model
|
||||
from os import environ, makedirs, mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from . import logging
|
||||
|
||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||
warnings.filterwarnings(
|
||||
'ignore', '.*The shape inference of prim::Constant type is missing.*')
|
||||
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
||||
warnings.filterwarnings(
|
||||
'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
||||
"ignore",
|
||||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
|
||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||
|
||||
|
@ -35,74 +39,95 @@ logger = getLogger(__name__)
|
|||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
'diffusion': [
|
||||
"diffusion": [
|
||||
# v1.x
|
||||
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
||||
('stable-diffusion-onnx-v1-inpainting',
|
||||
'runwayml/stable-diffusion-inpainting'),
|
||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||
# v2.x
|
||||
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
||||
('stable-diffusion-onnx-v2-inpainting',
|
||||
'stabilityai/stable-diffusion-2-inpainting'),
|
||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||
(
|
||||
"stable-diffusion-onnx-v2-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
),
|
||||
# TODO: should have its own converter
|
||||
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
|
||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||
],
|
||||
'correction': [
|
||||
('correction-gfpgan-v1-3',
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
|
||||
"correction": [
|
||||
(
|
||||
"correction-gfpgan-v1-3",
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
||||
4,
|
||||
),
|
||||
],
|
||||
'upscaling': [
|
||||
('upscaling-real-esrgan-x2-plus',
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
|
||||
('upscaling-real-esrgan-x4-plus',
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
|
||||
('upscaling-real-esrgan-x4-v3',
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4),
|
||||
"upscaling": [
|
||||
(
|
||||
"upscaling-real-esrgan-x2-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
2,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
4,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-v3",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
4,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models'))
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
map_location = torch.device(training_device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name + '.pth')
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx)
|
||||
dest_path = path.join(model_path, name + ".pth")
|
||||
dest_onnx = path.join(model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info('ONNX model already exists, skipping.')
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info('PTH model not found, downloading...')
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=scale)
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
if 'params_ema' in torch_model:
|
||||
model.load_state_dict(torch_model['params_ema'])
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
model.load_state_dict(torch_model['params'], strict=False)
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ['data']
|
||||
output_names = ['output']
|
||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
'output': {2: 'width', 3: 'height'}}
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info('exporting ONNX model to %s', dest_onnx)
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
|
@ -111,48 +136,57 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
|||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True
|
||||
export_params=True,
|
||||
)
|
||||
logger.info('Real ESRGAN exported to ONNX successfully.')
|
||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name + '.pth')
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx)
|
||||
dest_path = path.join(model_path, name + ".pth")
|
||||
dest_onnx = path.join(model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info('ONNX model already exists, skipping.')
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info('PTH model not found, downloading...')
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=scale)
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
if 'params_ema' in torch_model:
|
||||
model.load_state_dict(torch_model['params_ema'], strict=False)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model['params'], strict=False)
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ['data']
|
||||
output_names = ['output']
|
||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
'output': {2: 'width', 3: 'height'}}
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info('exporting ONNX model to %s', dest_onnx)
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
|
@ -161,9 +195,9 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
|||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True
|
||||
export_params=True,
|
||||
)
|
||||
logger.info('GFPGAN exported to ONNX successfully.')
|
||||
logger.info("GFPGAN exported to ONNX successfully.")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
|
@ -176,9 +210,9 @@ def onnx_export(
|
|||
opset,
|
||||
use_external_data_format=False,
|
||||
):
|
||||
'''
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
'''
|
||||
"""
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
export(
|
||||
|
@ -194,29 +228,33 @@ def onnx_export(
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
|
||||
'''
|
||||
def convert_diffuser(
|
||||
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
'''
|
||||
"""
|
||||
dtype = torch.float16 if half else torch.float32
|
||||
dest_path = path.join(model_path, name)
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
||||
logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
|
||||
|
||||
if single_vae:
|
||||
logger.info('converting model with single VAE')
|
||||
logger.info("converting model with single VAE")
|
||||
|
||||
if path.isdir(dest_path):
|
||||
logger.info('ONNX model already exists, skipping.')
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if half and training_device != 'cuda':
|
||||
if half and training_device != "cuda":
|
||||
raise ValueError(
|
||||
'Half precision model export is only supported on GPUs with CUDA')
|
||||
"Half precision model export is only supported on GPUs with CUDA"
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
url, torch_dtype=dtype, use_auth_token=token).to(training_device)
|
||||
url, torch_dtype=dtype, use_auth_token=token
|
||||
).to(training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
|
@ -232,8 +270,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(
|
||||
device=training_device, dtype=torch.int32)),
|
||||
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
|
@ -244,7 +281,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
)
|
||||
del pipeline.text_encoder
|
||||
|
||||
logger.debug('UNET config: %s', pipeline.unet.config)
|
||||
logger.debug("UNET config: %s", pipeline.unet.config)
|
||||
|
||||
# UNET
|
||||
if single_vae:
|
||||
|
@ -262,10 +299,12 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
device=training_device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
device=training_device, dtype=dtype
|
||||
),
|
||||
unet_scale,
|
||||
),
|
||||
output_path=unet_path,
|
||||
|
@ -298,7 +337,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
del pipeline.unet
|
||||
|
||||
if single_vae:
|
||||
logger.debug('VAE config: %s', pipeline.vae.config)
|
||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||
|
||||
# SINGLE VAE
|
||||
vae_only = pipeline.vae
|
||||
|
@ -309,8 +348,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
onnx_export(
|
||||
vae_only,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / "model.onnx",
|
||||
|
@ -328,12 +368,14 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
sample, return_dict
|
||||
)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
device=training_device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
|
@ -354,8 +396,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
|
@ -385,7 +428,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
clip_image_size,
|
||||
).to(device=training_device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
||||
device=training_device, dtype=dtype),
|
||||
device=training_device, dtype=dtype
|
||||
),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
|
@ -398,7 +442,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "safety_checker")
|
||||
output_path / "safety_checker"
|
||||
)
|
||||
feature_extractor = pipeline.feature_extractor
|
||||
else:
|
||||
safety_checker = None
|
||||
|
@ -406,10 +451,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
|
||||
if single_vae:
|
||||
onnx_pipeline = StableDiffusionUpscalePipeline(
|
||||
vae=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
low_res_scheduler=pipeline.scheduler,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
|
@ -417,12 +460,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
)
|
||||
else:
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_encoder"),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
|
@ -431,7 +471,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
logger.info('exporting ONNX model')
|
||||
logger.info("exporting ONNX model")
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
logger.info("ONNX pipeline saved to %s", output_path)
|
||||
|
@ -445,90 +485,93 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
|||
)
|
||||
else:
|
||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider")
|
||||
output_path, provider="CPUExecutionProvider"
|
||||
)
|
||||
|
||||
logger.info("ONNX pipeline is loadable")
|
||||
|
||||
|
||||
def load_models(args, models: Models):
|
||||
if args.diffusion:
|
||||
for source in models.get('diffusion'):
|
||||
for source in models.get("diffusion"):
|
||||
if source[0] in args.skip:
|
||||
logger.info('Skipping model: %s', source[0])
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
single_vae = 'upscaling' in source[0]
|
||||
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
|
||||
single_vae = "upscaling" in source[0]
|
||||
convert_diffuser(
|
||||
*source, args.opset, args.half, args.token, single_vae=single_vae
|
||||
)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get('upscaling'):
|
||||
for source in models.get("upscaling"):
|
||||
if source[0] in args.skip:
|
||||
logger.info('Skipping model: %s', source[0])
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_real_esrgan(*source, args.opset)
|
||||
|
||||
if args.correction:
|
||||
for source in models.get('correction'):
|
||||
for source in models.get("correction"):
|
||||
if source[0] in args.skip:
|
||||
logger.info('Skipping model: %s', source[0])
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_gfpgan(*source, args.opset)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser(
|
||||
prog='onnx-web model converter',
|
||||
description='convert checkpoint models to ONNX')
|
||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||
)
|
||||
|
||||
# model groups
|
||||
parser.add_argument('--correction', action='store_true', default=False)
|
||||
parser.add_argument('--diffusion', action='store_true', default=False)
|
||||
parser.add_argument('--upscaling', action='store_true', default=False)
|
||||
parser.add_argument("--correction", action="store_true", default=False)
|
||||
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||
|
||||
# extra models
|
||||
parser.add_argument('--extras', nargs='*', type=str, default=[])
|
||||
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||
|
||||
# export options
|
||||
parser.add_argument(
|
||||
'--half',
|
||||
action='store_true',
|
||||
"--half",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='Export models for half precision, faster on some Nvidia cards.'
|
||||
help="Export models for half precision, faster on some Nvidia cards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opset',
|
||||
"--opset",
|
||||
default=14,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--token',
|
||||
"--token",
|
||||
type=str,
|
||||
help="HuggingFace token with read permissions for downloading models.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info('CLI arguments: %s', args)
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
if not path.exists(model_path):
|
||||
logger.info('Model path does not existing, creating: %s', model_path)
|
||||
logger.info("Model path does not existing, creating: %s", model_path)
|
||||
makedirs(model_path)
|
||||
|
||||
logger.info('Converting base models.')
|
||||
logger.info("Converting base models.")
|
||||
load_models(args, base_models)
|
||||
|
||||
for file in args.extras:
|
||||
logger.info('Loading extra models from %s', file)
|
||||
logger.info("Loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
with open(file, "r") as f:
|
||||
data = loads(f.read())
|
||||
logger.info('Converting extra models.')
|
||||
logger.info("Converting extra models.")
|
||||
load_models(args, data)
|
||||
except Exception as err:
|
||||
logger.error('Error converting extra models: %s', err)
|
||||
logger.error("Error converting extra models: %s", err)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from collections import Counter
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from multiprocessing import Value
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
from .params import (
|
||||
DeviceParams,
|
||||
)
|
||||
from .params import DeviceParams
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -28,24 +26,24 @@ class JobContext:
|
|||
):
|
||||
self.key = key
|
||||
self.devices = list(devices)
|
||||
self.cancel = Value('B', cancel)
|
||||
self.device_index = Value('i', device_index)
|
||||
self.progress = Value('I', progress)
|
||||
self.cancel = Value("B", cancel)
|
||||
self.device_index = Value("i", device_index)
|
||||
self.progress = Value("I", progress)
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def get_device(self) -> DeviceParams:
|
||||
'''
|
||||
"""
|
||||
Get the device assigned to this job.
|
||||
'''
|
||||
"""
|
||||
with self.device_index.get_lock():
|
||||
device_index = self.device_index.value
|
||||
if device_index < 0:
|
||||
raise Exception('job has not been assigned to a device')
|
||||
raise Exception("job has not been assigned to a device")
|
||||
else:
|
||||
device = self.devices[device_index]
|
||||
logger.debug('job %s assigned to device %s', self.key, device)
|
||||
logger.debug("job %s assigned to device %s", self.key, device)
|
||||
return device
|
||||
|
||||
def get_progress(self) -> int:
|
||||
|
@ -54,10 +52,9 @@ class JobContext:
|
|||
def get_progress_callback(self) -> Callable[..., None]:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
if self.is_cancelled():
|
||||
raise Exception('job has been cancelled')
|
||||
raise Exception("job has been cancelled")
|
||||
else:
|
||||
logger.debug('setting progress for job %s to %s',
|
||||
self.key, step)
|
||||
logger.debug("setting progress for job %s to %s", self.key, step)
|
||||
self.set_progress(step)
|
||||
|
||||
return on_progress
|
||||
|
@ -72,9 +69,9 @@ class JobContext:
|
|||
|
||||
|
||||
class Job:
|
||||
'''
|
||||
"""
|
||||
Link a future to its context.
|
||||
'''
|
||||
"""
|
||||
|
||||
context: JobContext = None
|
||||
future: Future = None
|
||||
|
@ -106,7 +103,11 @@ class DevicePoolExecutor:
|
|||
next_device: int = 0
|
||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
||||
|
||||
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
devices: List[DeviceParams],
|
||||
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
|
||||
):
|
||||
self.devices = devices
|
||||
self.jobs = []
|
||||
self.next_device = 0
|
||||
|
@ -114,19 +115,25 @@ class DevicePoolExecutor:
|
|||
device_count = len(devices)
|
||||
if pool is None:
|
||||
logger.info(
|
||||
'creating thread pool executor for %s devices: %s', device_count, [d.device for d in devices])
|
||||
"creating thread pool executor for %s devices: %s",
|
||||
device_count,
|
||||
[d.device for d in devices],
|
||||
)
|
||||
self.pool = ThreadPoolExecutor(device_count)
|
||||
else:
|
||||
logger.info('using existing pool for %s devices: %s',
|
||||
device_count, [d.device for d in devices])
|
||||
logger.info(
|
||||
"using existing pool for %s devices: %s",
|
||||
device_count,
|
||||
[d.device for d in devices],
|
||||
)
|
||||
self.pool = pool
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
'''
|
||||
"""
|
||||
Cancel a job. If the job has not been started, this will cancel
|
||||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
'''
|
||||
"""
|
||||
for job in self.jobs:
|
||||
if job.key == key:
|
||||
if job.future.cancel():
|
||||
|
@ -144,7 +151,7 @@ class DevicePoolExecutor:
|
|||
progress = job.get_progress()
|
||||
return (done, progress)
|
||||
|
||||
logger.warn('checking status for unknown key: %s', key)
|
||||
logger.warn("checking status for unknown key: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
def get_next_device(self):
|
||||
|
@ -152,12 +159,14 @@ class DevicePoolExecutor:
|
|||
if len(self.jobs) == 0:
|
||||
return 0
|
||||
|
||||
job_devices = [job.context.device_index.value for job in self.jobs if not job.future.done()]
|
||||
job_devices = [
|
||||
job.context.device_index.value for job in self.jobs if not job.future.done()
|
||||
]
|
||||
job_counts = Counter(range(len(self.devices)))
|
||||
job_counts.update(job_devices)
|
||||
|
||||
queued = job_counts.most_common()
|
||||
logger.debug('jobs queued by device: %s', queued)
|
||||
logger.debug("jobs queued by device: %s", queued)
|
||||
|
||||
lowest_count = queued[-1][1]
|
||||
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
||||
|
@ -170,7 +179,7 @@ class DevicePoolExecutor:
|
|||
|
||||
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
||||
device = self.get_next_device()
|
||||
logger.info('assigning job %s to device %s', key, device)
|
||||
logger.info("assigning job %s to device %s", key, device)
|
||||
|
||||
context = JobContext(key, self.devices, device_index=device)
|
||||
future = self.pool.submit(fn, context, *args, **kwargs)
|
||||
|
@ -180,11 +189,19 @@ class DevicePoolExecutor:
|
|||
def job_done(f: Future):
|
||||
try:
|
||||
f.result()
|
||||
logger.info('job %s finished successfully', key)
|
||||
logger.info("job %s finished successfully", key)
|
||||
except Exception as err:
|
||||
logger.warn('job %s failed with an error: %s', key, err)
|
||||
logger.warn("job %s failed with an error: %s", key, err)
|
||||
|
||||
future.add_done_callback(job_done)
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||
return [(job.key, job.context.device_index.value, job.future.done(), job.get_progress()) for job in self.jobs]
|
||||
return [
|
||||
(
|
||||
job.key,
|
||||
job.context.device_index.value,
|
||||
job.future.done(),
|
||||
job.get_progress(),
|
||||
)
|
||||
for job in self.jobs
|
||||
]
|
||||
|
|
|
@ -1,18 +1,11 @@
|
|||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
Size,
|
||||
)
|
||||
from ..utils import (
|
||||
run_gc,
|
||||
)
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from ..params import DeviceParams, Size
|
||||
from ..utils import run_gc
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -25,17 +18,23 @@ latent_factor = 8
|
|||
|
||||
|
||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||
'''
|
||||
"""
|
||||
From https://www.travelneil.com/stable-diffusion-updates.html
|
||||
'''
|
||||
latents_shape = (batch, latent_channels, size.height // latent_factor,
|
||||
size.width // latent_factor)
|
||||
"""
|
||||
latents_shape = (
|
||||
batch,
|
||||
latent_channels,
|
||||
size.height // latent_factor,
|
||||
size.width // latent_factor,
|
||||
)
|
||||
rng = np.random.default_rng(seed)
|
||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||
return image_latents
|
||||
|
||||
|
||||
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
||||
def get_tile_latents(
|
||||
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
||||
) -> np.ndarray:
|
||||
x, y, tile = dims
|
||||
t = tile // latent_factor
|
||||
x = x // latent_factor
|
||||
|
@ -46,27 +45,29 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
|
|||
return full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
|
||||
def load_pipeline(
|
||||
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (pipeline, model, device.provider)
|
||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||
logger.debug('reusing existing diffusion pipeline')
|
||||
if last_pipeline_instance is not None and last_pipeline_options == options:
|
||||
logger.debug("reusing existing diffusion pipeline")
|
||||
pipe = last_pipeline_instance
|
||||
else:
|
||||
logger.debug('unloading previous diffusion pipeline')
|
||||
logger.debug("unloading previous diffusion pipeline")
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_scheduler = None
|
||||
run_gc()
|
||||
|
||||
logger.debug('loading new diffusion pipeline from %s', model)
|
||||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
scheduler = scheduler.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
subfolder='scheduler',
|
||||
subfolder="scheduler",
|
||||
)
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
|
@ -76,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
|||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
if device is not None and hasattr(pipe, 'to'):
|
||||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device)
|
||||
|
||||
last_pipeline_instance = pipe
|
||||
|
@ -84,15 +85,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
|||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
logger.debug('loading new diffusion scheduler')
|
||||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
subfolder='scheduler',
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, 'to'):
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
scheduler = scheduler.to(device)
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
|
|
|
@ -1,23 +1,11 @@
|
|||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
OnnxRuntimeModel,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.pipeline_utils import (
|
||||
ImagePipelineOutput,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from PIL import Image
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -30,6 +18,7 @@ unet_in_channels = 7 # self.unet.config.in_channels
|
|||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
|
||||
###
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
|
@ -63,8 +52,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet,
|
||||
low_res_scheduler, scheduler, max_noise_level)
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
low_res_scheduler,
|
||||
scheduler,
|
||||
max_noise_level,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -96,7 +92,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
|
@ -111,7 +111,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
text_embeddings_dtype = torch.float32
|
||||
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
|
||||
noise = torch.randn(
|
||||
image.shape, generator=generator, device=device, dtype=text_embeddings_dtype
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
|
@ -135,11 +137,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
num_channels_image = image.shape[1]
|
||||
if num_channels_latents + num_channels_image != unet_in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
"Incorrect configuration settings! The config of `pipeline.unet`"
|
||||
f" expects {unet_in_channels} but received `num_channels_latents`:"
|
||||
f" {num_channels_latents} + `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the"
|
||||
" config of `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
|
@ -150,10 +152,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = (
|
||||
np.concatenate([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||
|
||||
# timestep to tensor
|
||||
|
@ -164,19 +172,25 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level
|
||||
class_labels=noise_level,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
@ -200,7 +214,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
|
@ -211,13 +232,20 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
"The following part of your input was truncated because CLIP can only"
|
||||
f" handle sequences up to {self.tokenizer.model_max_length} tokens:"
|
||||
f" {removed_text}"
|
||||
)
|
||||
|
||||
# no positional arguments to text_encoder
|
||||
|
@ -240,16 +268,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
"`negative_prompt` should be the same type to `prompt`, but got"
|
||||
f" {type(negative_prompt)} != {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
f"`negative_prompt`: {negative_prompt} has batch size"
|
||||
f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size"
|
||||
f" {batch_size}. Please make sure that passed `negative_prompt`"
|
||||
" matches the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
|
|
@ -1,41 +1,17 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from PIL import Image, ImageChops
|
||||
from typing import Any
|
||||
|
||||
from ..chain import (
|
||||
upscale_outpaint,
|
||||
)
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Border,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..output import (
|
||||
save_image,
|
||||
save_params,
|
||||
)
|
||||
from ..upscale import (
|
||||
run_upscale_correction,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..utils import (
|
||||
run_gc,
|
||||
ServerContext,
|
||||
)
|
||||
from .load import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from ..chain import upscale_outpaint
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image, save_params
|
||||
from ..params import Border, ImageParams, Size, StageParams
|
||||
from ..upscale import UpscaleParams, run_upscale_correction
|
||||
from ..utils import ServerContext, run_gc
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -46,10 +22,11 @@ def run_txt2img_pipeline(
|
|||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams
|
||||
upscale: UpscaleParams,
|
||||
) -> None:
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||
)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -68,7 +45,8 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale)
|
||||
job, server, StageParams(), params, image, upscale=upscale
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
@ -77,7 +55,7 @@ def run_txt2img_pipeline(
|
|||
del result
|
||||
run_gc()
|
||||
|
||||
logger.info('finished txt2img job: %s', dest)
|
||||
logger.info("finished txt2img job: %s", dest)
|
||||
|
||||
|
||||
def run_img2img_pipeline(
|
||||
|
@ -89,8 +67,12 @@ def run_img2img_pipeline(
|
|||
source_image: Image.Image,
|
||||
strength: float,
|
||||
) -> None:
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.scheduler, job.get_device())
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
|
@ -107,7 +89,8 @@ def run_img2img_pipeline(
|
|||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale)
|
||||
job, server, StageParams(), params, image, upscale=upscale
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
size = Size(*source_image.size)
|
||||
|
@ -117,7 +100,7 @@ def run_img2img_pipeline(
|
|||
del result
|
||||
run_gc()
|
||||
|
||||
logger.info('finished img2img job: %s', dest)
|
||||
logger.info("finished img2img job: %s", dest)
|
||||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
|
@ -151,16 +134,14 @@ def run_inpaint_pipeline(
|
|||
mask_filter=mask_filter,
|
||||
noise_source=noise_source,
|
||||
)
|
||||
logger.info('applying mask filter and generating noise source')
|
||||
logger.info("applying mask filter and generating noise source")
|
||||
|
||||
if image.size == source_image.size:
|
||||
image = ImageChops.blend(source_image, image, strength)
|
||||
else:
|
||||
logger.info(
|
||||
'output image size does not match source, skipping post-blend')
|
||||
logger.info("output image size does not match source, skipping post-blend")
|
||||
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, image, upscale=upscale)
|
||||
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
|
@ -168,7 +149,7 @@ def run_inpaint_pipeline(
|
|||
del image
|
||||
run_gc()
|
||||
|
||||
logger.info('finished inpaint job: %s', dest)
|
||||
logger.info("finished inpaint job: %s", dest)
|
||||
|
||||
|
||||
def run_upscale_pipeline(
|
||||
|
@ -185,7 +166,8 @@ def run_upscale_pipeline(
|
|||
stage = StageParams()
|
||||
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, source_image, upscale=upscale)
|
||||
job, server, stage, params, source_image, upscale=upscale
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
@ -193,4 +175,4 @@ def run_upscale_pipeline(
|
|||
del image
|
||||
run_gc()
|
||||
|
||||
logger.info('finished upscale job: %s', dest)
|
||||
logger.info("finished upscale job: %s", dest)
|
||||
|
|
|
@ -1,31 +1,31 @@
|
|||
import numpy as np
|
||||
from numpy import random
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .params import (
|
||||
Border,
|
||||
Point,
|
||||
)
|
||||
from .params import Border, Point
|
||||
|
||||
|
||||
def get_pixel_index(x: int, y: int, width: int) -> int:
|
||||
return (y * width) + x
|
||||
|
||||
|
||||
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
def mask_filter_none(
|
||||
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new('RGB', (width, height), fill)
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise.paste(mask_image, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
def mask_filter_gaussian_multiply(
|
||||
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur with multiply, source image centered on white canvas.
|
||||
'''
|
||||
"""
|
||||
noise = mask_filter_none(mask_image, dims, origin)
|
||||
|
||||
for i in range(rounds):
|
||||
|
@ -35,10 +35,12 @@ def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin:
|
|||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
def mask_filter_gaussian_screen(
|
||||
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
"""
|
||||
noise = mask_filter_none(mask_image, dims, origin)
|
||||
|
||||
for i in range(rounds):
|
||||
|
@ -48,33 +50,39 @@ def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Po
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
'''
|
||||
def noise_source_fill_edge(
|
||||
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Identity transform, source image centered on white canvas.
|
||||
'''
|
||||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new('RGB', (width, height), fill)
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise.paste(source_image, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
'''
|
||||
def noise_source_fill_mask(
|
||||
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Fill the whole canvas, no source or noise.
|
||||
'''
|
||||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new('RGB', (width, height), fill)
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
def noise_source_gaussian(
|
||||
source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
"""
|
||||
noise = noise_source_uniform(source_image, dims, origin)
|
||||
noise.paste(source_image, origin)
|
||||
|
||||
|
@ -84,7 +92,9 @@ def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point,
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
def noise_source_uniform(
|
||||
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -92,21 +102,19 @@ def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point,
|
|||
noise_g = random.uniform(0, 256, size=size)
|
||||
noise_b = random.uniform(0, 256, size=size)
|
||||
|
||||
noise = Image.new('RGB', (width, height))
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (
|
||||
int(noise_r[i]),
|
||||
int(noise_g[i]),
|
||||
int(noise_b[i])
|
||||
))
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
def noise_source_normal(
|
||||
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -114,21 +122,19 @@ def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, *
|
|||
noise_g = random.normal(128, 32, size=size)
|
||||
noise_b = random.normal(128, 32, size=size)
|
||||
|
||||
noise = Image.new('RGB', (width, height))
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (
|
||||
int(noise_r[i]),
|
||||
int(noise_g[i]),
|
||||
int(noise_b[i])
|
||||
))
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
def noise_source_histogram(
|
||||
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
r, g, b = source_image.split()
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -137,23 +143,22 @@ def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point
|
|||
hist_g = g.histogram()
|
||||
hist_b = b.histogram()
|
||||
|
||||
noise_r = random.choice(256, p=np.divide(
|
||||
np.copy(hist_r), np.sum(hist_r)), size=size)
|
||||
noise_g = random.choice(256, p=np.divide(
|
||||
np.copy(hist_g), np.sum(hist_g)), size=size)
|
||||
noise_b = random.choice(256, p=np.divide(
|
||||
np.copy(hist_b), np.sum(hist_b)), size=size)
|
||||
noise_r = random.choice(
|
||||
256, p=np.divide(np.copy(hist_r), np.sum(hist_r)), size=size
|
||||
)
|
||||
noise_g = random.choice(
|
||||
256, p=np.divide(np.copy(hist_g), np.sum(hist_g)), size=size
|
||||
)
|
||||
noise_b = random.choice(
|
||||
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
||||
)
|
||||
|
||||
noise = Image.new('RGB', (width, height))
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (
|
||||
noise_r[i],
|
||||
noise_g[i],
|
||||
noise_b[i]
|
||||
))
|
||||
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||
|
||||
return noise
|
||||
|
||||
|
@ -163,7 +168,7 @@ def expand_image(
|
|||
source_image: Image.Image,
|
||||
mask_image: Image.Image,
|
||||
expand: Border,
|
||||
fill='white',
|
||||
fill="white",
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
|
@ -173,14 +178,13 @@ def expand_image(
|
|||
dims = (full_width, full_height)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
full_source = Image.new('RGB', dims, fill)
|
||||
full_source = Image.new("RGB", dims, fill)
|
||||
full_source.paste(source_image, origin)
|
||||
|
||||
full_mask = mask_filter(mask_image, dims, origin, fill=fill)
|
||||
full_noise = noise_source(source_image, dims, origin, fill=fill)
|
||||
full_noise = ImageChops.multiply(full_noise, full_mask)
|
||||
|
||||
full_source = Image.composite(
|
||||
full_noise, full_source, full_mask.convert('L'))
|
||||
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
|
||||
|
||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from logging.config import dictConfig
|
||||
from os import environ, path
|
||||
|
||||
from yaml import safe_load
|
||||
|
||||
logging_path = environ.get('ONNX_WEB_LOGGING_PATH', './logging.yaml')
|
||||
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
|
||||
|
||||
# setup logging config before anything else loads
|
||||
try:
|
||||
if path.exists(logging_path):
|
||||
with open(logging_path, 'r') as f:
|
||||
with open(logging_path, "r") as f:
|
||||
config_logging = safe_load(f)
|
||||
dictConfig(config_logging)
|
||||
except Exception as err:
|
||||
print('error loading logging config: %s' % (err))
|
||||
print("error loading logging config: %s" % (err))
|
||||
|
|
|
@ -1,4 +1 @@
|
|||
from .onnx_net import (
|
||||
OnnxImage,
|
||||
OnnxNet,
|
||||
)
|
||||
from .onnx_net import OnnxImage, OnnxNet
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from onnxruntime import InferenceSession
|
||||
from os import path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from ..utils import ServerContext
|
||||
|
||||
class OnnxImage():
|
||||
|
||||
class OnnxImage:
|
||||
def __init__(self, source) -> None:
|
||||
self.source = source
|
||||
self.data = self
|
||||
|
@ -38,28 +37,27 @@ class OnnxImage():
|
|||
return np.shape(self.source)
|
||||
|
||||
|
||||
class OnnxNet():
|
||||
'''
|
||||
class OnnxNet:
|
||||
"""
|
||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
model: str,
|
||||
provider: str = 'DmlExecutionProvider',
|
||||
provider: str = "DmlExecutionProvider",
|
||||
provider_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
model_path = path.join(server.model_path, model)
|
||||
self.session = InferenceSession(
|
||||
model_path, providers=[provider], provider_options=provider_options)
|
||||
model_path, providers=[provider], provider_options=provider_options
|
||||
)
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
||||
return OnnxImage(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
|
|
|
@ -1,22 +1,14 @@
|
|||
from hashlib import sha256
|
||||
from json import dumps
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from struct import pack
|
||||
from time import time
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from .params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Param,
|
||||
Size,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
base_join,
|
||||
ServerContext,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
||||
from .utils import ServerContext, base_join
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -25,13 +17,13 @@ def hash_value(sha, param: Param):
|
|||
if param is None:
|
||||
return
|
||||
elif isinstance(param, float):
|
||||
sha.update(bytearray(pack('!f', param)))
|
||||
sha.update(bytearray(pack("!f", param)))
|
||||
elif isinstance(param, int):
|
||||
sha.update(bytearray(pack('!I', param)))
|
||||
sha.update(bytearray(pack("!I", param)))
|
||||
elif isinstance(param, str):
|
||||
sha.update(param.encode('utf-8'))
|
||||
sha.update(param.encode("utf-8"))
|
||||
else:
|
||||
logger.warn('cannot hash param: %s, %s', param, type(param))
|
||||
logger.warn("cannot hash param: %s, %s", param, type(param))
|
||||
|
||||
|
||||
def json_params(
|
||||
|
@ -42,22 +34,22 @@ def json_params(
|
|||
border: Optional[Border] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
'output': output,
|
||||
'params': params.tojson(),
|
||||
"output": output,
|
||||
"params": params.tojson(),
|
||||
}
|
||||
|
||||
if upscale is not None and border is not None:
|
||||
size = upscale.resize(size.add_border(border))
|
||||
|
||||
if upscale is not None:
|
||||
json['upscale'] = upscale.tojson()
|
||||
json["upscale"] = upscale.tojson()
|
||||
size = upscale.resize(size)
|
||||
|
||||
if border is not None:
|
||||
json['border'] = border.tojson()
|
||||
json["border"] = border.tojson()
|
||||
size = size.add_border(border)
|
||||
|
||||
json['size'] = size.tojson()
|
||||
json["size"] = size.tojson()
|
||||
|
||||
return json
|
||||
|
||||
|
@ -67,7 +59,7 @@ def make_output_name(
|
|||
mode: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[Tuple[Param]] = None
|
||||
extras: Optional[Tuple[Param]] = None,
|
||||
) -> str:
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
@ -87,13 +79,19 @@ def make_output_name(
|
|||
for param in extras:
|
||||
hash_value(sha, param)
|
||||
|
||||
return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format)
|
||||
return "%s_%s_%s_%s.%s" % (
|
||||
mode,
|
||||
params.seed,
|
||||
sha.hexdigest(),
|
||||
now,
|
||||
ctx.image_format,
|
||||
)
|
||||
|
||||
|
||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||
path = base_join(ctx.output_path, output)
|
||||
image.save(path, format=ctx.image_format)
|
||||
logger.debug('saved output image to: %s', path)
|
||||
logger.debug("saved output image to: %s", path)
|
||||
return path
|
||||
|
||||
|
||||
|
@ -105,9 +103,9 @@ def save_params(
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
) -> str:
|
||||
path = base_join(ctx.output_path, '%s.json' % (output))
|
||||
path = base_join(ctx.output_path, "%s.json" % (output))
|
||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
||||
with open(path, 'w') as f:
|
||||
with open(path, "w") as f:
|
||||
f.write(dumps(json))
|
||||
logger.debug('saved image params to: %s', path)
|
||||
logger.debug("saved image params to: %s", path)
|
||||
return path
|
||||
|
|
|
@ -26,14 +26,14 @@ class Border:
|
|||
self.bottom = bottom
|
||||
|
||||
def __str__(self) -> str:
|
||||
return '%s %s %s %s' % (self.left, self.top, self.right, self.bottom)
|
||||
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
'left': self.left,
|
||||
'right': self.right,
|
||||
'top': self.top,
|
||||
'bottom': self.bottom,
|
||||
"left": self.left,
|
||||
"right": self.right,
|
||||
"top": self.top,
|
||||
"bottom": self.bottom,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -47,32 +47,37 @@ class Size:
|
|||
self.height = height
|
||||
|
||||
def __str__(self) -> str:
|
||||
return '%sx%s' % (self.width, self.height)
|
||||
return "%sx%s" % (self.width, self.height)
|
||||
|
||||
def add_border(self, border: Border):
|
||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||
return Size(
|
||||
border.left + self.width + border.right,
|
||||
border.top + self.height + border.right,
|
||||
)
|
||||
|
||||
def tojson(self) -> Dict[str, int]:
|
||||
return {
|
||||
'height': self.height,
|
||||
'width': self.width,
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
}
|
||||
|
||||
|
||||
class DeviceParams:
|
||||
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
|
||||
def __init__(
|
||||
self, device: str, provider: str, options: Optional[dict] = None
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.provider = provider
|
||||
self.options = options
|
||||
|
||||
def __str__(self) -> str:
|
||||
return '%s - %s (%s)' % (self.device, self.provider, self.options)
|
||||
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||
|
||||
def torch_device(self) -> str:
|
||||
if self.device.startswith('cuda'):
|
||||
if self.device.startswith("cuda"):
|
||||
return self.device
|
||||
else:
|
||||
return 'cpu'
|
||||
return "cpu"
|
||||
|
||||
|
||||
class ImageParams:
|
||||
|
@ -84,7 +89,7 @@ class ImageParams:
|
|||
negative_prompt: Optional[str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int
|
||||
seed: int,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.scheduler = scheduler
|
||||
|
@ -96,20 +101,20 @@ class ImageParams:
|
|||
|
||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||
return {
|
||||
'model': self.model,
|
||||
'scheduler': self.scheduler.__name__,
|
||||
'seed': self.seed,
|
||||
'prompt': self.prompt,
|
||||
'cfg': self.cfg,
|
||||
'negativePrompt': self.negative_prompt,
|
||||
'steps': self.steps,
|
||||
"model": self.model,
|
||||
"scheduler": self.scheduler.__name__,
|
||||
"seed": self.seed,
|
||||
"prompt": self.prompt,
|
||||
"cfg": self.cfg,
|
||||
"negativePrompt": self.negative_prompt,
|
||||
"steps": self.steps,
|
||||
}
|
||||
|
||||
|
||||
class StageParams:
|
||||
'''
|
||||
"""
|
||||
Parameters for a chained pipeline stage
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -123,7 +128,7 @@ class StageParams:
|
|||
self.outscale = outscale
|
||||
|
||||
|
||||
class UpscaleParams():
|
||||
class UpscaleParams:
|
||||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
|
@ -131,7 +136,7 @@ class UpscaleParams():
|
|||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_strength: float = 0.5,
|
||||
format: Literal['onnx', 'pth'] = 'onnx',
|
||||
format: Literal["onnx", "pth"] = "onnx",
|
||||
half=False,
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
|
@ -170,8 +175,8 @@ class UpscaleParams():
|
|||
|
||||
def tojson(self):
|
||||
return {
|
||||
'model': self.upscale_model,
|
||||
'scale': self.scale,
|
||||
'outscale': self.outscale,
|
||||
"model": self.upscale_model,
|
||||
"scale": self.scale,
|
||||
"outscale": self.outscale,
|
||||
# TODO: add more
|
||||
}
|
||||
|
|
|
@ -1,62 +1,61 @@
|
|||
from . import logging
|
||||
import gc
|
||||
from functools import cmp_to_key
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KarrasVeScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||
from flask_cors import CORS
|
||||
from functools import cmp_to_key
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from jsonschema import validate
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from onnxruntime import get_available_providers
|
||||
from os import makedirs, path
|
||||
from typing import List, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from . import logging
|
||||
from .chain import (
|
||||
ChainPipeline,
|
||||
blend_img2img,
|
||||
blend_inpaint,
|
||||
correct_gfpgan,
|
||||
persist_disk,
|
||||
persist_s3,
|
||||
reduce_thumbnail,
|
||||
reduce_crop,
|
||||
reduce_thumbnail,
|
||||
source_noise,
|
||||
source_txt2img,
|
||||
upscale_outpaint,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
ChainPipeline,
|
||||
)
|
||||
from .device_pool import (
|
||||
DeviceParams,
|
||||
DevicePoolExecutor,
|
||||
)
|
||||
from .device_pool import DevicePoolExecutor
|
||||
from .diffusion.run import (
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from .image import (
|
||||
# mask filters
|
||||
from .image import ( # mask filters; noise sources
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
mask_filter_none,
|
||||
# noise sources
|
||||
noise_source_fill_edge,
|
||||
noise_source_fill_mask,
|
||||
noise_source_gaussian,
|
||||
|
@ -64,35 +63,20 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .output import (
|
||||
json_params,
|
||||
make_output_name,
|
||||
)
|
||||
from .params import (
|
||||
Border,
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .output import json_params, make_output_name
|
||||
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
base_join,
|
||||
is_debug,
|
||||
get_and_clamp_float,
|
||||
get_and_clamp_int,
|
||||
get_from_list,
|
||||
get_from_map,
|
||||
get_not_empty,
|
||||
get_size,
|
||||
ServerContext,
|
||||
is_debug,
|
||||
)
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
# config caching
|
||||
|
@ -100,53 +84,53 @@ config_params = {}
|
|||
|
||||
# pipeline params
|
||||
platform_providers = {
|
||||
'amd': 'DmlExecutionProvider',
|
||||
'cpu': 'CPUExecutionProvider',
|
||||
'cuda': 'CUDAExecutionProvider',
|
||||
'directml': 'DmlExecutionProvider',
|
||||
'nvidia': 'CUDAExecutionProvider',
|
||||
'rocm': 'ROCMExecutionProvider',
|
||||
"amd": "DmlExecutionProvider",
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"nvidia": "CUDAExecutionProvider",
|
||||
"rocm": "ROCMExecutionProvider",
|
||||
}
|
||||
pipeline_schedulers = {
|
||||
'ddim': DDIMScheduler,
|
||||
'ddpm': DDPMScheduler,
|
||||
'dpm-multi': DPMSolverMultistepScheduler,
|
||||
'dpm-single': DPMSolverSinglestepScheduler,
|
||||
'euler': EulerDiscreteScheduler,
|
||||
'euler-a': EulerAncestralDiscreteScheduler,
|
||||
'heun': HeunDiscreteScheduler,
|
||||
'k-dpm-2-a': KDPM2AncestralDiscreteScheduler,
|
||||
'k-dpm-2': KDPM2DiscreteScheduler,
|
||||
'karras-ve': KarrasVeScheduler,
|
||||
'lms-discrete': LMSDiscreteScheduler,
|
||||
'pndm': PNDMScheduler,
|
||||
"ddim": DDIMScheduler,
|
||||
"ddpm": DDPMScheduler,
|
||||
"dpm-multi": DPMSolverMultistepScheduler,
|
||||
"dpm-single": DPMSolverSinglestepScheduler,
|
||||
"euler": EulerDiscreteScheduler,
|
||||
"euler-a": EulerAncestralDiscreteScheduler,
|
||||
"heun": HeunDiscreteScheduler,
|
||||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||
"karras-ve": KarrasVeScheduler,
|
||||
"lms-discrete": LMSDiscreteScheduler,
|
||||
"pndm": PNDMScheduler,
|
||||
}
|
||||
noise_sources = {
|
||||
'fill-edge': noise_source_fill_edge,
|
||||
'fill-mask': noise_source_fill_mask,
|
||||
'gaussian': noise_source_gaussian,
|
||||
'histogram': noise_source_histogram,
|
||||
'normal': noise_source_normal,
|
||||
'uniform': noise_source_uniform,
|
||||
"fill-edge": noise_source_fill_edge,
|
||||
"fill-mask": noise_source_fill_mask,
|
||||
"gaussian": noise_source_gaussian,
|
||||
"histogram": noise_source_histogram,
|
||||
"normal": noise_source_normal,
|
||||
"uniform": noise_source_uniform,
|
||||
}
|
||||
mask_filters = {
|
||||
'none': mask_filter_none,
|
||||
'gaussian-multiply': mask_filter_gaussian_multiply,
|
||||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
"none": mask_filter_none,
|
||||
"gaussian-multiply": mask_filter_gaussian_multiply,
|
||||
"gaussian-screen": mask_filter_gaussian_screen,
|
||||
}
|
||||
chain_stages = {
|
||||
'blend-img2img': blend_img2img,
|
||||
'blend-inpaint': blend_inpaint,
|
||||
'correct-gfpgan': correct_gfpgan,
|
||||
'persist-disk': persist_disk,
|
||||
'persist-s3': persist_s3,
|
||||
'reduce-crop': reduce_crop,
|
||||
'reduce-thumbnail': reduce_thumbnail,
|
||||
'source-noise': source_noise,
|
||||
'source-txt2img': source_txt2img,
|
||||
'upscale-outpaint': upscale_outpaint,
|
||||
'upscale-resrgan': upscale_resrgan,
|
||||
'upscale-stable-diffusion': upscale_stable_diffusion,
|
||||
"blend-img2img": blend_img2img,
|
||||
"blend-inpaint": blend_inpaint,
|
||||
"correct-gfpgan": correct_gfpgan,
|
||||
"persist-disk": persist_disk,
|
||||
"persist-s3": persist_s3,
|
||||
"reduce-crop": reduce_crop,
|
||||
"reduce-thumbnail": reduce_thumbnail,
|
||||
"source-noise": source_noise,
|
||||
"source-txt2img": source_txt2img,
|
||||
"upscale-outpaint": upscale_outpaint,
|
||||
"upscale-resrgan": upscale_resrgan,
|
||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||
}
|
||||
|
||||
# Available ORT providers
|
||||
|
@ -158,7 +142,7 @@ correction_models = []
|
|||
upscaling_models = []
|
||||
|
||||
|
||||
def get_config_value(key: str, subkey: str = 'default'):
|
||||
def get_config_value(key: str, subkey: str = "default"):
|
||||
return config_params.get(key).get(subkey)
|
||||
|
||||
|
||||
|
@ -174,7 +158,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
user = request.remote_addr
|
||||
|
||||
# platform stuff
|
||||
device_name = request.args.get('platform', available_platforms[0].device)
|
||||
device_name = request.args.get("platform", available_platforms[0].device)
|
||||
device = None
|
||||
|
||||
for platform in available_platforms:
|
||||
|
@ -182,78 +166,101 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
device = available_platforms[0]
|
||||
|
||||
if device is None:
|
||||
raise Exception('unknown device')
|
||||
raise Exception("unknown device")
|
||||
|
||||
# pipeline stuff
|
||||
model = get_not_empty(request.args, 'model', get_config_value('model'))
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(model)
|
||||
scheduler = get_from_map(request.args, 'scheduler',
|
||||
pipeline_schedulers, get_config_value('scheduler'))
|
||||
scheduler = get_from_map(
|
||||
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
|
||||
)
|
||||
|
||||
# image params
|
||||
prompt = get_not_empty(request.args,
|
||||
'prompt', get_config_value('prompt'))
|
||||
negative_prompt = request.args.get('negativePrompt', None)
|
||||
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
|
||||
negative_prompt = request.args.get("negativePrompt", None)
|
||||
|
||||
if negative_prompt is not None and negative_prompt.strip() == '':
|
||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||
negative_prompt = None
|
||||
|
||||
cfg = get_and_clamp_float(
|
||||
request.args, 'cfg',
|
||||
get_config_value('cfg'),
|
||||
get_config_value('cfg', 'max'),
|
||||
get_config_value('cfg', 'min'))
|
||||
request.args,
|
||||
"cfg",
|
||||
get_config_value("cfg"),
|
||||
get_config_value("cfg", "max"),
|
||||
get_config_value("cfg", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args, 'steps',
|
||||
get_config_value('steps'),
|
||||
get_config_value('steps', 'max'),
|
||||
get_config_value('steps', 'min'))
|
||||
request.args,
|
||||
"steps",
|
||||
get_config_value("steps"),
|
||||
get_config_value("steps", "max"),
|
||||
get_config_value("steps", "min"),
|
||||
)
|
||||
height = get_and_clamp_int(
|
||||
request.args, 'height',
|
||||
get_config_value('height'),
|
||||
get_config_value('height', 'max'),
|
||||
get_config_value('height', 'min'))
|
||||
request.args,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
request.args, 'width',
|
||||
get_config_value('width'),
|
||||
get_config_value('width', 'max'),
|
||||
get_config_value('width', 'min'))
|
||||
request.args,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
|
||||
seed = int(request.args.get('seed', -1))
|
||||
seed = int(request.args.get("seed", -1))
|
||||
if seed == -1:
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
||||
logger.info(
|
||||
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
steps,
|
||||
scheduler.__name__,
|
||||
model_path,
|
||||
device.provider,
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
seed,
|
||||
prompt,
|
||||
)
|
||||
|
||||
params = ImageParams(model_path, scheduler, prompt,
|
||||
negative_prompt, cfg, steps, seed)
|
||||
params = ImageParams(
|
||||
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def border_from_request() -> Border:
|
||||
left = get_and_clamp_int(request.args, 'left', 0,
|
||||
get_config_value('width', 'max'), 0)
|
||||
right = get_and_clamp_int(request.args, 'right',
|
||||
0, get_config_value('width', 'max'), 0)
|
||||
top = get_and_clamp_int(request.args, 'top', 0,
|
||||
get_config_value('height', 'max'), 0)
|
||||
left = get_and_clamp_int(
|
||||
request.args, "left", 0, get_config_value("width", "max"), 0
|
||||
)
|
||||
right = get_and_clamp_int(
|
||||
request.args, "right", 0, get_config_value("width", "max"), 0
|
||||
)
|
||||
top = get_and_clamp_int(
|
||||
request.args, "top", 0, get_config_value("height", "max"), 0
|
||||
)
|
||||
bottom = get_and_clamp_int(
|
||||
request.args, 'bottom', 0, get_config_value('height', 'max'), 0)
|
||||
request.args, "bottom", 0, get_config_value("height", "max"), 0
|
||||
)
|
||||
|
||||
return Border(left, right, top, bottom)
|
||||
|
||||
|
||||
def upscale_from_request() -> UpscaleParams:
|
||||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
||||
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
|
||||
correction = get_from_list(request.args, 'correction', correction_models)
|
||||
faces = get_not_empty(request.args, 'faces', 'false') == 'true'
|
||||
face_strength = get_and_clamp_float(
|
||||
request.args, 'faceStrength', 0.5, 1.0, 0.0)
|
||||
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
|
||||
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
|
||||
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
|
||||
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
|
||||
correction = get_from_list(request.args, "correction", correction_models)
|
||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
|
@ -261,7 +268,7 @@ def upscale_from_request() -> UpscaleParams:
|
|||
denoise=denoise,
|
||||
faces=faces,
|
||||
face_strength=face_strength,
|
||||
format='onnx',
|
||||
format="onnx",
|
||||
outscale=outscale,
|
||||
scale=scale,
|
||||
)
|
||||
|
@ -269,7 +276,7 @@ def upscale_from_request() -> UpscaleParams:
|
|||
|
||||
def check_paths(context: ServerContext):
|
||||
if not path.exists(context.model_path):
|
||||
raise RuntimeError('model path must exist')
|
||||
raise RuntimeError("model path must exist")
|
||||
|
||||
if not path.exists(context.output_path):
|
||||
makedirs(context.output_path)
|
||||
|
@ -286,35 +293,41 @@ def load_models(context: ServerContext):
|
|||
global correction_models
|
||||
global upscaling_models
|
||||
|
||||
diffusion_models = [get_model_name(f) for f in glob(
|
||||
path.join(context.model_path, 'diffusion-*'))]
|
||||
diffusion_models.extend([
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))])
|
||||
diffusion_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
|
||||
]
|
||||
diffusion_models.extend(
|
||||
[
|
||||
get_model_name(f)
|
||||
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
|
||||
]
|
||||
)
|
||||
diffusion_models = list(set(diffusion_models))
|
||||
diffusion_models.sort()
|
||||
|
||||
correction_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))]
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
||||
]
|
||||
correction_models = list(set(correction_models))
|
||||
correction_models.sort()
|
||||
|
||||
upscaling_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))]
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||
]
|
||||
upscaling_models = list(set(upscaling_models))
|
||||
upscaling_models.sort()
|
||||
|
||||
|
||||
def load_params(context: ServerContext):
|
||||
global config_params
|
||||
params_file = path.join(context.params_path, 'params.json')
|
||||
with open(params_file, 'r') as f:
|
||||
params_file = path.join(context.params_path, "params.json")
|
||||
with open(params_file, "r") as f:
|
||||
config_params = yaml.safe_load(f)
|
||||
|
||||
if 'platform' in config_params and context.default_platform is not None:
|
||||
logger.info('overriding default platform to %s',
|
||||
context.default_platform)
|
||||
config_platform = config_params.get('platform')
|
||||
config_platform['default'] = context.default_platform
|
||||
if "platform" in config_params and context.default_platform is not None:
|
||||
logger.info("overriding default platform to %s", context.default_platform)
|
||||
config_platform = config_params.get("platform")
|
||||
config_platform["default"] = context.default_platform
|
||||
|
||||
|
||||
def load_platforms():
|
||||
|
@ -323,30 +336,42 @@ def load_platforms():
|
|||
providers = get_available_providers()
|
||||
|
||||
for potential in platform_providers:
|
||||
if platform_providers[potential] in providers and potential not in context.block_platforms:
|
||||
if potential == 'cuda':
|
||||
if (
|
||||
platform_providers[potential] in providers
|
||||
and potential not in context.block_platforms
|
||||
):
|
||||
if potential == "cuda":
|
||||
for i in range(torch.cuda.device_count()):
|
||||
available_platforms.append(DeviceParams(potential, platform_providers[potential], {
|
||||
'device_id': i,
|
||||
}))
|
||||
available_platforms.append(
|
||||
DeviceParams(
|
||||
potential,
|
||||
platform_providers[potential],
|
||||
{
|
||||
"device_id": i,
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
available_platforms.append(DeviceParams(
|
||||
potential, platform_providers[potential]))
|
||||
available_platforms.append(
|
||||
DeviceParams(potential, platform_providers[potential])
|
||||
)
|
||||
|
||||
# make sure CPU is last on the list
|
||||
def cpu_last(a: DeviceParams, b: DeviceParams):
|
||||
if a.device == 'cpu' and b.device == 'cpu':
|
||||
if a.device == "cpu" and b.device == "cpu":
|
||||
return 0
|
||||
|
||||
if a.device == 'cpu':
|
||||
if a.device == "cpu":
|
||||
return 1
|
||||
|
||||
return -1
|
||||
|
||||
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
||||
|
||||
logger.info('available acceleration platforms: %s',
|
||||
', '.join([str(p) for p in available_platforms]))
|
||||
logger.info(
|
||||
"available acceleration platforms: %s",
|
||||
", ".join([str(p) for p in available_platforms]),
|
||||
)
|
||||
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
|
@ -365,16 +390,22 @@ if is_debug():
|
|||
|
||||
|
||||
def ready_reply(ready: bool, progress: int = 0):
|
||||
return jsonify({
|
||||
'progress': progress,
|
||||
'ready': ready,
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
"progress": progress,
|
||||
"ready": ready,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def error_reply(err: str):
|
||||
response = make_response(jsonify({
|
||||
'error': err,
|
||||
}))
|
||||
response = make_response(
|
||||
jsonify(
|
||||
{
|
||||
"error": err,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.status_code = 400
|
||||
return response
|
||||
|
||||
|
@ -383,151 +414,154 @@ def get_model_path(model: str):
|
|||
return base_join(context.model_path, model)
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', context.bundle_path), filename)
|
||||
def serve_bundle_file(filename="index.html"):
|
||||
return send_from_directory(path.join("..", context.bundle_path), filename)
|
||||
|
||||
|
||||
# routes
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@app.route("/")
|
||||
def index():
|
||||
return serve_bundle_file()
|
||||
|
||||
|
||||
@app.route('/<path:filename>')
|
||||
@app.route("/<path:filename>")
|
||||
def index_path(filename):
|
||||
return serve_bundle_file(filename)
|
||||
|
||||
|
||||
@app.route('/api')
|
||||
@app.route("/api")
|
||||
def introspect():
|
||||
return {
|
||||
'name': 'onnx-web',
|
||||
'routes': [{
|
||||
'path': url_from_rule(rule),
|
||||
'methods': list(rule.methods).sort()
|
||||
} for rule in app.url_map.iter_rules()]
|
||||
"name": "onnx-web",
|
||||
"routes": [
|
||||
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
||||
for rule in app.url_map.iter_rules()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.route('/api/settings/masks')
|
||||
@app.route("/api/settings/masks")
|
||||
def list_mask_filters():
|
||||
return jsonify(list(mask_filters.keys()))
|
||||
|
||||
|
||||
@app.route('/api/settings/models')
|
||||
@app.route("/api/settings/models")
|
||||
def list_models():
|
||||
return jsonify({
|
||||
'diffusion': diffusion_models,
|
||||
'correction': correction_models,
|
||||
'upscaling': upscaling_models,
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
"diffusion": diffusion_models,
|
||||
"correction": correction_models,
|
||||
"upscaling": upscaling_models,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.route('/api/settings/noises')
|
||||
@app.route("/api/settings/noises")
|
||||
def list_noise_sources():
|
||||
return jsonify(list(noise_sources.keys()))
|
||||
|
||||
|
||||
@app.route('/api/settings/params')
|
||||
@app.route("/api/settings/params")
|
||||
def list_params():
|
||||
return jsonify(config_params)
|
||||
|
||||
|
||||
@app.route('/api/settings/platforms')
|
||||
@app.route("/api/settings/platforms")
|
||||
def list_platforms():
|
||||
return jsonify([p.device for p in available_platforms])
|
||||
|
||||
|
||||
@app.route('/api/settings/schedulers')
|
||||
@app.route("/api/settings/schedulers")
|
||||
def list_schedulers():
|
||||
return jsonify(list(pipeline_schedulers.keys()))
|
||||
|
||||
|
||||
@app.route('/api/img2img', methods=['POST'])
|
||||
@app.route("/api/img2img", methods=["POST"])
|
||||
def img2img():
|
||||
if 'source' not in request.files:
|
||||
return error_reply('source image is required')
|
||||
if "source" not in request.files:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
source_file = request.files.get("source")
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
'strength',
|
||||
get_config_value('strength'),
|
||||
get_config_value('strength', 'max'),
|
||||
get_config_value('strength', 'min'))
|
||||
"strength",
|
||||
get_config_value("strength"),
|
||||
get_config_value("strength", "max"),
|
||||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
'img2img',
|
||||
params,
|
||||
size,
|
||||
extras=(strength,))
|
||||
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
||||
logger.info("img2img job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit(output, run_img2img_pipeline,
|
||||
context, params, output, upscale, source_image, strength)
|
||||
executor.submit(
|
||||
output,
|
||||
run_img2img_pipeline,
|
||||
context,
|
||||
params,
|
||||
output,
|
||||
upscale,
|
||||
source_image,
|
||||
strength,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
@app.route('/api/txt2img', methods=['POST'])
|
||||
@app.route("/api/txt2img", methods=["POST"])
|
||||
def txt2img():
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
'txt2img',
|
||||
params,
|
||||
size)
|
||||
output = make_output_name(context, "txt2img", params, size)
|
||||
logger.info("txt2img job queued for: %s", output)
|
||||
|
||||
executor.submit(
|
||||
output, run_txt2img_pipeline, context, params, size, output, upscale)
|
||||
output, run_txt2img_pipeline, context, params, size, output, upscale
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
@app.route('/api/inpaint', methods=['POST'])
|
||||
@app.route("/api/inpaint", methods=["POST"])
|
||||
def inpaint():
|
||||
if 'source' not in request.files:
|
||||
return error_reply('source image is required')
|
||||
if "source" not in request.files:
|
||||
return error_reply("source image is required")
|
||||
|
||||
if 'mask' not in request.files:
|
||||
return error_reply('mask image is required')
|
||||
if "mask" not in request.files:
|
||||
return error_reply("mask image is required")
|
||||
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
source_file = request.files.get("source")
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
mask_file = request.files.get('mask')
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
||||
mask_file = request.files.get("mask")
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
fill_color = get_not_empty(request.args, 'fillColor', 'white')
|
||||
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
|
||||
noise_source = get_from_map(
|
||||
request.args, 'noise', noise_sources, 'histogram')
|
||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
|
||||
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
'strength',
|
||||
get_config_value('strength'),
|
||||
get_config_value('strength', 'max'),
|
||||
get_config_value('strength', 'min'))
|
||||
"strength",
|
||||
get_config_value("strength"),
|
||||
get_config_value("strength", "max"),
|
||||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
'inpaint',
|
||||
"inpaint",
|
||||
params,
|
||||
size,
|
||||
extras=(
|
||||
|
@ -539,7 +573,7 @@ def inpaint():
|
|||
noise_source.__name__,
|
||||
strength,
|
||||
fill_color,
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.info("inpaint job queued for: %s", output)
|
||||
|
||||
|
@ -559,123 +593,131 @@ def inpaint():
|
|||
noise_source,
|
||||
mask_filter,
|
||||
strength,
|
||||
fill_color)
|
||||
fill_color,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||
|
||||
|
||||
@app.route('/api/upscale', methods=['POST'])
|
||||
@app.route("/api/upscale", methods=["POST"])
|
||||
def upscale():
|
||||
if 'source' not in request.files:
|
||||
return error_reply('source image is required')
|
||||
if "source" not in request.files:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
source_file = request.files.get("source")
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
'upscale',
|
||||
params,
|
||||
size)
|
||||
output = make_output_name(context, "upscale", params, size)
|
||||
logger.info("upscale job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit(output, run_upscale_pipeline,
|
||||
context, params, size, output, upscale, source_image)
|
||||
executor.submit(
|
||||
output,
|
||||
run_upscale_pipeline,
|
||||
context,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
source_image,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
@app.route('/api/chain', methods=['POST'])
|
||||
@app.route("/api/chain", methods=["POST"])
|
||||
def chain():
|
||||
logger.debug('chain pipeline request: %s, %s',
|
||||
request.form.keys(), request.files.keys())
|
||||
body = request.form.get('chain') or request.files.get('chain')
|
||||
logger.debug(
|
||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||
)
|
||||
body = request.form.get("chain") or request.files.get("chain")
|
||||
if body is None:
|
||||
return error_reply('chain pipeline must have a body')
|
||||
return error_reply("chain pipeline must have a body")
|
||||
|
||||
data = yaml.safe_load(body)
|
||||
with open('./schema.yaml', 'r') as f:
|
||||
with open("./schema.yaml", "r") as f:
|
||||
schema = yaml.safe_load(f.read())
|
||||
|
||||
logger.info('validating chain request: %s against %s', data, schema)
|
||||
logger.info("validating chain request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, params, size = pipeline_from_request()
|
||||
output = make_output_name(
|
||||
context,
|
||||
'chain',
|
||||
params,
|
||||
size)
|
||||
output = make_output_name(context, "chain", params, size)
|
||||
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get('stages', []):
|
||||
callback = chain_stages[stage_data.get('type')]
|
||||
kwargs = stage_data.get('params', {})
|
||||
logger.info('request stage: %s, %s', callback.__name__, kwargs)
|
||||
for stage_data in data.get("stages", []):
|
||||
callback = chain_stages[stage_data.get("type")]
|
||||
kwargs = stage_data.get("params", {})
|
||||
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
||||
|
||||
stage = StageParams(
|
||||
stage_data.get('name', callback.__name__),
|
||||
tile_size=get_size(kwargs.get('tile_size')),
|
||||
outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4),
|
||||
stage_data.get("name", callback.__name__),
|
||||
tile_size=get_size(kwargs.get("tile_size")),
|
||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||
)
|
||||
|
||||
if 'border' in kwargs:
|
||||
border = Border.even(int(kwargs.get('border')))
|
||||
kwargs['border'] = border
|
||||
if "border" in kwargs:
|
||||
border = Border.even(int(kwargs.get("border")))
|
||||
kwargs["border"] = border
|
||||
|
||||
if 'upscale' in kwargs:
|
||||
upscale = UpscaleParams(kwargs.get('upscale'))
|
||||
kwargs['upscale'] = upscale
|
||||
if "upscale" in kwargs:
|
||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
||||
kwargs["upscale"] = upscale
|
||||
|
||||
stage_source_name = 'source:%s' % (stage.name)
|
||||
stage_mask_name = 'mask:%s' % (stage.name)
|
||||
stage_source_name = "source:%s" % (stage.name)
|
||||
stage_mask_name = "mask:%s" % (stage.name)
|
||||
|
||||
if stage_source_name in request.files:
|
||||
logger.debug('loading source image %s for pipeline stage %s',
|
||||
stage_source_name, stage.name)
|
||||
logger.debug(
|
||||
"loading source image %s for pipeline stage %s",
|
||||
stage_source_name,
|
||||
stage.name,
|
||||
)
|
||||
source_file = request.files.get(stage_source_name)
|
||||
source_image = Image.open(
|
||||
BytesIO(source_file.read())).convert('RGB')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
source_image = source_image.thumbnail((512, 512))
|
||||
kwargs['source_image'] = source_image
|
||||
kwargs["source_image"] = source_image
|
||||
|
||||
if stage_mask_name in request.files:
|
||||
logger.debug('loading mask image %s for pipeline stage %s',
|
||||
stage_mask_name, stage.name)
|
||||
logger.debug(
|
||||
"loading mask image %s for pipeline stage %s",
|
||||
stage_mask_name,
|
||||
stage.name,
|
||||
)
|
||||
mask_file = request.files.get(stage_mask_name)
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
mask_image = mask_image.thumbnail((512, 512))
|
||||
kwargs['mask_image'] = mask_image
|
||||
kwargs["mask_image"] = mask_image
|
||||
|
||||
pipeline.append((callback, stage, kwargs))
|
||||
|
||||
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
# build and run chain pipeline
|
||||
empty_source = Image.new('RGB', (size.width, size.height))
|
||||
executor.submit(output, pipeline, context,
|
||||
params, empty_source, output=output, size=size)
|
||||
empty_source = Image.new("RGB", (size.width, size.height))
|
||||
executor.submit(
|
||||
output, pipeline, context, params, empty_source, output=output, size=size
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
@app.route('/api/cancel', methods=['PUT'])
|
||||
@app.route("/api/cancel", methods=["PUT"])
|
||||
def cancel():
|
||||
output_file = request.args.get('output', None)
|
||||
output_file = request.args.get("output", None)
|
||||
|
||||
cancel = executor.cancel(output_file)
|
||||
|
||||
return ready_reply(cancel)
|
||||
|
||||
|
||||
@app.route('/api/ready')
|
||||
@app.route("/api/ready")
|
||||
def ready():
|
||||
output_file = request.args.get('output', None)
|
||||
output_file = request.args.get("output", None)
|
||||
|
||||
done, progress = executor.done(output_file)
|
||||
|
||||
|
@ -687,11 +729,13 @@ def ready():
|
|||
return ready_reply(done, progress=progress)
|
||||
|
||||
|
||||
@app.route('/api/status')
|
||||
@app.route("/api/status")
|
||||
def status():
|
||||
return jsonify(executor.status())
|
||||
|
||||
|
||||
@app.route('/output/<path:filename>')
|
||||
@app.route("/output/<path:filename>")
|
||||
def output(filename: str):
|
||||
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
|
||||
return send_from_directory(
|
||||
path.join("..", context.output_path), filename, as_attachment=False
|
||||
)
|
||||
|
|
|
@ -1,24 +1,16 @@
|
|||
from logging import getLogger
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .chain import (
|
||||
correct_gfpgan,
|
||||
upscale_stable_diffusion,
|
||||
upscale_resrgan,
|
||||
ChainPipeline,
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from .params import (
|
||||
ImageParams,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
)
|
||||
from .device_pool import JobContext
|
||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -32,27 +24,25 @@ def run_upscale_correction(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image.Image:
|
||||
'''
|
||||
"""
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
correction, based on the `upscale` params.
|
||||
'''
|
||||
logger.info('running upscaling and correction pipeline')
|
||||
"""
|
||||
logger.info("running upscaling and correction pipeline")
|
||||
|
||||
chain = ChainPipeline()
|
||||
|
||||
if upscale.scale > 1:
|
||||
if 'esrgan' in upscale.upscale_model:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
if "esrgan" in upscale.upscale_model:
|
||||
stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale)
|
||||
chain.append((upscale_resrgan, stage, None))
|
||||
elif 'stable-diffusion' in upscale.upscale_model:
|
||||
elif "stable-diffusion" in upscale.upscale_model:
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
chain.append((upscale_stable_diffusion, stage, None))
|
||||
|
||||
if upscale.faces:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=1)
|
||||
stage = StageParams(tile_size=stage.tile_size, outscale=1)
|
||||
chain.append((correct_gfpgan, stage, None))
|
||||
|
||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import gc
|
||||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gc
|
||||
import torch
|
||||
|
||||
from .params import (
|
||||
SizeChart,
|
||||
)
|
||||
from .params import SizeChart
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -15,15 +13,15 @@ logger = getLogger(__name__)
|
|||
class ServerContext:
|
||||
def __init__(
|
||||
self,
|
||||
bundle_path: str = '.',
|
||||
model_path: str = '.',
|
||||
output_path: str = '.',
|
||||
params_path: str = '.',
|
||||
cors_origin: str = '*',
|
||||
bundle_path: str = ".",
|
||||
model_path: str = ".",
|
||||
output_path: str = ".",
|
||||
params_path: str = ".",
|
||||
cors_origin: str = "*",
|
||||
num_workers: int = 1,
|
||||
block_platforms: List[str] = [],
|
||||
default_platform: str = None,
|
||||
image_format: str = 'png',
|
||||
image_format: str = "png",
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -38,40 +36,39 @@ class ServerContext:
|
|||
@classmethod
|
||||
def from_environ(cls):
|
||||
return ServerContext(
|
||||
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
|
||||
path.join('..', 'gui', 'out')),
|
||||
model_path=environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models')),
|
||||
output_path=environ.get(
|
||||
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
|
||||
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
|
||||
# others
|
||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||
block_platforms=environ.get(
|
||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
||||
default_platform=environ.get(
|
||||
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
||||
image_format=environ.get(
|
||||
'ONNX_WEB_IMAGE_FORMAT', 'png'
|
||||
bundle_path=environ.get(
|
||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
||||
),
|
||||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||
# others
|
||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
|
||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
)
|
||||
|
||||
|
||||
def base_join(base: str, tail: str) -> str:
|
||||
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||
return path.join(base, tail_path)
|
||||
|
||||
|
||||
def is_debug() -> bool:
|
||||
return environ.get('DEBUG') is not None
|
||||
return environ.get("DEBUG") is not None
|
||||
|
||||
|
||||
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
||||
def get_and_clamp_float(
|
||||
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
||||
) -> float:
|
||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int:
|
||||
def get_and_clamp_int(
|
||||
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
||||
) -> int:
|
||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
|
@ -80,7 +77,7 @@ def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
|||
if selected in values:
|
||||
return selected
|
||||
|
||||
logger.warn('invalid selection: %s', selected)
|
||||
logger.warn("invalid selection: %s", selected)
|
||||
if len(values) > 0:
|
||||
return values[0]
|
||||
|
||||
|
@ -118,10 +115,10 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
|||
|
||||
return int(val)
|
||||
|
||||
raise Exception('invalid size')
|
||||
raise Exception("invalid size")
|
||||
|
||||
|
||||
def run_gc():
|
||||
logger.debug('running garbage collection')
|
||||
logger.debug("running garbage collection")
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
[tool.isort]
|
||||
profile = "black"
|
Loading…
Reference in New Issue