lint(api): apply black and isort style
This commit is contained in:
parent
c0e5f435ee
commit
54dd34d211
13
api/Makefile
13
api/Makefile
|
@ -23,3 +23,16 @@ package-dist:
|
||||||
|
|
||||||
package-upload:
|
package-upload:
|
||||||
twine upload dist/*
|
twine upload dist/*
|
||||||
|
|
||||||
|
lint-check:
|
||||||
|
black --check --preview onnx_web
|
||||||
|
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||||
|
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
||||||
|
|
||||||
|
lint-fix:
|
||||||
|
black onnx_web
|
||||||
|
isort --skip __init__.py --filter-files onnx_web
|
||||||
|
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
||||||
|
|
||||||
|
typecheck:
|
||||||
|
mypy -m onnx_web.serve
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
black
|
||||||
|
flake8
|
||||||
|
isort
|
||||||
mypy
|
mypy
|
||||||
|
|
||||||
types-Flask-Cors
|
types-Flask-Cors
|
||||||
|
|
|
@ -1,49 +1,31 @@
|
||||||
from . import logging
|
from . import logging
|
||||||
|
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
|
||||||
from .chain import (
|
from .diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
correct_gfpgan,
|
|
||||||
upscale_resrgan,
|
|
||||||
upscale_stable_diffusion,
|
|
||||||
)
|
|
||||||
from .diffusion.load import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
)
|
)
|
||||||
from .image import (
|
from .image import (
|
||||||
expand_image,
|
expand_image,
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
mask_filter_none,
|
mask_filter_none,
|
||||||
noise_source_fill_edge,
|
noise_source_fill_edge,
|
||||||
noise_source_fill_mask,
|
noise_source_fill_mask,
|
||||||
noise_source_gaussian,
|
noise_source_gaussian,
|
||||||
noise_source_histogram,
|
noise_source_histogram,
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
|
||||||
from .params import (
|
|
||||||
Param,
|
|
||||||
Point,
|
|
||||||
Border,
|
|
||||||
Size,
|
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from .upscale import (
|
|
||||||
run_upscale_correction,
|
|
||||||
)
|
)
|
||||||
|
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
|
||||||
|
from .upscale import run_upscale_correction
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_and_clamp_float,
|
ServerContext,
|
||||||
get_and_clamp_int,
|
base_join,
|
||||||
get_from_list,
|
get_and_clamp_float,
|
||||||
get_from_map,
|
get_and_clamp_int,
|
||||||
get_not_empty,
|
get_from_list,
|
||||||
base_join,
|
get_from_map,
|
||||||
ServerContext,
|
get_not_empty,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,42 +1,13 @@
|
||||||
from .base import (
|
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||||
ChainPipeline,
|
from .blend_img2img import blend_img2img
|
||||||
PipelineStage,
|
from .blend_inpaint import blend_inpaint
|
||||||
StageCallback,
|
from .correct_gfpgan import correct_gfpgan
|
||||||
StageParams,
|
from .persist_disk import persist_disk
|
||||||
)
|
from .persist_s3 import persist_s3
|
||||||
from .blend_img2img import (
|
from .reduce_crop import reduce_crop
|
||||||
blend_img2img,
|
from .reduce_thumbnail import reduce_thumbnail
|
||||||
)
|
from .source_noise import source_noise
|
||||||
from .blend_inpaint import (
|
from .source_txt2img import source_txt2img
|
||||||
blend_inpaint,
|
from .upscale_outpaint import upscale_outpaint
|
||||||
)
|
from .upscale_resrgan import upscale_resrgan
|
||||||
from .correct_gfpgan import (
|
from .upscale_stable_diffusion import upscale_stable_diffusion
|
||||||
correct_gfpgan,
|
|
||||||
)
|
|
||||||
from .persist_disk import (
|
|
||||||
persist_disk,
|
|
||||||
)
|
|
||||||
from .persist_s3 import (
|
|
||||||
persist_s3,
|
|
||||||
)
|
|
||||||
from .reduce_crop import (
|
|
||||||
reduce_crop,
|
|
||||||
)
|
|
||||||
from .reduce_thumbnail import (
|
|
||||||
reduce_thumbnail,
|
|
||||||
)
|
|
||||||
from .source_noise import (
|
|
||||||
source_noise,
|
|
||||||
)
|
|
||||||
from .source_txt2img import (
|
|
||||||
source_txt2img,
|
|
||||||
)
|
|
||||||
from .upscale_outpaint import (
|
|
||||||
upscale_outpaint,
|
|
||||||
)
|
|
||||||
from .upscale_resrgan import (
|
|
||||||
upscale_resrgan,
|
|
||||||
)
|
|
||||||
from .upscale_stable_diffusion import (
|
|
||||||
upscale_stable_diffusion,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,26 +1,15 @@
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
from typing import Any, List, Optional, Protocol, Tuple
|
from typing import Any, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
from ..device_pool import (
|
from PIL import Image
|
||||||
JobContext,
|
|
||||||
)
|
from ..device_pool import JobContext
|
||||||
from ..params import (
|
from ..output import save_image
|
||||||
ImageParams,
|
from ..params import ImageParams, StageParams
|
||||||
StageParams,
|
from ..utils import ServerContext, is_debug
|
||||||
)
|
from .utils import process_tile_grid
|
||||||
from ..output import (
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
is_debug,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
process_tile_grid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -42,33 +31,43 @@ PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
class ChainPipeline:
|
class ChainPipeline:
|
||||||
'''
|
"""
|
||||||
Run many stages in series, passing the image results from each to the next, and processing
|
Run many stages in series, passing the image results from each to the next, and processing
|
||||||
tiles as needed.
|
tiles as needed.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stages: List[PipelineStage] = [],
|
stages: List[PipelineStage] = [],
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
Create a new pipeline that will run the given stages.
|
Create a new pipeline that will run the given stages.
|
||||||
'''
|
"""
|
||||||
self.stages = list(stages)
|
self.stages = list(stages)
|
||||||
|
|
||||||
def append(self, stage: PipelineStage):
|
def append(self, stage: PipelineStage):
|
||||||
'''
|
"""
|
||||||
Append an additional stage to this pipeline.
|
Append an additional stage to this pipeline.
|
||||||
'''
|
"""
|
||||||
self.stages.append(stage)
|
self.stages.append(stage)
|
||||||
|
|
||||||
def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
|
def __call__(
|
||||||
'''
|
self,
|
||||||
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
**pipeline_kwargs
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
TODO: handle List[Image] outputs
|
TODO: handle List[Image] outputs
|
||||||
'''
|
"""
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
logger.info('running pipeline on source image with dimensions %sx%s',
|
logger.info(
|
||||||
source.width, source.height)
|
"running pipeline on source image with dimensions %sx%s",
|
||||||
|
source.width,
|
||||||
|
source.height,
|
||||||
|
)
|
||||||
image = source
|
image = source
|
||||||
|
|
||||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||||
|
@ -76,37 +75,51 @@ class ChainPipeline:
|
||||||
kwargs = stage_kwargs or {}
|
kwargs = stage_kwargs or {}
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
|
||||||
logger.info('running stage %s on image with dimensions %sx%s, %s',
|
logger.info(
|
||||||
name, image.width, image.height, kwargs.keys())
|
"running stage %s on image with dimensions %sx%s, %s",
|
||||||
|
name,
|
||||||
|
image.width,
|
||||||
|
image.height,
|
||||||
|
kwargs.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
if (
|
||||||
logger.info('image larger than tile size of %s, tiling stage',
|
image.width > stage_params.tile_size
|
||||||
stage_params.tile_size)
|
or image.height > stage_params.tile_size
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"image larger than tile size of %s, tiling stage",
|
||||||
|
stage_params.tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||||
tile = stage_pipe(job, server, stage_params, params, tile,
|
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'last-tile.png', tile)
|
save_image(server, "last-tile.png", tile)
|
||||||
|
|
||||||
return tile
|
return tile
|
||||||
|
|
||||||
image = process_tile_grid(
|
image = process_tile_grid(
|
||||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info('image within tile size, running stage')
|
logger.info("image within tile size, running stage")
|
||||||
image = stage_pipe(job, server, stage_params, params, image,
|
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
logger.info('finished stage %s, result size: %sx%s',
|
logger.info(
|
||||||
name, image.width, image.height)
|
"finished stage %s, result size: %sx%s", name, image.width, image.height
|
||||||
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'last-stage.png', image)
|
save_image(server, "last-stage.png", image)
|
||||||
|
|
||||||
end = monotonic()
|
end = monotonic()
|
||||||
duration = timedelta(seconds=(end - start))
|
duration = timedelta(seconds=(end - start))
|
||||||
logger.info('finished pipeline in %s, result size: %sx%s',
|
logger.info(
|
||||||
duration, image.width, image.height)
|
"finished pipeline in %s, result size: %sx%s",
|
||||||
|
duration,
|
||||||
|
image.width,
|
||||||
|
image.height,
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -1,24 +1,13 @@
|
||||||
from diffusers import (
|
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..diffusion.load import (
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..diffusion.load import load_pipeline
|
||||||
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,10 +24,14 @@ def blend_img2img(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
|
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
params.model,
|
||||||
|
params.scheduler,
|
||||||
|
job.get_device(),
|
||||||
|
)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
|
@ -53,6 +46,5 @@ def blend_img2img(
|
||||||
)
|
)
|
||||||
output = result.images[0]
|
output = result.images[0]
|
||||||
|
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -1,41 +1,17 @@
|
||||||
from diffusers import (
|
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..diffusion.load import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
from ..image import (
|
|
||||||
expand_image,
|
|
||||||
mask_filter_none,
|
|
||||||
noise_source_histogram,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
Border,
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
SizeChart,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..output import (
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
is_debug,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
process_tile_grid,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
|
from ..output import save_image
|
||||||
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
|
from ..utils import ServerContext, is_debug
|
||||||
|
from .utils import process_tile_grid
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -49,16 +25,16 @@ def blend_inpaint(
|
||||||
*,
|
*,
|
||||||
expand: Border,
|
expand: Border,
|
||||||
mask_image: Image.Image = None,
|
mask_image: Image.Image = None,
|
||||||
fill_color: str = 'white',
|
fill_color: str = "white",
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling image by expanding borders', expand)
|
logger.info("upscaling image by expanding borders", expand)
|
||||||
|
|
||||||
if mask_image is None:
|
if mask_image is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
mask_image = Image.new('RGB', source_image.size, 'black')
|
mask_image = Image.new("RGB", source_image.size, "black")
|
||||||
|
|
||||||
source_image, mask_image, noise_image, _full_dims = expand_image(
|
source_image, mask_image, noise_image, _full_dims = expand_image(
|
||||||
source_image,
|
source_image,
|
||||||
|
@ -66,12 +42,13 @@ def blend_inpaint(
|
||||||
expand,
|
expand,
|
||||||
fill=fill_color,
|
fill=fill_color,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
mask_filter=mask_filter)
|
mask_filter=mask_filter,
|
||||||
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'last-source.png', source_image)
|
save_image(server, "last-source.png", source_image)
|
||||||
save_image(server, 'last-mask.png', mask_image)
|
save_image(server, "last-mask.png", mask_image)
|
||||||
save_image(server, 'last-noise.png', noise_image)
|
save_image(server, "last-noise.png", noise_image)
|
||||||
|
|
||||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||||
left, top, tile = dims
|
left, top, tile = dims
|
||||||
|
@ -79,11 +56,15 @@ def blend_inpaint(
|
||||||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'tile-source.png', image)
|
save_image(server, "tile-source.png", image)
|
||||||
save_image(server, 'tile-mask.png', mask)
|
save_image(server, "tile-mask.png", mask)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
|
params.model,
|
||||||
|
params.scheduler,
|
||||||
|
job.get_device(),
|
||||||
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -104,5 +85,5 @@ def blend_inpaint(
|
||||||
|
|
||||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||||
|
|
||||||
logger.info('final output image size', output.size)
|
logger.info("final output image size", output.size)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,30 +1,53 @@
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import torch
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
from basicsr.utils import img2tensor, tensor2img
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
from logging import getLogger
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
import torch
|
from ..device_pool import JobContext
|
||||||
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
pretrain_model_url = {
|
pretrain_model_url = (
|
||||||
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
}
|
)
|
||||||
|
|
||||||
device = 'cpu'
|
device = "cpu"
|
||||||
upscale = 2
|
upscale = 2
|
||||||
|
|
||||||
def correct_codeformer(image: Image.Image) -> Image.Image:
|
|
||||||
|
def correct_codeformer(
|
||||||
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
source_image: Image.Image,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
model = "TODO"
|
||||||
|
|
||||||
# ------------------ set up CodeFormer restorer -------------------
|
# ------------------ set up CodeFormer restorer -------------------
|
||||||
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
|
net = ARCH_REGISTRY.get("CodeFormer")(
|
||||||
connect_list=['32', '64', '128', '256']).to(device)
|
dim_embd=512,
|
||||||
|
codebook_size=1024,
|
||||||
|
n_head=8,
|
||||||
|
n_layers=9,
|
||||||
|
connect_list=["32", "64", "128", "256"],
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
||||||
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
|
ckpt_path = load_file_from_url(
|
||||||
model_dir='weights/CodeFormer', progress=True, file_name=None)
|
url=pretrain_model_url,
|
||||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
model_dir="weights/CodeFormer",
|
||||||
|
progress=True,
|
||||||
|
file_name=None,
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
|
@ -36,22 +59,24 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
||||||
upscale,
|
upscale,
|
||||||
face_size=512,
|
face_size=512,
|
||||||
crop_ratio=(1, 1),
|
crop_ratio=(1, 1),
|
||||||
det_model = args.detection_model,
|
det_model=model,
|
||||||
save_ext='png',
|
save_ext="png",
|
||||||
use_parse=True,
|
use_parse=True,
|
||||||
device=device)
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# get face landmarks for each face
|
# get face landmarks for each face
|
||||||
num_det_faces = face_helper.get_face_landmarks_5(
|
num_det_faces = face_helper.get_face_landmarks_5(
|
||||||
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
|
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||||
logger.info('detect %s faces', num_det_faces)
|
)
|
||||||
|
logger.info("detect %s faces", num_det_faces)
|
||||||
# align and warp each face
|
# align and warp each face
|
||||||
face_helper.align_warp_face()
|
face_helper.align_warp_face()
|
||||||
|
|
||||||
# face restoration for each cropped face
|
# face restoration for each cropped face
|
||||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||||
# prepare data
|
# prepare data
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
@ -62,10 +87,10 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error('Failed inference for CodeFormer: %s', error)
|
logger.error("Failed inference for CodeFormer: %s", error)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||||
|
|
||||||
restored_face = restored_face.astype('uint8')
|
restored_face = restored_face.astype("uint8")
|
||||||
face_helper.add_restored_face(restored_face, cropped_face)
|
face_helper.add_restored_face(restored_face, cropped_face)
|
||||||
|
|
||||||
# upsample the background
|
# upsample the background
|
||||||
|
@ -75,13 +100,16 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
|
||||||
else:
|
else:
|
||||||
bg_img = None
|
bg_img = None
|
||||||
|
|
||||||
|
|
||||||
# paste_back
|
# paste_back
|
||||||
face_helper.get_inverse_affine(None)
|
face_helper.get_inverse_affine(None)
|
||||||
# paste each restored face to the input image
|
# paste each restored face to the input image
|
||||||
if face_upsampler is not None:
|
if face_upsampler is not None:
|
||||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler)
|
restored_img = face_helper.paste_faces_to_input_image(
|
||||||
|
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
|
restored_img = face_helper.paste_faces_to_input_image(
|
||||||
|
upsample_img=bg_img, draw_box=False
|
||||||
|
)
|
||||||
|
|
||||||
return restored_img
|
return restored_img
|
||||||
|
|
|
@ -1,27 +1,16 @@
|
||||||
from gfpgan import GFPGANer
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from PIL import Image
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
run_gc,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
from .upscale_resrgan import (
|
|
||||||
load_resrgan,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
from PIL import Image
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..utils import ServerContext, run_gc
|
||||||
|
from .upscale_resrgan import load_resrgan
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -30,7 +19,9 @@ last_pipeline_instance = None
|
||||||
last_pipeline_params = None
|
last_pipeline_params = None
|
||||||
|
|
||||||
|
|
||||||
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None):
|
def load_gfpgan(
|
||||||
|
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
|
||||||
|
):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
|
||||||
|
@ -38,22 +29,22 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
|
||||||
bg_upscale = upscale.rescale(upscale.outscale)
|
bg_upscale = upscale.rescale(upscale.outscale)
|
||||||
upsampler = load_resrgan(ctx, bg_upscale)
|
upsampler = load_resrgan(ctx, bg_upscale)
|
||||||
|
|
||||||
face_path = path.join(ctx.model_path, '%s.pth' %
|
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
|
||||||
(upscale.correction_model))
|
|
||||||
|
|
||||||
if last_pipeline_instance != None and face_path == last_pipeline_params:
|
if last_pipeline_instance is not None and face_path == last_pipeline_params:
|
||||||
logger.info('reusing existing GFPGAN pipeline')
|
logger.info("reusing existing GFPGAN pipeline")
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
logger.debug('loading GFPGAN model from %s', face_path)
|
logger.debug("loading GFPGAN model from %s", face_path)
|
||||||
|
|
||||||
# TODO: find a way to pass the ONNX model to underlying architectures
|
# TODO: find a way to pass the ONNX model to underlying architectures
|
||||||
gfpgan = GFPGANer(
|
gfpgan = GFPGANer(
|
||||||
model_path=face_path,
|
model_path=face_path,
|
||||||
upscale=upscale.outscale,
|
upscale=upscale.outscale,
|
||||||
arch='clean',
|
arch="clean",
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
bg_upsampler=upsampler)
|
bg_upsampler=upsampler,
|
||||||
|
)
|
||||||
|
|
||||||
last_pipeline_instance = gfpgan
|
last_pipeline_instance = gfpgan
|
||||||
last_pipeline_params = face_path
|
last_pipeline_params = face_path
|
||||||
|
@ -74,15 +65,20 @@ def correct_gfpgan(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
if upscale.correction_model is None:
|
if upscale.correction_model is None:
|
||||||
logger.warn('no face model given, skipping')
|
logger.warn("no face model given, skipping")
|
||||||
return source_image
|
return source_image
|
||||||
|
|
||||||
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
|
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||||
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
_, _, output = gfpgan.enhance(
|
_, _, output = gfpgan.enhance(
|
||||||
output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
output,
|
||||||
output = Image.fromarray(output, 'RGB')
|
has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True,
|
||||||
|
weight=upscale.face_strength,
|
||||||
|
)
|
||||||
|
output = Image.fromarray(output, "RGB")
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,26 +1,18 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
from ..output import save_image
|
||||||
)
|
from ..params import ImageParams, StageParams
|
||||||
from ..params import (
|
from ..utils import ServerContext
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..output import (
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_disk(
|
def persist_disk(
|
||||||
_job: JobContext,
|
_job: JobContext,
|
||||||
ctx: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -28,6 +20,6 @@ def persist_disk(
|
||||||
output: str,
|
output: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
dest = save_image(ctx, output, source_image)
|
dest = save_image(server, output, source_image)
|
||||||
logger.info('saved image to %s', dest)
|
logger.info("saved image to %s", dest)
|
||||||
return source_image
|
return source_image
|
||||||
|
|
|
@ -1,26 +1,19 @@
|
||||||
from boto3 import (
|
|
||||||
Session,
|
|
||||||
)
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
|
from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
from ..params import ImageParams, StageParams
|
||||||
)
|
from ..utils import ServerContext
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_s3(
|
def persist_s3(
|
||||||
ctx: ServerContext,
|
_job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -32,16 +25,16 @@ def persist_s3(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source_image.save(data, format=ctx.image_format)
|
source_image.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s3.upload_fileobj(data, bucket, output)
|
s3.upload_fileobj(data, bucket, output)
|
||||||
logger.info('saved image to %s/%s', bucket, output)
|
logger.info("saved image to %s/%s", bucket, output)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error('error saving image to S3: %s', err)
|
logger.error("error saving image to S3: %s", err)
|
||||||
|
|
||||||
return source_image
|
return source_image
|
||||||
|
|
|
@ -1,23 +1,17 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
from ..params import ImageParams, Size, StageParams
|
||||||
)
|
from ..utils import ServerContext
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_crop(
|
def reduce_crop(
|
||||||
ctx: ServerContext,
|
_job: JobContext,
|
||||||
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -26,8 +20,6 @@ def reduce_crop(
|
||||||
size: Size,
|
size: Size,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
image = source_image.crop(
|
image = source_image.crop((origin.width, origin.height, size.width, size.height))
|
||||||
(origin.width, origin.height, size.width, size.height))
|
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||||
logger.info('created thumbnail with dimensions: %sx%s',
|
|
||||||
image.width, image.height)
|
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -1,23 +1,17 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
from ..params import ImageParams, Size, StageParams
|
||||||
)
|
from ..utils import ServerContext
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_thumbnail(
|
def reduce_thumbnail(
|
||||||
ctx: ServerContext,
|
_job: JobContext,
|
||||||
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -26,6 +20,5 @@ def reduce_thumbnail(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
image = source_image.thumbnail((size.width, size.height))
|
image = source_image.thumbnail((size.width, size.height))
|
||||||
logger.info('created thumbnail with dimensions: %sx%s',
|
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||||
image.width, image.height)
|
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -1,26 +1,19 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from ..device_pool import (
|
from PIL import Image
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_noise(
|
def source_noise(
|
||||||
ctx: ServerContext,
|
_job: JobContext,
|
||||||
stage: StageParams,
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
|
@ -28,14 +21,12 @@ def source_noise(
|
||||||
noise_source: Callable,
|
noise_source: Callable,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
logger.info("generating image from noise source")
|
||||||
logger.info('generating image from noise source')
|
|
||||||
|
|
||||||
if source_image is not None:
|
if source_image is not None:
|
||||||
logger.warn(
|
logger.warn("a source image was passed to a noise stage, but will be discarded")
|
||||||
'a source image was passed to a noise stage, but will be discarded')
|
|
||||||
|
|
||||||
output = noise_source(source_image, (size.width, size.height), (0, 0))
|
output = noise_source(source_image, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,26 +1,13 @@
|
||||||
from diffusers import (
|
|
||||||
OnnxStableDiffusionPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..diffusion.load import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,13 +24,16 @@ def source_txt2img(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
|
logger.info("generating image using txt2img, %s steps: %s", params.steps, prompt)
|
||||||
|
|
||||||
if source_image is not None:
|
if source_image is not None:
|
||||||
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
logger.warn(
|
||||||
|
"a source image was passed to a txt2img stage, but will be discarded"
|
||||||
|
)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||||
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -60,5 +50,5 @@ def source_txt2img(
|
||||||
)
|
)
|
||||||
output = result.images[0]
|
output = result.images[0]
|
||||||
|
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,43 +1,17 @@
|
||||||
from diffusers import (
|
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..diffusion.load import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
get_tile_latents,
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
from ..image import (
|
|
||||||
expand_image,
|
|
||||||
mask_filter_none,
|
|
||||||
noise_source_histogram,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
Border,
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
SizeChart,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..output import (
|
|
||||||
save_image,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
base_join,
|
|
||||||
is_debug,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
process_tile_spiral,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||||
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
|
from ..output import save_image
|
||||||
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
|
from ..utils import ServerContext, is_debug
|
||||||
|
from .utils import process_tile_spiral
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,17 +26,17 @@ def upscale_outpaint(
|
||||||
border: Border,
|
border: Border,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
mask_image: Image.Image = None,
|
mask_image: Image.Image = None,
|
||||||
fill_color: str = 'white',
|
fill_color: str = "white",
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
logger.info('upscaling image by expanding borders: %s', border)
|
logger.info("upscaling image by expanding borders: %s", border)
|
||||||
|
|
||||||
if mask_image is None:
|
if mask_image is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
mask_image = Image.new('RGB', source_image.size, 'black')
|
mask_image = Image.new("RGB", source_image.size, "black")
|
||||||
|
|
||||||
source_image, mask_image, noise_image, full_dims = expand_image(
|
source_image, mask_image, noise_image, full_dims = expand_image(
|
||||||
source_image,
|
source_image,
|
||||||
|
@ -70,16 +44,17 @@ def upscale_outpaint(
|
||||||
border,
|
border,
|
||||||
fill=fill_color,
|
fill=fill_color,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
mask_filter=mask_filter)
|
mask_filter=mask_filter,
|
||||||
|
)
|
||||||
|
|
||||||
draw_mask = ImageDraw.Draw(mask_image)
|
draw_mask = ImageDraw.Draw(mask_image)
|
||||||
full_size = Size(*full_dims)
|
full_size = Size(*full_dims)
|
||||||
full_latents = get_latents_from_seed(params.seed, full_size)
|
full_latents = get_latents_from_seed(params.seed, full_size)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'last-source.png', source_image)
|
save_image(server, "last-source.png", source_image)
|
||||||
save_image(server, 'last-mask.png', mask_image)
|
save_image(server, "last-mask.png", mask_image)
|
||||||
save_image(server, 'last-noise.png', noise_image)
|
save_image(server, "last-noise.png", noise_image)
|
||||||
|
|
||||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||||
left, top, tile = dims
|
left, top, tile = dims
|
||||||
|
@ -87,11 +62,15 @@ def upscale_outpaint(
|
||||||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, 'tile-source.png', image)
|
save_image(server, "tile-source.png", image)
|
||||||
save_image(server, 'tile-mask.png', mask)
|
save_image(server, "tile-mask.png", mask)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
|
params.model,
|
||||||
|
params.scheduler,
|
||||||
|
job.get_device(),
|
||||||
|
)
|
||||||
|
|
||||||
latents = get_tile_latents(full_latents, dims)
|
latents = get_tile_latents(full_latents, dims)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -110,10 +89,10 @@ def upscale_outpaint(
|
||||||
)
|
)
|
||||||
|
|
||||||
# once part of the image has been drawn, keep it
|
# once part of the image has been drawn, keep it
|
||||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black')
|
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
||||||
|
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,27 +1,15 @@
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
from ..onnx import OnnxNet
|
||||||
)
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..onnx import (
|
from ..utils import ServerContext, run_gc
|
||||||
OnnxNet,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
DeviceParams,
|
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
run_gc,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -29,39 +17,50 @@ last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
|
def load_resrgan(
|
||||||
|
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||||
|
):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
|
||||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||||
model_path = path.join(ctx.model_path, model_file)
|
model_path = path.join(ctx.model_path, model_file)
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||||
|
|
||||||
cache_params = (model_path, params.format)
|
cache_params = (model_path, params.format)
|
||||||
if last_pipeline_instance != None and cache_params == last_pipeline_params:
|
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||||
logger.info('reusing existing Real ESRGAN pipeline')
|
logger.info("reusing existing Real ESRGAN pipeline")
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
# use ONNX acceleration, if available
|
# use ONNX acceleration, if available
|
||||||
if params.format == 'onnx':
|
if params.format == "onnx":
|
||||||
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
|
model = OnnxNet(
|
||||||
elif params.format == 'pth':
|
ctx, model_file, provider=device.provider, provider_options=device.options
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
)
|
||||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
elif params.format == "pth":
|
||||||
raise Exception('unknown platform %s' % params.format)
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=params.scale,
|
||||||
|
)
|
||||||
|
raise Exception("unknown platform %s" % params.format)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
if params.upscale_model == "realesr-general-x4v3" and params.denoise != 1:
|
||||||
wdn_model_path = model_path.replace(
|
wdn_model_path = model_path.replace(
|
||||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
"realesr-general-x4v3", "realesr-general-wdn-x4v3"
|
||||||
|
)
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [params.denoise, 1 - params.denoise]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
logger.debug('loading Real ESRGAN upscale model from %s', model_path)
|
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||||
|
|
||||||
# TODO: shouldn't need the PTH file
|
# TODO: shouldn't need the PTH file
|
||||||
model_path_pth = path.join(ctx.model_path, '%s.pth' % params.upscale_model)
|
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
model_path=model_path_pth,
|
model_path=model_path_pth,
|
||||||
|
@ -70,7 +69,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
|
||||||
tile=tile,
|
tile=tile,
|
||||||
tile_pad=params.tile_pad,
|
tile_pad=params.tile_pad,
|
||||||
pre_pad=params.pre_pad,
|
pre_pad=params.pre_pad,
|
||||||
half=params.half)
|
half=params.half,
|
||||||
|
)
|
||||||
|
|
||||||
last_pipeline_instance = upsampler
|
last_pipeline_instance = upsampler
|
||||||
last_pipeline_params = cache_params
|
last_pipeline_params = cache_params
|
||||||
|
@ -89,13 +89,13 @@ def upscale_resrgan(
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
|
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||||
|
|
||||||
output = Image.fromarray(output, 'RGB')
|
output = Image.fromarray(output, "RGB")
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,28 +1,16 @@
|
||||||
from diffusers import (
|
|
||||||
StableDiffusionUpscalePipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import (
|
from ..device_pool import JobContext
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from ..params import (
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
DeviceParams,
|
from ..utils import ServerContext, run_gc
|
||||||
ImageParams,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
run_gc,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -31,23 +19,37 @@ last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
|
|
||||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
|
def load_stable_diffusion(
|
||||||
|
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
||||||
|
):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
|
||||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||||
cache_params = (model_path, upscale.format)
|
cache_params = (model_path, upscale.format)
|
||||||
|
|
||||||
if last_pipeline_instance != None and cache_params == last_pipeline_params:
|
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||||
logger.info('reusing existing Stable Diffusion upscale pipeline')
|
logger.info("reusing existing Stable Diffusion upscale pipeline")
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
if upscale.format == 'onnx':
|
if upscale.format == "onnx":
|
||||||
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
logger.debug(
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
|
"loading Stable Diffusion upscale ONNX model from %s, using provider %s",
|
||||||
|
model_path,
|
||||||
|
device.provider,
|
||||||
|
)
|
||||||
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
model_path, provider=device.provider, provider_options=device.options
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
logger.debug(
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
|
"loading Stable Diffusion upscale model from %s, using provider %s",
|
||||||
|
model_path,
|
||||||
|
device.provider,
|
||||||
|
)
|
||||||
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
model_path, provider=device.provider
|
||||||
|
)
|
||||||
|
|
||||||
last_pipeline_instance = pipeline
|
last_pipeline_instance = pipeline
|
||||||
last_pipeline_params = cache_params
|
last_pipeline_params = cache_params
|
||||||
|
@ -68,7 +70,7 @@ def upscale_stable_diffusion(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
logger.info("upscaling with Stable Diffusion, %s steps: %s", params.steps, prompt)
|
||||||
|
|
||||||
pipeline = load_stable_diffusion(server, upscale, job.get_device())
|
pipeline = load_stable_diffusion(server, upscale, job.get_device())
|
||||||
generator = torch.manual_seed(params.seed)
|
generator = torch.manual_seed(params.seed)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
from typing import List, Protocol, Tuple
|
from typing import List, Protocol, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ def process_tile_grid(
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = source.size
|
width, height = source.size
|
||||||
image = Image.new('RGB', (width * scale, height * scale))
|
image = Image.new("RGB", (width * scale, height * scale))
|
||||||
|
|
||||||
tiles_x = width // tile
|
tiles_x = width // tile
|
||||||
tiles_y = height // tile
|
tiles_y = height // tile
|
||||||
|
@ -28,7 +29,7 @@ def process_tile_grid(
|
||||||
idx = (y * tiles_x) + x
|
idx = (y * tiles_x) + x
|
||||||
left = x * tile
|
left = x * tile
|
||||||
top = y * tile
|
top = y * tile
|
||||||
logger.info('processing tile %s of %s, %s.%s', idx + 1, total, y, x)
|
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
||||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
|
@ -47,10 +48,10 @@ def process_tile_spiral(
|
||||||
overlap: float = 0.5,
|
overlap: float = 0.5,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
if scale != 1:
|
if scale != 1:
|
||||||
raise Exception('unsupported scale')
|
raise Exception("unsupported scale")
|
||||||
|
|
||||||
width, height = source.size
|
width, height = source.size
|
||||||
image = Image.new('RGB', (width * scale, height * scale))
|
image = Image.new("RGB", (width * scale, height * scale))
|
||||||
image.paste(source, (0, 0, width, height))
|
image.paste(source, (0, 0, width, height))
|
||||||
|
|
||||||
center_x = (width // 2) - (tile // 2)
|
center_x = (width // 2) - (tile // 2)
|
||||||
|
@ -76,7 +77,7 @@ def process_tile_spiral(
|
||||||
top = center_y + int(top)
|
top = center_y + int(top)
|
||||||
|
|
||||||
counter += 1
|
counter += 1
|
||||||
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top)
|
logger.info("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
|
||||||
|
|
||||||
# TODO: only valid for scale == 1, resize source for others
|
# TODO: only valid for scale == 1, resize source for others
|
||||||
tile_image = image.crop((left, top, left + tile, top + tile))
|
tile_image = image.crop((left, top, left + tile, top + tile))
|
||||||
|
|
|
@ -1,5 +1,14 @@
|
||||||
from . import logging
|
import warnings
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from json import loads
|
||||||
|
from logging import getLogger
|
||||||
|
from os import environ, makedirs, mkdir, path
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile, rmtree
|
||||||
|
from sys import exit
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
@ -8,25 +17,20 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from json import loads
|
|
||||||
from logging import getLogger
|
|
||||||
from onnx import load, save_model
|
from onnx import load, save_model
|
||||||
from os import environ, makedirs, mkdir, path
|
|
||||||
from pathlib import Path
|
|
||||||
from shutil import copyfile, rmtree
|
|
||||||
from sys import exit
|
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
from . import logging
|
||||||
import warnings
|
|
||||||
|
|
||||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
'ignore', '.*The shape inference of prim::Constant type is missing.*')
|
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||||
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
)
|
||||||
|
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
"ignore",
|
||||||
|
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||||
|
)
|
||||||
|
|
||||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||||
|
|
||||||
|
@ -35,74 +39,95 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
# recommended models
|
# recommended models
|
||||||
base_models: Models = {
|
base_models: Models = {
|
||||||
'diffusion': [
|
"diffusion": [
|
||||||
# v1.x
|
# v1.x
|
||||||
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||||
('stable-diffusion-onnx-v1-inpainting',
|
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||||
'runwayml/stable-diffusion-inpainting'),
|
|
||||||
# v2.x
|
# v2.x
|
||||||
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||||
('stable-diffusion-onnx-v2-inpainting',
|
(
|
||||||
'stabilityai/stable-diffusion-2-inpainting'),
|
"stable-diffusion-onnx-v2-inpainting",
|
||||||
|
"stabilityai/stable-diffusion-2-inpainting",
|
||||||
|
),
|
||||||
# TODO: should have its own converter
|
# TODO: should have its own converter
|
||||||
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
|
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||||
],
|
],
|
||||||
'correction': [
|
"correction": [
|
||||||
('correction-gfpgan-v1-3',
|
(
|
||||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
|
"correction-gfpgan-v1-3",
|
||||||
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
'upscaling': [
|
"upscaling": [
|
||||||
('upscaling-real-esrgan-x2-plus',
|
(
|
||||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
|
"upscaling-real-esrgan-x2-plus",
|
||||||
('upscaling-real-esrgan-x4-plus',
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
|
2,
|
||||||
('upscaling-real-esrgan-x4-v3',
|
),
|
||||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4),
|
(
|
||||||
|
"upscaling-real-esrgan-x4-plus",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"upscaling-real-esrgan-x4-v3",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||||
path.join('..', 'models'))
|
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
map_location = torch.device(training_device)
|
map_location = torch.device(training_device)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||||
dest_path = path.join(model_path, name + '.pth')
|
dest_path = path.join(model_path, name + ".pth")
|
||||||
dest_onnx = path.join(model_path, name + '.onnx')
|
dest_onnx = path.join(model_path, name + ".onnx")
|
||||||
logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx)
|
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
if path.isfile(dest_onnx):
|
||||||
logger.info('ONNX model already exists, skipping.')
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
if not path.isfile(dest_path):
|
||||||
logger.info('PTH model not found, downloading...')
|
logger.info("PTH model not found, downloading...")
|
||||||
download_path = load_file_from_url(
|
download_path = load_file_from_url(
|
||||||
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||||
|
)
|
||||||
copyfile(download_path, dest_path)
|
copyfile(download_path, dest_path)
|
||||||
|
|
||||||
logger.info('loading and training model')
|
logger.info("loading and training model")
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
model = RRDBNet(
|
||||||
num_block=23, num_grow_ch=32, scale=scale)
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=map_location)
|
torch_model = torch.load(dest_path, map_location=map_location)
|
||||||
if 'params_ema' in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model['params_ema'])
|
model.load_state_dict(torch_model["params_ema"])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch_model['params'], strict=False)
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
model.to(training_device).train(False)
|
model.to(training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||||
input_names = ['data']
|
input_names = ["data"]
|
||||||
output_names = ['output']
|
output_names = ["output"]
|
||||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
dynamic_axes = {
|
||||||
'output': {2: 'width', 3: 'height'}}
|
"data": {2: "width", 3: "height"},
|
||||||
|
"output": {2: "width", 3: "height"},
|
||||||
|
}
|
||||||
|
|
||||||
logger.info('exporting ONNX model to %s', dest_onnx)
|
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
rng,
|
rng,
|
||||||
|
@ -111,48 +136,57 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
export_params=True
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info('Real ESRGAN exported to ONNX successfully.')
|
logger.info("Real ESRGAN exported to ONNX successfully.")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
||||||
dest_path = path.join(model_path, name + '.pth')
|
dest_path = path.join(model_path, name + ".pth")
|
||||||
dest_onnx = path.join(model_path, name + '.onnx')
|
dest_onnx = path.join(model_path, name + ".onnx")
|
||||||
logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx)
|
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
if path.isfile(dest_onnx):
|
||||||
logger.info('ONNX model already exists, skipping.')
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
if not path.isfile(dest_path):
|
||||||
logger.info('PTH model not found, downloading...')
|
logger.info("PTH model not found, downloading...")
|
||||||
download_path = load_file_from_url(
|
download_path = load_file_from_url(
|
||||||
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||||
|
)
|
||||||
copyfile(download_path, dest_path)
|
copyfile(download_path, dest_path)
|
||||||
|
|
||||||
logger.info('loading and training model')
|
logger.info("loading and training model")
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
model = RRDBNet(
|
||||||
num_block=23, num_grow_ch=32, scale=scale)
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=map_location)
|
torch_model = torch.load(dest_path, map_location=map_location)
|
||||||
# TODO: make sure strict=False is safe here
|
# TODO: make sure strict=False is safe here
|
||||||
if 'params_ema' in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model['params_ema'], strict=False)
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch_model['params'], strict=False)
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
model.to(training_device).train(False)
|
model.to(training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||||
input_names = ['data']
|
input_names = ["data"]
|
||||||
output_names = ['output']
|
output_names = ["output"]
|
||||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
dynamic_axes = {
|
||||||
'output': {2: 'width', 3: 'height'}}
|
"data": {2: "width", 3: "height"},
|
||||||
|
"output": {2: "width", 3: "height"},
|
||||||
|
}
|
||||||
|
|
||||||
logger.info('exporting ONNX model to %s', dest_onnx)
|
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
rng,
|
rng,
|
||||||
|
@ -161,9 +195,9 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
export_params=True
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info('GFPGAN exported to ONNX successfully.')
|
logger.info("GFPGAN exported to ONNX successfully.")
|
||||||
|
|
||||||
|
|
||||||
def onnx_export(
|
def onnx_export(
|
||||||
|
@ -176,9 +210,9 @@ def onnx_export(
|
||||||
opset,
|
opset,
|
||||||
use_external_data_format=False,
|
use_external_data_format=False,
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
'''
|
"""
|
||||||
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
export(
|
export(
|
||||||
|
@ -194,29 +228,33 @@ def onnx_export(
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
|
def convert_diffuser(
|
||||||
'''
|
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
'''
|
"""
|
||||||
dtype = torch.float16 if half else torch.float32
|
dtype = torch.float16 if half else torch.float32
|
||||||
dest_path = path.join(model_path, name)
|
dest_path = path.join(model_path, name)
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.info('converting model with single VAE')
|
logger.info("converting model with single VAE")
|
||||||
|
|
||||||
if path.isdir(dest_path):
|
if path.isdir(dest_path):
|
||||||
logger.info('ONNX model already exists, skipping.')
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if half and training_device != 'cuda':
|
if half and training_device != "cuda":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Half precision model export is only supported on GPUs with CUDA')
|
"Half precision model export is only supported on GPUs with CUDA"
|
||||||
|
)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
url, torch_dtype=dtype, use_auth_token=token).to(training_device)
|
url, torch_dtype=dtype, use_auth_token=token
|
||||||
|
).to(training_device)
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
|
||||||
# TEXT ENCODER
|
# TEXT ENCODER
|
||||||
|
@ -232,8 +270,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
onnx_export(
|
onnx_export(
|
||||||
pipeline.text_encoder,
|
pipeline.text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(text_input.input_ids.to(
|
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
|
||||||
device=training_device, dtype=torch.int32)),
|
|
||||||
output_path=output_path / "text_encoder" / "model.onnx",
|
output_path=output_path / "text_encoder" / "model.onnx",
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
|
@ -244,7 +281,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
)
|
)
|
||||||
del pipeline.text_encoder
|
del pipeline.text_encoder
|
||||||
|
|
||||||
logger.debug('UNET config: %s', pipeline.unet.config)
|
logger.debug("UNET config: %s", pipeline.unet.config)
|
||||||
|
|
||||||
# UNET
|
# UNET
|
||||||
if single_vae:
|
if single_vae:
|
||||||
|
@ -262,10 +299,12 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
pipeline.unet,
|
pipeline.unet,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype
|
||||||
|
),
|
||||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype
|
||||||
|
),
|
||||||
unet_scale,
|
unet_scale,
|
||||||
),
|
),
|
||||||
output_path=unet_path,
|
output_path=unet_path,
|
||||||
|
@ -298,7 +337,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.debug('VAE config: %s', pipeline.vae.config)
|
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||||
|
|
||||||
# SINGLE VAE
|
# SINGLE VAE
|
||||||
vae_only = pipeline.vae
|
vae_only = pipeline.vae
|
||||||
|
@ -309,8 +348,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
onnx_export(
|
onnx_export(
|
||||||
vae_only,
|
vae_only,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(
|
||||||
device=training_device, dtype=dtype),
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
|
).to(device=training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae" / "model.onnx",
|
output_path=output_path / "vae" / "model.onnx",
|
||||||
|
@ -328,12 +368,14 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
vae_sample_size = vae_encoder.config.sample_size
|
vae_sample_size = vae_encoder.config.sample_size
|
||||||
# need to get the raw tensor output (sample) from the encoder
|
# need to get the raw tensor output (sample) from the encoder
|
||||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||||
sample, return_dict)[0].sample()
|
sample, return_dict
|
||||||
|
)[0].sample()
|
||||||
onnx_export(
|
onnx_export(
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype
|
||||||
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||||
|
@ -354,8 +396,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
onnx_export(
|
onnx_export(
|
||||||
vae_decoder,
|
vae_decoder,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(
|
||||||
device=training_device, dtype=dtype),
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
|
).to(device=training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||||
|
@ -385,7 +428,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
clip_image_size,
|
clip_image_size,
|
||||||
).to(device=training_device, dtype=dtype),
|
).to(device=training_device, dtype=dtype),
|
||||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype
|
||||||
|
),
|
||||||
),
|
),
|
||||||
output_path=output_path / "safety_checker" / "model.onnx",
|
output_path=output_path / "safety_checker" / "model.onnx",
|
||||||
ordered_input_names=["clip_input", "images"],
|
ordered_input_names=["clip_input", "images"],
|
||||||
|
@ -398,7 +442,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
)
|
)
|
||||||
del pipeline.safety_checker
|
del pipeline.safety_checker
|
||||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||||
output_path / "safety_checker")
|
output_path / "safety_checker"
|
||||||
|
)
|
||||||
feature_extractor = pipeline.feature_extractor
|
feature_extractor = pipeline.feature_extractor
|
||||||
else:
|
else:
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
|
@ -406,10 +451,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
onnx_pipeline = StableDiffusionUpscalePipeline(
|
onnx_pipeline = StableDiffusionUpscalePipeline(
|
||||||
vae=OnnxRuntimeModel.from_pretrained(
|
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
|
||||||
output_path / "vae"),
|
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
|
||||||
output_path / "text_encoder"),
|
|
||||||
tokenizer=pipeline.tokenizer,
|
tokenizer=pipeline.tokenizer,
|
||||||
low_res_scheduler=pipeline.scheduler,
|
low_res_scheduler=pipeline.scheduler,
|
||||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||||
|
@ -417,12 +460,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||||
output_path / "vae_encoder"),
|
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
|
||||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||||
output_path / "vae_decoder"),
|
|
||||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
|
||||||
output_path / "text_encoder"),
|
|
||||||
tokenizer=pipeline.tokenizer,
|
tokenizer=pipeline.tokenizer,
|
||||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||||
scheduler=pipeline.scheduler,
|
scheduler=pipeline.scheduler,
|
||||||
|
@ -431,7 +471,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
requires_safety_checker=safety_checker is not None,
|
requires_safety_checker=safety_checker is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info('exporting ONNX model')
|
logger.info("exporting ONNX model")
|
||||||
|
|
||||||
onnx_pipeline.save_pretrained(output_path)
|
onnx_pipeline.save_pretrained(output_path)
|
||||||
logger.info("ONNX pipeline saved to %s", output_path)
|
logger.info("ONNX pipeline saved to %s", output_path)
|
||||||
|
@ -445,90 +485,93 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||||
output_path, provider="CPUExecutionProvider")
|
output_path, provider="CPUExecutionProvider"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
|
||||||
def load_models(args, models: Models):
|
def load_models(args, models: Models):
|
||||||
if args.diffusion:
|
if args.diffusion:
|
||||||
for source in models.get('diffusion'):
|
for source in models.get("diffusion"):
|
||||||
if source[0] in args.skip:
|
if source[0] in args.skip:
|
||||||
logger.info('Skipping model: %s', source[0])
|
logger.info("Skipping model: %s", source[0])
|
||||||
else:
|
else:
|
||||||
single_vae = 'upscaling' in source[0]
|
single_vae = "upscaling" in source[0]
|
||||||
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
|
convert_diffuser(
|
||||||
|
*source, args.opset, args.half, args.token, single_vae=single_vae
|
||||||
|
)
|
||||||
|
|
||||||
if args.upscaling:
|
if args.upscaling:
|
||||||
for source in models.get('upscaling'):
|
for source in models.get("upscaling"):
|
||||||
if source[0] in args.skip:
|
if source[0] in args.skip:
|
||||||
logger.info('Skipping model: %s', source[0])
|
logger.info("Skipping model: %s", source[0])
|
||||||
else:
|
else:
|
||||||
convert_real_esrgan(*source, args.opset)
|
convert_real_esrgan(*source, args.opset)
|
||||||
|
|
||||||
if args.correction:
|
if args.correction:
|
||||||
for source in models.get('correction'):
|
for source in models.get("correction"):
|
||||||
if source[0] in args.skip:
|
if source[0] in args.skip:
|
||||||
logger.info('Skipping model: %s', source[0])
|
logger.info("Skipping model: %s", source[0])
|
||||||
else:
|
else:
|
||||||
convert_gfpgan(*source, args.opset)
|
convert_gfpgan(*source, args.opset)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog='onnx-web model converter',
|
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||||
description='convert checkpoint models to ONNX')
|
)
|
||||||
|
|
||||||
# model groups
|
# model groups
|
||||||
parser.add_argument('--correction', action='store_true', default=False)
|
parser.add_argument("--correction", action="store_true", default=False)
|
||||||
parser.add_argument('--diffusion', action='store_true', default=False)
|
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||||
parser.add_argument('--upscaling', action='store_true', default=False)
|
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||||
|
|
||||||
# extra models
|
# extra models
|
||||||
parser.add_argument('--extras', nargs='*', type=str, default=[])
|
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||||
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||||
|
|
||||||
# export options
|
# export options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--half',
|
"--half",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help='Export models for half precision, faster on some Nvidia cards.'
|
help="Export models for half precision, faster on some Nvidia cards.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--opset',
|
"--opset",
|
||||||
default=14,
|
default=14,
|
||||||
type=int,
|
type=int,
|
||||||
help="The version of the ONNX operator set to use.",
|
help="The version of the ONNX operator set to use.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--token',
|
"--token",
|
||||||
type=str,
|
type=str,
|
||||||
help="HuggingFace token with read permissions for downloading models.",
|
help="HuggingFace token with read permissions for downloading models.",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info('CLI arguments: %s', args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
logger.info('Model path does not existing, creating: %s', model_path)
|
logger.info("Model path does not existing, creating: %s", model_path)
|
||||||
makedirs(model_path)
|
makedirs(model_path)
|
||||||
|
|
||||||
logger.info('Converting base models.')
|
logger.info("Converting base models.")
|
||||||
load_models(args, base_models)
|
load_models(args, base_models)
|
||||||
|
|
||||||
for file in args.extras:
|
for file in args.extras:
|
||||||
logger.info('Loading extra models from %s', file)
|
logger.info("Loading extra models from %s", file)
|
||||||
try:
|
try:
|
||||||
with open(file, 'r') as f:
|
with open(file, "r") as f:
|
||||||
data = loads(f.read())
|
data = loads(f.read())
|
||||||
logger.info('Converting extra models.')
|
logger.info("Converting extra models.")
|
||||||
load_models(args, data)
|
load_models(args, data)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error('Error converting extra models: %s', err)
|
logger.error("Error converting extra models: %s", err)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
exit(main())
|
exit(main())
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .params import (
|
from .params import DeviceParams
|
||||||
DeviceParams,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -28,24 +26,24 @@ class JobContext:
|
||||||
):
|
):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.devices = list(devices)
|
self.devices = list(devices)
|
||||||
self.cancel = Value('B', cancel)
|
self.cancel = Value("B", cancel)
|
||||||
self.device_index = Value('i', device_index)
|
self.device_index = Value("i", device_index)
|
||||||
self.progress = Value('I', progress)
|
self.progress = Value("I", progress)
|
||||||
|
|
||||||
def is_cancelled(self) -> bool:
|
def is_cancelled(self) -> bool:
|
||||||
return self.cancel.value
|
return self.cancel.value
|
||||||
|
|
||||||
def get_device(self) -> DeviceParams:
|
def get_device(self) -> DeviceParams:
|
||||||
'''
|
"""
|
||||||
Get the device assigned to this job.
|
Get the device assigned to this job.
|
||||||
'''
|
"""
|
||||||
with self.device_index.get_lock():
|
with self.device_index.get_lock():
|
||||||
device_index = self.device_index.value
|
device_index = self.device_index.value
|
||||||
if device_index < 0:
|
if device_index < 0:
|
||||||
raise Exception('job has not been assigned to a device')
|
raise Exception("job has not been assigned to a device")
|
||||||
else:
|
else:
|
||||||
device = self.devices[device_index]
|
device = self.devices[device_index]
|
||||||
logger.debug('job %s assigned to device %s', self.key, device)
|
logger.debug("job %s assigned to device %s", self.key, device)
|
||||||
return device
|
return device
|
||||||
|
|
||||||
def get_progress(self) -> int:
|
def get_progress(self) -> int:
|
||||||
|
@ -54,10 +52,9 @@ class JobContext:
|
||||||
def get_progress_callback(self) -> Callable[..., None]:
|
def get_progress_callback(self) -> Callable[..., None]:
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
raise Exception('job has been cancelled')
|
raise Exception("job has been cancelled")
|
||||||
else:
|
else:
|
||||||
logger.debug('setting progress for job %s to %s',
|
logger.debug("setting progress for job %s to %s", self.key, step)
|
||||||
self.key, step)
|
|
||||||
self.set_progress(step)
|
self.set_progress(step)
|
||||||
|
|
||||||
return on_progress
|
return on_progress
|
||||||
|
@ -72,9 +69,9 @@ class JobContext:
|
||||||
|
|
||||||
|
|
||||||
class Job:
|
class Job:
|
||||||
'''
|
"""
|
||||||
Link a future to its context.
|
Link a future to its context.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
context: JobContext = None
|
context: JobContext = None
|
||||||
future: Future = None
|
future: Future = None
|
||||||
|
@ -106,7 +103,11 @@ class DevicePoolExecutor:
|
||||||
next_device: int = 0
|
next_device: int = 0
|
||||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
||||||
|
|
||||||
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
devices: List[DeviceParams],
|
||||||
|
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
|
||||||
|
):
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.jobs = []
|
self.jobs = []
|
||||||
self.next_device = 0
|
self.next_device = 0
|
||||||
|
@ -114,19 +115,25 @@ class DevicePoolExecutor:
|
||||||
device_count = len(devices)
|
device_count = len(devices)
|
||||||
if pool is None:
|
if pool is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
'creating thread pool executor for %s devices: %s', device_count, [d.device for d in devices])
|
"creating thread pool executor for %s devices: %s",
|
||||||
|
device_count,
|
||||||
|
[d.device for d in devices],
|
||||||
|
)
|
||||||
self.pool = ThreadPoolExecutor(device_count)
|
self.pool = ThreadPoolExecutor(device_count)
|
||||||
else:
|
else:
|
||||||
logger.info('using existing pool for %s devices: %s',
|
logger.info(
|
||||||
device_count, [d.device for d in devices])
|
"using existing pool for %s devices: %s",
|
||||||
|
device_count,
|
||||||
|
[d.device for d in devices],
|
||||||
|
)
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
'''
|
"""
|
||||||
Cancel a job. If the job has not been started, this will cancel
|
Cancel a job. If the job has not been started, this will cancel
|
||||||
the future and never execute it. If the job has been started, it
|
the future and never execute it. If the job has been started, it
|
||||||
should be cancelled on the next progress callback.
|
should be cancelled on the next progress callback.
|
||||||
'''
|
"""
|
||||||
for job in self.jobs:
|
for job in self.jobs:
|
||||||
if job.key == key:
|
if job.key == key:
|
||||||
if job.future.cancel():
|
if job.future.cancel():
|
||||||
|
@ -144,7 +151,7 @@ class DevicePoolExecutor:
|
||||||
progress = job.get_progress()
|
progress = job.get_progress()
|
||||||
return (done, progress)
|
return (done, progress)
|
||||||
|
|
||||||
logger.warn('checking status for unknown key: %s', key)
|
logger.warn("checking status for unknown key: %s", key)
|
||||||
return (None, 0)
|
return (None, 0)
|
||||||
|
|
||||||
def get_next_device(self):
|
def get_next_device(self):
|
||||||
|
@ -152,12 +159,14 @@ class DevicePoolExecutor:
|
||||||
if len(self.jobs) == 0:
|
if len(self.jobs) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
job_devices = [job.context.device_index.value for job in self.jobs if not job.future.done()]
|
job_devices = [
|
||||||
|
job.context.device_index.value for job in self.jobs if not job.future.done()
|
||||||
|
]
|
||||||
job_counts = Counter(range(len(self.devices)))
|
job_counts = Counter(range(len(self.devices)))
|
||||||
job_counts.update(job_devices)
|
job_counts.update(job_devices)
|
||||||
|
|
||||||
queued = job_counts.most_common()
|
queued = job_counts.most_common()
|
||||||
logger.debug('jobs queued by device: %s', queued)
|
logger.debug("jobs queued by device: %s", queued)
|
||||||
|
|
||||||
lowest_count = queued[-1][1]
|
lowest_count = queued[-1][1]
|
||||||
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
||||||
|
@ -170,7 +179,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
||||||
device = self.get_next_device()
|
device = self.get_next_device()
|
||||||
logger.info('assigning job %s to device %s', key, device)
|
logger.info("assigning job %s to device %s", key, device)
|
||||||
|
|
||||||
context = JobContext(key, self.devices, device_index=device)
|
context = JobContext(key, self.devices, device_index=device)
|
||||||
future = self.pool.submit(fn, context, *args, **kwargs)
|
future = self.pool.submit(fn, context, *args, **kwargs)
|
||||||
|
@ -180,11 +189,19 @@ class DevicePoolExecutor:
|
||||||
def job_done(f: Future):
|
def job_done(f: Future):
|
||||||
try:
|
try:
|
||||||
f.result()
|
f.result()
|
||||||
logger.info('job %s finished successfully', key)
|
logger.info("job %s finished successfully", key)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.warn('job %s failed with an error: %s', key, err)
|
logger.warn("job %s failed with an error: %s", key, err)
|
||||||
|
|
||||||
future.add_done_callback(job_done)
|
future.add_done_callback(job_done)
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||||
return [(job.key, job.context.device_index.value, job.future.done(), job.get_progress()) for job in self.jobs]
|
return [
|
||||||
|
(
|
||||||
|
job.key,
|
||||||
|
job.context.device_index.value,
|
||||||
|
job.future.done(),
|
||||||
|
job.get_progress(),
|
||||||
|
)
|
||||||
|
for job in self.jobs
|
||||||
|
]
|
||||||
|
|
|
@ -1,18 +1,11 @@
|
||||||
from diffusers import (
|
|
||||||
DiffusionPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Tuple
|
||||||
|
|
||||||
from ..params import (
|
|
||||||
DeviceParams,
|
|
||||||
Size,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
run_gc,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
from ..params import DeviceParams, Size
|
||||||
|
from ..utils import run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,17 +18,23 @@ latent_factor = 8
|
||||||
|
|
||||||
|
|
||||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
'''
|
"""
|
||||||
From https://www.travelneil.com/stable-diffusion-updates.html
|
From https://www.travelneil.com/stable-diffusion-updates.html
|
||||||
'''
|
"""
|
||||||
latents_shape = (batch, latent_channels, size.height // latent_factor,
|
latents_shape = (
|
||||||
size.width // latent_factor)
|
batch,
|
||||||
|
latent_channels,
|
||||||
|
size.height // latent_factor,
|
||||||
|
size.width // latent_factor,
|
||||||
|
)
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
def get_tile_latents(
|
||||||
|
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
||||||
|
) -> np.ndarray:
|
||||||
x, y, tile = dims
|
x, y, tile = dims
|
||||||
t = tile // latent_factor
|
t = tile // latent_factor
|
||||||
x = x // latent_factor
|
x = x // latent_factor
|
||||||
|
@ -46,27 +45,29 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
|
def load_pipeline(
|
||||||
|
pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams
|
||||||
|
):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
|
||||||
options = (pipeline, model, device.provider)
|
options = (pipeline, model, device.provider)
|
||||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
if last_pipeline_instance is not None and last_pipeline_options == options:
|
||||||
logger.debug('reusing existing diffusion pipeline')
|
logger.debug("reusing existing diffusion pipeline")
|
||||||
pipe = last_pipeline_instance
|
pipe = last_pipeline_instance
|
||||||
else:
|
else:
|
||||||
logger.debug('unloading previous diffusion pipeline')
|
logger.debug("unloading previous diffusion pipeline")
|
||||||
last_pipeline_instance = None
|
last_pipeline_instance = None
|
||||||
last_pipeline_scheduler = None
|
last_pipeline_scheduler = None
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.debug('loading new diffusion pipeline from %s', model)
|
logger.debug("loading new diffusion pipeline from %s", model)
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
provider_options=device.options,
|
provider_options=device.options,
|
||||||
subfolder='scheduler',
|
subfolder="scheduler",
|
||||||
)
|
)
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
@ -76,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, 'to'):
|
if device is not None and hasattr(pipe, "to"):
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
last_pipeline_instance = pipe
|
last_pipeline_instance = pipe
|
||||||
|
@ -84,15 +85,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
logger.debug('loading new diffusion scheduler')
|
logger.debug("loading new diffusion scheduler")
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
provider_options=device.options,
|
provider_options=device.options,
|
||||||
subfolder='scheduler',
|
subfolder="scheduler",
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, 'to'):
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
scheduler = scheduler.to(device)
|
scheduler = scheduler.to(device)
|
||||||
|
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
|
|
|
@ -1,28 +1,16 @@
|
||||||
from diffusers import (
|
|
||||||
DDPMScheduler,
|
|
||||||
OnnxRuntimeModel,
|
|
||||||
StableDiffusionUpscalePipeline,
|
|
||||||
)
|
|
||||||
from diffusers.pipeline_utils import (
|
|
||||||
ImagePipelineOutput,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from typing import Any, Callable, List, Optional, Union
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline
|
||||||
|
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
num_channels_latents = 4 # self.vae.config.latent_channels
|
num_channels_latents = 4 # self.vae.config.latent_channels
|
||||||
unet_in_channels = 7 # self.unet.config.in_channels
|
unet_in_channels = 7 # self.unet.config.in_channels
|
||||||
|
|
||||||
###
|
###
|
||||||
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
|
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
|
||||||
|
@ -30,6 +18,7 @@ unet_in_channels = 7 # self.unet.config.in_channels
|
||||||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
|
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
return image
|
return image
|
||||||
|
@ -63,8 +52,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
max_noise_level: int = 350,
|
max_noise_level: int = 350,
|
||||||
):
|
):
|
||||||
super().__init__(vae, text_encoder, tokenizer, unet,
|
super().__init__(
|
||||||
low_res_scheduler, scheduler, max_noise_level)
|
vae,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
unet,
|
||||||
|
low_res_scheduler,
|
||||||
|
scheduler,
|
||||||
|
max_noise_level,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -96,7 +92,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 3. Encode input prompt
|
# 3. Encode input prompt
|
||||||
text_embeddings = self._encode_prompt(
|
text_embeddings = self._encode_prompt(
|
||||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
|
@ -111,7 +111,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
text_embeddings_dtype = torch.float32
|
text_embeddings_dtype = torch.float32
|
||||||
|
|
||||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
|
noise = torch.randn(
|
||||||
|
image.shape, generator=generator, device=device, dtype=text_embeddings_dtype
|
||||||
|
)
|
||||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||||
|
|
||||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||||
|
@ -135,11 +137,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
num_channels_image = image.shape[1]
|
num_channels_image = image.shape[1]
|
||||||
if num_channels_latents + num_channels_image != unet_in_channels:
|
if num_channels_latents + num_channels_image != unet_in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
"Incorrect configuration settings! The config of `pipeline.unet`"
|
||||||
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
f" expects {unet_in_channels} but received `num_channels_latents`:"
|
||||||
f" `num_channels_image`: {num_channels_image} "
|
f" {num_channels_latents} + `num_channels_image`: {num_channels_image} "
|
||||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
f" = {num_channels_latents+num_channels_image}. Please verify the"
|
||||||
" `pipeline.unet` or your `image` input."
|
" config of `pipeline.unet` or your `image` input."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
@ -150,10 +152,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = (
|
||||||
|
np.concatenate([latents] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents
|
||||||
|
)
|
||||||
|
|
||||||
# concat latents, mask, masked_image_latents in the channel dimension
|
# concat latents, mask, masked_image_latents in the channel dimension
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
latent_model_input, t
|
||||||
|
)
|
||||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||||
|
|
||||||
# timestep to tensor
|
# timestep to tensor
|
||||||
|
@ -164,19 +172,25 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=text_embeddings,
|
encoder_hidden_states=text_embeddings,
|
||||||
class_labels=noise_level
|
class_labels=noise_level,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
latents = self.scheduler.step(
|
||||||
|
noise_pred, t, latents, **extra_step_kwargs
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
@ -200,7 +214,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
image = image.transpose((0, 2, 3, 1))
|
image = image.transpose((0, 2, 3, 1))
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
def _encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
):
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
|
|
||||||
text_inputs = self.tokenizer(
|
text_inputs = self.tokenizer(
|
||||||
|
@ -211,13 +232,20 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
text_input_ids = text_inputs.input_ids
|
text_input_ids = text_inputs.input_ids
|
||||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
untruncated_ids = self.tokenizer(
|
||||||
|
prompt, padding="longest", return_tensors="pt"
|
||||||
|
).input_ids
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
text_input_ids, untruncated_ids
|
||||||
|
):
|
||||||
|
removed_text = self.tokenizer.batch_decode(
|
||||||
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
"The following part of your input was truncated because CLIP can only"
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
f" handle sequences up to {self.tokenizer.model_max_length} tokens:"
|
||||||
|
f" {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# no positional arguments to text_encoder
|
# no positional arguments to text_encoder
|
||||||
|
@ -240,16 +268,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
uncond_tokens = [""] * batch_size
|
uncond_tokens = [""] * batch_size
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
"`negative_prompt` should be the same type to `prompt`, but got"
|
||||||
f" {type(prompt)}."
|
f" {type(negative_prompt)} != {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt]
|
uncond_tokens = [negative_prompt]
|
||||||
elif batch_size != len(negative_prompt):
|
elif batch_size != len(negative_prompt):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
f"`negative_prompt`: {negative_prompt} has batch size"
|
||||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size"
|
||||||
" the batch size of `prompt`."
|
f" {batch_size}. Please make sure that passed `negative_prompt`"
|
||||||
|
" matches the batch size of `prompt`."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
uncond_tokens = negative_prompt
|
uncond_tokens = negative_prompt
|
||||||
|
|
|
@ -1,41 +1,17 @@
|
||||||
from diffusers import (
|
|
||||||
OnnxStableDiffusionPipeline,
|
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
|
||||||
)
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image, ImageChops
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..chain import (
|
|
||||||
upscale_outpaint,
|
|
||||||
)
|
|
||||||
from ..device_pool import (
|
|
||||||
JobContext,
|
|
||||||
)
|
|
||||||
from ..params import (
|
|
||||||
ImageParams,
|
|
||||||
Border,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
)
|
|
||||||
from ..output import (
|
|
||||||
save_image,
|
|
||||||
save_params,
|
|
||||||
)
|
|
||||||
from ..upscale import (
|
|
||||||
run_upscale_correction,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from ..utils import (
|
|
||||||
run_gc,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
from .load import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
load_pipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||||
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
|
from ..chain import upscale_outpaint
|
||||||
|
from ..device_pool import JobContext
|
||||||
|
from ..output import save_image, save_params
|
||||||
|
from ..params import Border, ImageParams, Size, StageParams
|
||||||
|
from ..upscale import UpscaleParams, run_upscale_correction
|
||||||
|
from ..utils import ServerContext, run_gc
|
||||||
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -46,10 +22,11 @@ def run_txt2img_pipeline(
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams
|
upscale: UpscaleParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device()
|
||||||
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -68,7 +45,8 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, StageParams(), params, image, upscale=upscale)
|
job, server, StageParams(), params, image, upscale=upscale
|
||||||
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
@ -77,7 +55,7 @@ def run_txt2img_pipeline(
|
||||||
del result
|
del result
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('finished txt2img job: %s', dest)
|
logger.info("finished txt2img job: %s", dest)
|
||||||
|
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
|
@ -89,8 +67,12 @@ def run_img2img_pipeline(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(
|
||||||
params.model, params.scheduler, job.get_device())
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
params.model,
|
||||||
|
params.scheduler,
|
||||||
|
job.get_device(),
|
||||||
|
)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
|
@ -107,7 +89,8 @@ def run_img2img_pipeline(
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, StageParams(), params, image, upscale=upscale)
|
job, server, StageParams(), params, image, upscale=upscale
|
||||||
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
size = Size(*source_image.size)
|
size = Size(*source_image.size)
|
||||||
|
@ -117,7 +100,7 @@ def run_img2img_pipeline(
|
||||||
del result
|
del result
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('finished img2img job: %s', dest)
|
logger.info("finished img2img job: %s", dest)
|
||||||
|
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
|
@ -151,16 +134,14 @@ def run_inpaint_pipeline(
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
)
|
)
|
||||||
logger.info('applying mask filter and generating noise source')
|
logger.info("applying mask filter and generating noise source")
|
||||||
|
|
||||||
if image.size == source_image.size:
|
if image.size == source_image.size:
|
||||||
image = ImageChops.blend(source_image, image, strength)
|
image = ImageChops.blend(source_image, image, strength)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info("output image size does not match source, skipping post-blend")
|
||||||
'output image size does not match source, skipping post-blend')
|
|
||||||
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
|
||||||
job, server, stage, params, image, upscale=upscale)
|
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||||
|
@ -168,7 +149,7 @@ def run_inpaint_pipeline(
|
||||||
del image
|
del image
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('finished inpaint job: %s', dest)
|
logger.info("finished inpaint job: %s", dest)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
|
@ -185,7 +166,8 @@ def run_upscale_pipeline(
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, stage, params, source_image, upscale=upscale)
|
job, server, stage, params, source_image, upscale=upscale
|
||||||
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
@ -193,4 +175,4 @@ def run_upscale_pipeline(
|
||||||
del image
|
del image
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('finished upscale job: %s', dest)
|
logger.info("finished upscale job: %s", dest)
|
||||||
|
|
|
@ -1,31 +1,31 @@
|
||||||
|
import numpy as np
|
||||||
from numpy import random
|
from numpy import random
|
||||||
from PIL import Image, ImageChops, ImageFilter
|
from PIL import Image, ImageChops, ImageFilter
|
||||||
|
|
||||||
import numpy as np
|
from .params import Border, Point
|
||||||
|
|
||||||
from .params import (
|
|
||||||
Border,
|
|
||||||
Point,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pixel_index(x: int, y: int, width: int) -> int:
|
def get_pixel_index(x: int, y: int, width: int) -> int:
|
||||||
return (y * width) + x
|
return (y * width) + x
|
||||||
|
|
||||||
|
|
||||||
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
def mask_filter_none(
|
||||||
|
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||||
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height), fill)
|
noise = Image.new("RGB", (width, height), fill)
|
||||||
noise.paste(mask_image, origin)
|
noise.paste(mask_image, origin)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
def mask_filter_gaussian_multiply(
|
||||||
'''
|
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
Gaussian blur with multiply, source image centered on white canvas.
|
Gaussian blur with multiply, source image centered on white canvas.
|
||||||
'''
|
"""
|
||||||
noise = mask_filter_none(mask_image, dims, origin)
|
noise = mask_filter_none(mask_image, dims, origin)
|
||||||
|
|
||||||
for i in range(rounds):
|
for i in range(rounds):
|
||||||
|
@ -35,10 +35,12 @@ def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin:
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
def mask_filter_gaussian_screen(
|
||||||
'''
|
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
Gaussian blur, source image centered on white canvas.
|
Gaussian blur, source image centered on white canvas.
|
||||||
'''
|
"""
|
||||||
noise = mask_filter_none(mask_image, dims, origin)
|
noise = mask_filter_none(mask_image, dims, origin)
|
||||||
|
|
||||||
for i in range(rounds):
|
for i in range(rounds):
|
||||||
|
@ -48,33 +50,39 @@ def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Po
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
def noise_source_fill_edge(
|
||||||
'''
|
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
Identity transform, source image centered on white canvas.
|
Identity transform, source image centered on white canvas.
|
||||||
'''
|
"""
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height), fill)
|
noise = Image.new("RGB", (width, height), fill)
|
||||||
noise.paste(source_image, origin)
|
noise.paste(source_image, origin)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
def noise_source_fill_mask(
|
||||||
'''
|
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
Fill the whole canvas, no source or noise.
|
Fill the whole canvas, no source or noise.
|
||||||
'''
|
"""
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height), fill)
|
noise = Image.new("RGB", (width, height), fill)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
def noise_source_gaussian(
|
||||||
'''
|
source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
Gaussian blur, source image centered on white canvas.
|
Gaussian blur, source image centered on white canvas.
|
||||||
'''
|
"""
|
||||||
noise = noise_source_uniform(source_image, dims, origin)
|
noise = noise_source_uniform(source_image, dims, origin)
|
||||||
noise.paste(source_image, origin)
|
noise.paste(source_image, origin)
|
||||||
|
|
||||||
|
@ -84,7 +92,9 @@ def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point,
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
def noise_source_uniform(
|
||||||
|
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||||
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
|
||||||
|
@ -92,21 +102,19 @@ def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point,
|
||||||
noise_g = random.uniform(0, 256, size=size)
|
noise_g = random.uniform(0, 256, size=size)
|
||||||
noise_b = random.uniform(0, 256, size=size)
|
noise_b = random.uniform(0, 256, size=size)
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
for y in range(height):
|
for y in range(height):
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (
|
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||||
int(noise_r[i]),
|
|
||||||
int(noise_g[i]),
|
|
||||||
int(noise_b[i])
|
|
||||||
))
|
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
def noise_source_normal(
|
||||||
|
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||||
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
|
||||||
|
@ -114,21 +122,19 @@ def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, *
|
||||||
noise_g = random.normal(128, 32, size=size)
|
noise_g = random.normal(128, 32, size=size)
|
||||||
noise_b = random.normal(128, 32, size=size)
|
noise_b = random.normal(128, 32, size=size)
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
for y in range(height):
|
for y in range(height):
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (
|
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||||
int(noise_r[i]),
|
|
||||||
int(noise_g[i]),
|
|
||||||
int(noise_b[i])
|
|
||||||
))
|
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
def noise_source_histogram(
|
||||||
|
source_image: Image.Image, dims: Point, origin: Point, **kw
|
||||||
|
) -> Image.Image:
|
||||||
r, g, b = source_image.split()
|
r, g, b = source_image.split()
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
@ -137,35 +143,34 @@ def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point
|
||||||
hist_g = g.histogram()
|
hist_g = g.histogram()
|
||||||
hist_b = b.histogram()
|
hist_b = b.histogram()
|
||||||
|
|
||||||
noise_r = random.choice(256, p=np.divide(
|
noise_r = random.choice(
|
||||||
np.copy(hist_r), np.sum(hist_r)), size=size)
|
256, p=np.divide(np.copy(hist_r), np.sum(hist_r)), size=size
|
||||||
noise_g = random.choice(256, p=np.divide(
|
)
|
||||||
np.copy(hist_g), np.sum(hist_g)), size=size)
|
noise_g = random.choice(
|
||||||
noise_b = random.choice(256, p=np.divide(
|
256, p=np.divide(np.copy(hist_g), np.sum(hist_g)), size=size
|
||||||
np.copy(hist_b), np.sum(hist_b)), size=size)
|
)
|
||||||
|
noise_b = random.choice(
|
||||||
|
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
||||||
|
)
|
||||||
|
|
||||||
noise = Image.new('RGB', (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
for y in range(height):
|
for y in range(height):
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (
|
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||||
noise_r[i],
|
|
||||||
noise_g[i],
|
|
||||||
noise_b[i]
|
|
||||||
))
|
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
|
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
|
||||||
def expand_image(
|
def expand_image(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
mask_image: Image.Image,
|
mask_image: Image.Image,
|
||||||
expand: Border,
|
expand: Border,
|
||||||
fill='white',
|
fill="white",
|
||||||
noise_source=noise_source_histogram,
|
noise_source=noise_source_histogram,
|
||||||
mask_filter=mask_filter_none,
|
mask_filter=mask_filter_none,
|
||||||
):
|
):
|
||||||
full_width = expand.left + source_image.width + expand.right
|
full_width = expand.left + source_image.width + expand.right
|
||||||
full_height = expand.top + source_image.height + expand.bottom
|
full_height = expand.top + source_image.height + expand.bottom
|
||||||
|
@ -173,14 +178,13 @@ def expand_image(
|
||||||
dims = (full_width, full_height)
|
dims = (full_width, full_height)
|
||||||
origin = (expand.left, expand.top)
|
origin = (expand.left, expand.top)
|
||||||
|
|
||||||
full_source = Image.new('RGB', dims, fill)
|
full_source = Image.new("RGB", dims, fill)
|
||||||
full_source.paste(source_image, origin)
|
full_source.paste(source_image, origin)
|
||||||
|
|
||||||
full_mask = mask_filter(mask_image, dims, origin, fill=fill)
|
full_mask = mask_filter(mask_image, dims, origin, fill=fill)
|
||||||
full_noise = noise_source(source_image, dims, origin, fill=fill)
|
full_noise = noise_source(source_image, dims, origin, fill=fill)
|
||||||
full_noise = ImageChops.multiply(full_noise, full_mask)
|
full_noise = ImageChops.multiply(full_noise, full_mask)
|
||||||
|
|
||||||
full_source = Image.composite(
|
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
|
||||||
full_noise, full_source, full_mask.convert('L'))
|
|
||||||
|
|
||||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
|
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
logging_path = environ.get('ONNX_WEB_LOGGING_PATH', './logging.yaml')
|
logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml")
|
||||||
|
|
||||||
# setup logging config before anything else loads
|
# setup logging config before anything else loads
|
||||||
try:
|
try:
|
||||||
if path.exists(logging_path):
|
if path.exists(logging_path):
|
||||||
with open(logging_path, 'r') as f:
|
with open(logging_path, "r") as f:
|
||||||
config_logging = safe_load(f)
|
config_logging = safe_load(f)
|
||||||
dictConfig(config_logging)
|
dictConfig(config_logging)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print('error loading logging config: %s' % (err))
|
print("error loading logging config: %s" % (err))
|
||||||
|
|
|
@ -1,4 +1 @@
|
||||||
from .onnx_net import (
|
from .onnx_net import OnnxImage, OnnxNet
|
||||||
OnnxImage,
|
|
||||||
OnnxNet,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
from onnxruntime import InferenceSession
|
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import ServerContext
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
class OnnxImage():
|
|
||||||
|
class OnnxImage:
|
||||||
def __init__(self, source) -> None:
|
def __init__(self, source) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
self.data = self
|
self.data = self
|
||||||
|
@ -38,28 +37,27 @@ class OnnxImage():
|
||||||
return np.shape(self.source)
|
return np.shape(self.source)
|
||||||
|
|
||||||
|
|
||||||
class OnnxNet():
|
class OnnxNet:
|
||||||
'''
|
"""
|
||||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
model: str,
|
model: str,
|
||||||
provider: str = 'DmlExecutionProvider',
|
provider: str = "DmlExecutionProvider",
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
model_path = path.join(server.model_path, model)
|
model_path = path.join(server.model_path, model)
|
||||||
self.session = InferenceSession(
|
self.session = InferenceSession(
|
||||||
model_path, providers=[provider], provider_options=provider_options)
|
model_path, providers=[provider], provider_options=provider_options
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, image: Any) -> Any:
|
def __call__(self, image: Any) -> Any:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
output_name = self.session.get_outputs()[0].name
|
output_name = self.session.get_outputs()[0].name
|
||||||
output = self.session.run([output_name], {
|
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
||||||
input_name: image.cpu().numpy()
|
|
||||||
})[0]
|
|
||||||
return OnnxImage(output)
|
return OnnxImage(output)
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
|
|
|
@ -1,22 +1,14 @@
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from .params import (
|
from PIL import Image
|
||||||
Border,
|
|
||||||
ImageParams,
|
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
||||||
Param,
|
from .utils import ServerContext, base_join
|
||||||
Size,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
base_join,
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,13 +17,13 @@ def hash_value(sha, param: Param):
|
||||||
if param is None:
|
if param is None:
|
||||||
return
|
return
|
||||||
elif isinstance(param, float):
|
elif isinstance(param, float):
|
||||||
sha.update(bytearray(pack('!f', param)))
|
sha.update(bytearray(pack("!f", param)))
|
||||||
elif isinstance(param, int):
|
elif isinstance(param, int):
|
||||||
sha.update(bytearray(pack('!I', param)))
|
sha.update(bytearray(pack("!I", param)))
|
||||||
elif isinstance(param, str):
|
elif isinstance(param, str):
|
||||||
sha.update(param.encode('utf-8'))
|
sha.update(param.encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
logger.warn('cannot hash param: %s, %s', param, type(param))
|
logger.warn("cannot hash param: %s, %s", param, type(param))
|
||||||
|
|
||||||
|
|
||||||
def json_params(
|
def json_params(
|
||||||
|
@ -42,22 +34,22 @@ def json_params(
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
'output': output,
|
"output": output,
|
||||||
'params': params.tojson(),
|
"params": params.tojson(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if upscale is not None and border is not None:
|
if upscale is not None and border is not None:
|
||||||
size = upscale.resize(size.add_border(border))
|
size = upscale.resize(size.add_border(border))
|
||||||
|
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
json['upscale'] = upscale.tojson()
|
json["upscale"] = upscale.tojson()
|
||||||
size = upscale.resize(size)
|
size = upscale.resize(size)
|
||||||
|
|
||||||
if border is not None:
|
if border is not None:
|
||||||
json['border'] = border.tojson()
|
json["border"] = border.tojson()
|
||||||
size = size.add_border(border)
|
size = size.add_border(border)
|
||||||
|
|
||||||
json['size'] = size.tojson()
|
json["size"] = size.tojson()
|
||||||
|
|
||||||
return json
|
return json
|
||||||
|
|
||||||
|
@ -67,7 +59,7 @@ def make_output_name(
|
||||||
mode: str,
|
mode: str,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
extras: Optional[Tuple[Param]] = None
|
extras: Optional[Tuple[Param]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
now = int(time())
|
now = int(time())
|
||||||
sha = sha256()
|
sha = sha256()
|
||||||
|
@ -87,13 +79,19 @@ def make_output_name(
|
||||||
for param in extras:
|
for param in extras:
|
||||||
hash_value(sha, param)
|
hash_value(sha, param)
|
||||||
|
|
||||||
return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format)
|
return "%s_%s_%s_%s.%s" % (
|
||||||
|
mode,
|
||||||
|
params.seed,
|
||||||
|
sha.hexdigest(),
|
||||||
|
now,
|
||||||
|
ctx.image_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||||
path = base_join(ctx.output_path, output)
|
path = base_join(ctx.output_path, output)
|
||||||
image.save(path, format=ctx.image_format)
|
image.save(path, format=ctx.image_format)
|
||||||
logger.debug('saved output image to: %s', path)
|
logger.debug("saved output image to: %s", path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,9 +103,9 @@ def save_params(
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(ctx.output_path, '%s.json' % (output))
|
path = base_join(ctx.output_path, "%s.json" % (output))
|
||||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
json = json_params(output, params, size, upscale=upscale, border=border)
|
||||||
with open(path, 'w') as f:
|
with open(path, "w") as f:
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
logger.debug('saved image params to: %s', path)
|
logger.debug("saved image params to: %s", path)
|
||||||
return path
|
return path
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
half = 256 # half tile for outpainting
|
half = 256 # half tile for outpainting
|
||||||
auto = 512 # auto tile size
|
auto = 512 # auto tile size
|
||||||
hd1k = 2**10
|
hd1k = 2**10
|
||||||
hd2k = 2**11
|
hd2k = 2**11
|
||||||
hd4k = 2**12
|
hd4k = 2**12
|
||||||
|
@ -26,14 +26,14 @@ class Border:
|
||||||
self.bottom = bottom
|
self.bottom = bottom
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '%s %s %s %s' % (self.left, self.top, self.right, self.bottom)
|
return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom)
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
'left': self.left,
|
"left": self.left,
|
||||||
'right': self.right,
|
"right": self.right,
|
||||||
'top': self.top,
|
"top": self.top,
|
||||||
'bottom': self.bottom,
|
"bottom": self.bottom,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -47,32 +47,37 @@ class Size:
|
||||||
self.height = height
|
self.height = height
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '%sx%s' % (self.width, self.height)
|
return "%sx%s" % (self.width, self.height)
|
||||||
|
|
||||||
def add_border(self, border: Border):
|
def add_border(self, border: Border):
|
||||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
return Size(
|
||||||
|
border.left + self.width + border.right,
|
||||||
|
border.top + self.height + border.right,
|
||||||
|
)
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, int]:
|
def tojson(self) -> Dict[str, int]:
|
||||||
return {
|
return {
|
||||||
'height': self.height,
|
"height": self.height,
|
||||||
'width': self.width,
|
"width": self.width,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DeviceParams:
|
class DeviceParams:
|
||||||
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
|
def __init__(
|
||||||
|
self, device: str, provider: str, options: Optional[dict] = None
|
||||||
|
) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.options = options
|
self.options = options
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return '%s - %s (%s)' % (self.device, self.provider, self.options)
|
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||||
|
|
||||||
def torch_device(self) -> str:
|
def torch_device(self) -> str:
|
||||||
if self.device.startswith('cuda'):
|
if self.device.startswith("cuda"):
|
||||||
return self.device
|
return self.device
|
||||||
else:
|
else:
|
||||||
return 'cpu'
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
class ImageParams:
|
class ImageParams:
|
||||||
|
@ -84,7 +89,7 @@ class ImageParams:
|
||||||
negative_prompt: Optional[str],
|
negative_prompt: Optional[str],
|
||||||
cfg: float,
|
cfg: float,
|
||||||
steps: int,
|
steps: int,
|
||||||
seed: int
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
|
@ -96,20 +101,20 @@ class ImageParams:
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
'model': self.model,
|
"model": self.model,
|
||||||
'scheduler': self.scheduler.__name__,
|
"scheduler": self.scheduler.__name__,
|
||||||
'seed': self.seed,
|
"seed": self.seed,
|
||||||
'prompt': self.prompt,
|
"prompt": self.prompt,
|
||||||
'cfg': self.cfg,
|
"cfg": self.cfg,
|
||||||
'negativePrompt': self.negative_prompt,
|
"negativePrompt": self.negative_prompt,
|
||||||
'steps': self.steps,
|
"steps": self.steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class StageParams:
|
class StageParams:
|
||||||
'''
|
"""
|
||||||
Parameters for a chained pipeline stage
|
Parameters for a chained pipeline stage
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -123,7 +128,7 @@ class StageParams:
|
||||||
self.outscale = outscale
|
self.outscale = outscale
|
||||||
|
|
||||||
|
|
||||||
class UpscaleParams():
|
class UpscaleParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
upscale_model: str,
|
upscale_model: str,
|
||||||
|
@ -131,7 +136,7 @@ class UpscaleParams():
|
||||||
denoise: float = 0.5,
|
denoise: float = 0.5,
|
||||||
faces=True,
|
faces=True,
|
||||||
face_strength: float = 0.5,
|
face_strength: float = 0.5,
|
||||||
format: Literal['onnx', 'pth'] = 'onnx',
|
format: Literal["onnx", "pth"] = "onnx",
|
||||||
half=False,
|
half=False,
|
||||||
outscale: int = 1,
|
outscale: int = 1,
|
||||||
scale: int = 4,
|
scale: int = 4,
|
||||||
|
@ -170,8 +175,8 @@ class UpscaleParams():
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
'model': self.upscale_model,
|
"model": self.upscale_model,
|
||||||
'scale': self.scale,
|
"scale": self.scale,
|
||||||
'outscale': self.outscale,
|
"outscale": self.outscale,
|
||||||
# TODO: add more
|
# TODO: add more
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,62 +1,61 @@
|
||||||
from . import logging
|
import gc
|
||||||
|
from functools import cmp_to_key
|
||||||
|
from glob import glob
|
||||||
|
from io import BytesIO
|
||||||
|
from logging import getLogger
|
||||||
|
from os import makedirs, path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
DPMSolverSinglestepScheduler,
|
DPMSolverSinglestepScheduler,
|
||||||
EulerDiscreteScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
|
KarrasVeScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KarrasVeScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from functools import cmp_to_key
|
|
||||||
from glob import glob
|
|
||||||
from io import BytesIO
|
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from logging import getLogger
|
|
||||||
from PIL import Image
|
|
||||||
from onnxruntime import get_available_providers
|
from onnxruntime import get_available_providers
|
||||||
from os import makedirs, path
|
from PIL import Image
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
|
from . import logging
|
||||||
from .chain import (
|
from .chain import (
|
||||||
|
ChainPipeline,
|
||||||
blend_img2img,
|
blend_img2img,
|
||||||
blend_inpaint,
|
blend_inpaint,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
persist_disk,
|
persist_disk,
|
||||||
persist_s3,
|
persist_s3,
|
||||||
reduce_thumbnail,
|
|
||||||
reduce_crop,
|
reduce_crop,
|
||||||
|
reduce_thumbnail,
|
||||||
source_noise,
|
source_noise,
|
||||||
source_txt2img,
|
source_txt2img,
|
||||||
upscale_outpaint,
|
upscale_outpaint,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
ChainPipeline,
|
|
||||||
)
|
|
||||||
from .device_pool import (
|
|
||||||
DeviceParams,
|
|
||||||
DevicePoolExecutor,
|
|
||||||
)
|
)
|
||||||
|
from .device_pool import DevicePoolExecutor
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from .image import (
|
from .image import ( # mask filters; noise sources
|
||||||
# mask filters
|
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
mask_filter_none,
|
mask_filter_none,
|
||||||
# noise sources
|
|
||||||
noise_source_fill_edge,
|
noise_source_fill_edge,
|
||||||
noise_source_fill_mask,
|
noise_source_fill_mask,
|
||||||
noise_source_gaussian,
|
noise_source_gaussian,
|
||||||
|
@ -64,35 +63,20 @@ from .image import (
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .output import (
|
from .output import json_params, make_output_name
|
||||||
json_params,
|
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||||
make_output_name,
|
|
||||||
)
|
|
||||||
from .params import (
|
|
||||||
Border,
|
|
||||||
DeviceParams,
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
ServerContext,
|
||||||
base_join,
|
base_join,
|
||||||
is_debug,
|
|
||||||
get_and_clamp_float,
|
get_and_clamp_float,
|
||||||
get_and_clamp_int,
|
get_and_clamp_int,
|
||||||
get_from_list,
|
get_from_list,
|
||||||
get_from_map,
|
get_from_map,
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
get_size,
|
get_size,
|
||||||
ServerContext,
|
is_debug,
|
||||||
)
|
)
|
||||||
|
|
||||||
import gc
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
# config caching
|
# config caching
|
||||||
|
@ -100,53 +84,53 @@ config_params = {}
|
||||||
|
|
||||||
# pipeline params
|
# pipeline params
|
||||||
platform_providers = {
|
platform_providers = {
|
||||||
'amd': 'DmlExecutionProvider',
|
"amd": "DmlExecutionProvider",
|
||||||
'cpu': 'CPUExecutionProvider',
|
"cpu": "CPUExecutionProvider",
|
||||||
'cuda': 'CUDAExecutionProvider',
|
"cuda": "CUDAExecutionProvider",
|
||||||
'directml': 'DmlExecutionProvider',
|
"directml": "DmlExecutionProvider",
|
||||||
'nvidia': 'CUDAExecutionProvider',
|
"nvidia": "CUDAExecutionProvider",
|
||||||
'rocm': 'ROCMExecutionProvider',
|
"rocm": "ROCMExecutionProvider",
|
||||||
}
|
}
|
||||||
pipeline_schedulers = {
|
pipeline_schedulers = {
|
||||||
'ddim': DDIMScheduler,
|
"ddim": DDIMScheduler,
|
||||||
'ddpm': DDPMScheduler,
|
"ddpm": DDPMScheduler,
|
||||||
'dpm-multi': DPMSolverMultistepScheduler,
|
"dpm-multi": DPMSolverMultistepScheduler,
|
||||||
'dpm-single': DPMSolverSinglestepScheduler,
|
"dpm-single": DPMSolverSinglestepScheduler,
|
||||||
'euler': EulerDiscreteScheduler,
|
"euler": EulerDiscreteScheduler,
|
||||||
'euler-a': EulerAncestralDiscreteScheduler,
|
"euler-a": EulerAncestralDiscreteScheduler,
|
||||||
'heun': HeunDiscreteScheduler,
|
"heun": HeunDiscreteScheduler,
|
||||||
'k-dpm-2-a': KDPM2AncestralDiscreteScheduler,
|
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||||
'k-dpm-2': KDPM2DiscreteScheduler,
|
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||||
'karras-ve': KarrasVeScheduler,
|
"karras-ve": KarrasVeScheduler,
|
||||||
'lms-discrete': LMSDiscreteScheduler,
|
"lms-discrete": LMSDiscreteScheduler,
|
||||||
'pndm': PNDMScheduler,
|
"pndm": PNDMScheduler,
|
||||||
}
|
}
|
||||||
noise_sources = {
|
noise_sources = {
|
||||||
'fill-edge': noise_source_fill_edge,
|
"fill-edge": noise_source_fill_edge,
|
||||||
'fill-mask': noise_source_fill_mask,
|
"fill-mask": noise_source_fill_mask,
|
||||||
'gaussian': noise_source_gaussian,
|
"gaussian": noise_source_gaussian,
|
||||||
'histogram': noise_source_histogram,
|
"histogram": noise_source_histogram,
|
||||||
'normal': noise_source_normal,
|
"normal": noise_source_normal,
|
||||||
'uniform': noise_source_uniform,
|
"uniform": noise_source_uniform,
|
||||||
}
|
}
|
||||||
mask_filters = {
|
mask_filters = {
|
||||||
'none': mask_filter_none,
|
"none": mask_filter_none,
|
||||||
'gaussian-multiply': mask_filter_gaussian_multiply,
|
"gaussian-multiply": mask_filter_gaussian_multiply,
|
||||||
'gaussian-screen': mask_filter_gaussian_screen,
|
"gaussian-screen": mask_filter_gaussian_screen,
|
||||||
}
|
}
|
||||||
chain_stages = {
|
chain_stages = {
|
||||||
'blend-img2img': blend_img2img,
|
"blend-img2img": blend_img2img,
|
||||||
'blend-inpaint': blend_inpaint,
|
"blend-inpaint": blend_inpaint,
|
||||||
'correct-gfpgan': correct_gfpgan,
|
"correct-gfpgan": correct_gfpgan,
|
||||||
'persist-disk': persist_disk,
|
"persist-disk": persist_disk,
|
||||||
'persist-s3': persist_s3,
|
"persist-s3": persist_s3,
|
||||||
'reduce-crop': reduce_crop,
|
"reduce-crop": reduce_crop,
|
||||||
'reduce-thumbnail': reduce_thumbnail,
|
"reduce-thumbnail": reduce_thumbnail,
|
||||||
'source-noise': source_noise,
|
"source-noise": source_noise,
|
||||||
'source-txt2img': source_txt2img,
|
"source-txt2img": source_txt2img,
|
||||||
'upscale-outpaint': upscale_outpaint,
|
"upscale-outpaint": upscale_outpaint,
|
||||||
'upscale-resrgan': upscale_resrgan,
|
"upscale-resrgan": upscale_resrgan,
|
||||||
'upscale-stable-diffusion': upscale_stable_diffusion,
|
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Available ORT providers
|
# Available ORT providers
|
||||||
|
@ -158,7 +142,7 @@ correction_models = []
|
||||||
upscaling_models = []
|
upscaling_models = []
|
||||||
|
|
||||||
|
|
||||||
def get_config_value(key: str, subkey: str = 'default'):
|
def get_config_value(key: str, subkey: str = "default"):
|
||||||
return config_params.get(key).get(subkey)
|
return config_params.get(key).get(subkey)
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,7 +158,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# platform stuff
|
# platform stuff
|
||||||
device_name = request.args.get('platform', available_platforms[0].device)
|
device_name = request.args.get("platform", available_platforms[0].device)
|
||||||
device = None
|
device = None
|
||||||
|
|
||||||
for platform in available_platforms:
|
for platform in available_platforms:
|
||||||
|
@ -182,78 +166,101 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
device = available_platforms[0]
|
device = available_platforms[0]
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
raise Exception('unknown device')
|
raise Exception("unknown device")
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
model = get_not_empty(request.args, 'model', get_config_value('model'))
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(model)
|
model_path = get_model_path(model)
|
||||||
scheduler = get_from_map(request.args, 'scheduler',
|
scheduler = get_from_map(
|
||||||
pipeline_schedulers, get_config_value('scheduler'))
|
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
|
||||||
|
)
|
||||||
|
|
||||||
# image params
|
# image params
|
||||||
prompt = get_not_empty(request.args,
|
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
|
||||||
'prompt', get_config_value('prompt'))
|
negative_prompt = request.args.get("negativePrompt", None)
|
||||||
negative_prompt = request.args.get('negativePrompt', None)
|
|
||||||
|
|
||||||
if negative_prompt is not None and negative_prompt.strip() == '':
|
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
|
|
||||||
cfg = get_and_clamp_float(
|
cfg = get_and_clamp_float(
|
||||||
request.args, 'cfg',
|
request.args,
|
||||||
get_config_value('cfg'),
|
"cfg",
|
||||||
get_config_value('cfg', 'max'),
|
get_config_value("cfg"),
|
||||||
get_config_value('cfg', 'min'))
|
get_config_value("cfg", "max"),
|
||||||
|
get_config_value("cfg", "min"),
|
||||||
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args, 'steps',
|
request.args,
|
||||||
get_config_value('steps'),
|
"steps",
|
||||||
get_config_value('steps', 'max'),
|
get_config_value("steps"),
|
||||||
get_config_value('steps', 'min'))
|
get_config_value("steps", "max"),
|
||||||
|
get_config_value("steps", "min"),
|
||||||
|
)
|
||||||
height = get_and_clamp_int(
|
height = get_and_clamp_int(
|
||||||
request.args, 'height',
|
request.args,
|
||||||
get_config_value('height'),
|
"height",
|
||||||
get_config_value('height', 'max'),
|
get_config_value("height"),
|
||||||
get_config_value('height', 'min'))
|
get_config_value("height", "max"),
|
||||||
|
get_config_value("height", "min"),
|
||||||
|
)
|
||||||
width = get_and_clamp_int(
|
width = get_and_clamp_int(
|
||||||
request.args, 'width',
|
request.args,
|
||||||
get_config_value('width'),
|
"width",
|
||||||
get_config_value('width', 'max'),
|
get_config_value("width"),
|
||||||
get_config_value('width', 'min'))
|
get_config_value("width", "max"),
|
||||||
|
get_config_value("width", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
seed = int(request.args.get('seed', -1))
|
seed = int(request.args.get("seed", -1))
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
logger.info(
|
||||||
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||||
|
user,
|
||||||
|
steps,
|
||||||
|
scheduler.__name__,
|
||||||
|
model_path,
|
||||||
|
device.provider,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
cfg,
|
||||||
|
seed,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
params = ImageParams(model_path, scheduler, prompt,
|
params = ImageParams(
|
||||||
negative_prompt, cfg, steps, seed)
|
model_path, scheduler, prompt, negative_prompt, cfg, steps, seed
|
||||||
|
)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
def border_from_request() -> Border:
|
def border_from_request() -> Border:
|
||||||
left = get_and_clamp_int(request.args, 'left', 0,
|
left = get_and_clamp_int(
|
||||||
get_config_value('width', 'max'), 0)
|
request.args, "left", 0, get_config_value("width", "max"), 0
|
||||||
right = get_and_clamp_int(request.args, 'right',
|
)
|
||||||
0, get_config_value('width', 'max'), 0)
|
right = get_and_clamp_int(
|
||||||
top = get_and_clamp_int(request.args, 'top', 0,
|
request.args, "right", 0, get_config_value("width", "max"), 0
|
||||||
get_config_value('height', 'max'), 0)
|
)
|
||||||
|
top = get_and_clamp_int(
|
||||||
|
request.args, "top", 0, get_config_value("height", "max"), 0
|
||||||
|
)
|
||||||
bottom = get_and_clamp_int(
|
bottom = get_and_clamp_int(
|
||||||
request.args, 'bottom', 0, get_config_value('height', 'max'), 0)
|
request.args, "bottom", 0, get_config_value("height", "max"), 0
|
||||||
|
)
|
||||||
|
|
||||||
return Border(left, right, top, bottom)
|
return Border(left, right, top, bottom)
|
||||||
|
|
||||||
|
|
||||||
def upscale_from_request() -> UpscaleParams:
|
def upscale_from_request() -> UpscaleParams:
|
||||||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
|
||||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
|
||||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
|
||||||
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
|
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
|
||||||
correction = get_from_list(request.args, 'correction', correction_models)
|
correction = get_from_list(request.args, "correction", correction_models)
|
||||||
faces = get_not_empty(request.args, 'faces', 'false') == 'true'
|
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||||
face_strength = get_and_clamp_float(
|
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||||
request.args, 'faceStrength', 0.5, 1.0, 0.0)
|
|
||||||
|
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscaling,
|
upscaling,
|
||||||
|
@ -261,7 +268,7 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
denoise=denoise,
|
denoise=denoise,
|
||||||
faces=faces,
|
faces=faces,
|
||||||
face_strength=face_strength,
|
face_strength=face_strength,
|
||||||
format='onnx',
|
format="onnx",
|
||||||
outscale=outscale,
|
outscale=outscale,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
@ -269,7 +276,7 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
|
|
||||||
def check_paths(context: ServerContext):
|
def check_paths(context: ServerContext):
|
||||||
if not path.exists(context.model_path):
|
if not path.exists(context.model_path):
|
||||||
raise RuntimeError('model path must exist')
|
raise RuntimeError("model path must exist")
|
||||||
|
|
||||||
if not path.exists(context.output_path):
|
if not path.exists(context.output_path):
|
||||||
makedirs(context.output_path)
|
makedirs(context.output_path)
|
||||||
|
@ -286,35 +293,41 @@ def load_models(context: ServerContext):
|
||||||
global correction_models
|
global correction_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
|
||||||
diffusion_models = [get_model_name(f) for f in glob(
|
diffusion_models = [
|
||||||
path.join(context.model_path, 'diffusion-*'))]
|
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
|
||||||
diffusion_models.extend([
|
]
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))])
|
diffusion_models.extend(
|
||||||
|
[
|
||||||
|
get_model_name(f)
|
||||||
|
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
|
||||||
|
]
|
||||||
|
)
|
||||||
diffusion_models = list(set(diffusion_models))
|
diffusion_models = list(set(diffusion_models))
|
||||||
diffusion_models.sort()
|
diffusion_models.sort()
|
||||||
|
|
||||||
correction_models = [
|
correction_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))]
|
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
||||||
|
]
|
||||||
correction_models = list(set(correction_models))
|
correction_models = list(set(correction_models))
|
||||||
correction_models.sort()
|
correction_models.sort()
|
||||||
|
|
||||||
upscaling_models = [
|
upscaling_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))]
|
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||||
|
]
|
||||||
upscaling_models = list(set(upscaling_models))
|
upscaling_models = list(set(upscaling_models))
|
||||||
upscaling_models.sort()
|
upscaling_models.sort()
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext):
|
def load_params(context: ServerContext):
|
||||||
global config_params
|
global config_params
|
||||||
params_file = path.join(context.params_path, 'params.json')
|
params_file = path.join(context.params_path, "params.json")
|
||||||
with open(params_file, 'r') as f:
|
with open(params_file, "r") as f:
|
||||||
config_params = yaml.safe_load(f)
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
if 'platform' in config_params and context.default_platform is not None:
|
if "platform" in config_params and context.default_platform is not None:
|
||||||
logger.info('overriding default platform to %s',
|
logger.info("overriding default platform to %s", context.default_platform)
|
||||||
context.default_platform)
|
config_platform = config_params.get("platform")
|
||||||
config_platform = config_params.get('platform')
|
config_platform["default"] = context.default_platform
|
||||||
config_platform['default'] = context.default_platform
|
|
||||||
|
|
||||||
|
|
||||||
def load_platforms():
|
def load_platforms():
|
||||||
|
@ -323,30 +336,42 @@ def load_platforms():
|
||||||
providers = get_available_providers()
|
providers = get_available_providers()
|
||||||
|
|
||||||
for potential in platform_providers:
|
for potential in platform_providers:
|
||||||
if platform_providers[potential] in providers and potential not in context.block_platforms:
|
if (
|
||||||
if potential == 'cuda':
|
platform_providers[potential] in providers
|
||||||
|
and potential not in context.block_platforms
|
||||||
|
):
|
||||||
|
if potential == "cuda":
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
available_platforms.append(DeviceParams(potential, platform_providers[potential], {
|
available_platforms.append(
|
||||||
'device_id': i,
|
DeviceParams(
|
||||||
}))
|
potential,
|
||||||
|
platform_providers[potential],
|
||||||
|
{
|
||||||
|
"device_id": i,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
available_platforms.append(DeviceParams(
|
available_platforms.append(
|
||||||
potential, platform_providers[potential]))
|
DeviceParams(potential, platform_providers[potential])
|
||||||
|
)
|
||||||
|
|
||||||
# make sure CPU is last on the list
|
# make sure CPU is last on the list
|
||||||
def cpu_last(a: DeviceParams, b: DeviceParams):
|
def cpu_last(a: DeviceParams, b: DeviceParams):
|
||||||
if a.device == 'cpu' and b.device == 'cpu':
|
if a.device == "cpu" and b.device == "cpu":
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if a.device == 'cpu':
|
if a.device == "cpu":
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
||||||
|
|
||||||
logger.info('available acceleration platforms: %s',
|
logger.info(
|
||||||
', '.join([str(p) for p in available_platforms]))
|
"available acceleration platforms: %s",
|
||||||
|
", ".join([str(p) for p in available_platforms]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
context = ServerContext.from_environ()
|
||||||
|
@ -365,16 +390,22 @@ if is_debug():
|
||||||
|
|
||||||
|
|
||||||
def ready_reply(ready: bool, progress: int = 0):
|
def ready_reply(ready: bool, progress: int = 0):
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'progress': progress,
|
{
|
||||||
'ready': ready,
|
"progress": progress,
|
||||||
})
|
"ready": ready,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def error_reply(err: str):
|
def error_reply(err: str):
|
||||||
response = make_response(jsonify({
|
response = make_response(
|
||||||
'error': err,
|
jsonify(
|
||||||
}))
|
{
|
||||||
|
"error": err,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
response.status_code = 400
|
response.status_code = 400
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -383,151 +414,154 @@ def get_model_path(model: str):
|
||||||
return base_join(context.model_path, model)
|
return base_join(context.model_path, model)
|
||||||
|
|
||||||
|
|
||||||
def serve_bundle_file(filename='index.html'):
|
def serve_bundle_file(filename="index.html"):
|
||||||
return send_from_directory(path.join('..', context.bundle_path), filename)
|
return send_from_directory(path.join("..", context.bundle_path), filename)
|
||||||
|
|
||||||
|
|
||||||
# routes
|
# routes
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return serve_bundle_file()
|
return serve_bundle_file()
|
||||||
|
|
||||||
|
|
||||||
@app.route('/<path:filename>')
|
@app.route("/<path:filename>")
|
||||||
def index_path(filename):
|
def index_path(filename):
|
||||||
return serve_bundle_file(filename)
|
return serve_bundle_file(filename)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api')
|
@app.route("/api")
|
||||||
def introspect():
|
def introspect():
|
||||||
return {
|
return {
|
||||||
'name': 'onnx-web',
|
"name": "onnx-web",
|
||||||
'routes': [{
|
"routes": [
|
||||||
'path': url_from_rule(rule),
|
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
||||||
'methods': list(rule.methods).sort()
|
for rule in app.url_map.iter_rules()
|
||||||
} for rule in app.url_map.iter_rules()]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/masks')
|
@app.route("/api/settings/masks")
|
||||||
def list_mask_filters():
|
def list_mask_filters():
|
||||||
return jsonify(list(mask_filters.keys()))
|
return jsonify(list(mask_filters.keys()))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/models')
|
@app.route("/api/settings/models")
|
||||||
def list_models():
|
def list_models():
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'diffusion': diffusion_models,
|
{
|
||||||
'correction': correction_models,
|
"diffusion": diffusion_models,
|
||||||
'upscaling': upscaling_models,
|
"correction": correction_models,
|
||||||
})
|
"upscaling": upscaling_models,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/noises')
|
@app.route("/api/settings/noises")
|
||||||
def list_noise_sources():
|
def list_noise_sources():
|
||||||
return jsonify(list(noise_sources.keys()))
|
return jsonify(list(noise_sources.keys()))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/params')
|
@app.route("/api/settings/params")
|
||||||
def list_params():
|
def list_params():
|
||||||
return jsonify(config_params)
|
return jsonify(config_params)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/platforms')
|
@app.route("/api/settings/platforms")
|
||||||
def list_platforms():
|
def list_platforms():
|
||||||
return jsonify([p.device for p in available_platforms])
|
return jsonify([p.device for p in available_platforms])
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/schedulers')
|
@app.route("/api/settings/schedulers")
|
||||||
def list_schedulers():
|
def list_schedulers():
|
||||||
return jsonify(list(pipeline_schedulers.keys()))
|
return jsonify(list(pipeline_schedulers.keys()))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/img2img', methods=['POST'])
|
@app.route("/api/img2img", methods=["POST"])
|
||||||
def img2img():
|
def img2img():
|
||||||
if 'source' not in request.files:
|
if "source" not in request.files:
|
||||||
return error_reply('source image is required')
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source_file = request.files.get('source')
|
source_file = request.files.get("source")
|
||||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
request.args,
|
||||||
'strength',
|
"strength",
|
||||||
get_config_value('strength'),
|
get_config_value("strength"),
|
||||||
get_config_value('strength', 'max'),
|
get_config_value("strength", "max"),
|
||||||
get_config_value('strength', 'min'))
|
get_config_value("strength", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
||||||
context,
|
|
||||||
'img2img',
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
extras=(strength,))
|
|
||||||
logger.info("img2img job queued for: %s", output)
|
logger.info("img2img job queued for: %s", output)
|
||||||
|
|
||||||
source_image.thumbnail((size.width, size.height))
|
source_image.thumbnail((size.width, size.height))
|
||||||
executor.submit(output, run_img2img_pipeline,
|
executor.submit(
|
||||||
context, params, output, upscale, source_image, strength)
|
output,
|
||||||
|
run_img2img_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
source_image,
|
||||||
|
strength,
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/txt2img', methods=['POST'])
|
@app.route("/api/txt2img", methods=["POST"])
|
||||||
def txt2img():
|
def txt2img():
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(context, "txt2img", params, size)
|
||||||
context,
|
|
||||||
'txt2img',
|
|
||||||
params,
|
|
||||||
size)
|
|
||||||
logger.info("txt2img job queued for: %s", output)
|
logger.info("txt2img job queued for: %s", output)
|
||||||
|
|
||||||
executor.submit(
|
executor.submit(
|
||||||
output, run_txt2img_pipeline, context, params, size, output, upscale)
|
output, run_txt2img_pipeline, context, params, size, output, upscale
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/inpaint', methods=['POST'])
|
@app.route("/api/inpaint", methods=["POST"])
|
||||||
def inpaint():
|
def inpaint():
|
||||||
if 'source' not in request.files:
|
if "source" not in request.files:
|
||||||
return error_reply('source image is required')
|
return error_reply("source image is required")
|
||||||
|
|
||||||
if 'mask' not in request.files:
|
if "mask" not in request.files:
|
||||||
return error_reply('mask image is required')
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
source_file = request.files.get('source')
|
source_file = request.files.get("source")
|
||||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
mask_file = request.files.get('mask')
|
mask_file = request.files.get("mask")
|
||||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
expand = border_from_request()
|
expand = border_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
fill_color = get_not_empty(request.args, 'fillColor', 'white')
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||||
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
|
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
|
||||||
noise_source = get_from_map(
|
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
|
||||||
request.args, 'noise', noise_sources, 'histogram')
|
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
request.args,
|
||||||
'strength',
|
"strength",
|
||||||
get_config_value('strength'),
|
get_config_value("strength"),
|
||||||
get_config_value('strength', 'max'),
|
get_config_value("strength", "max"),
|
||||||
get_config_value('strength', 'min'))
|
get_config_value("strength", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(
|
||||||
context,
|
context,
|
||||||
'inpaint',
|
"inpaint",
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
extras=(
|
extras=(
|
||||||
|
@ -539,7 +573,7 @@ def inpaint():
|
||||||
noise_source.__name__,
|
noise_source.__name__,
|
||||||
strength,
|
strength,
|
||||||
fill_color,
|
fill_color,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
logger.info("inpaint job queued for: %s", output)
|
logger.info("inpaint job queued for: %s", output)
|
||||||
|
|
||||||
|
@ -559,123 +593,131 @@ def inpaint():
|
||||||
noise_source,
|
noise_source,
|
||||||
mask_filter,
|
mask_filter,
|
||||||
strength,
|
strength,
|
||||||
fill_color)
|
fill_color,
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/upscale', methods=['POST'])
|
@app.route("/api/upscale", methods=["POST"])
|
||||||
def upscale():
|
def upscale():
|
||||||
if 'source' not in request.files:
|
if "source" not in request.files:
|
||||||
return error_reply('source image is required')
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source_file = request.files.get('source')
|
source_file = request.files.get("source")
|
||||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(context, "upscale", params, size)
|
||||||
context,
|
|
||||||
'upscale',
|
|
||||||
params,
|
|
||||||
size)
|
|
||||||
logger.info("upscale job queued for: %s", output)
|
logger.info("upscale job queued for: %s", output)
|
||||||
|
|
||||||
source_image.thumbnail((size.width, size.height))
|
source_image.thumbnail((size.width, size.height))
|
||||||
executor.submit(output, run_upscale_pipeline,
|
executor.submit(
|
||||||
context, params, size, output, upscale, source_image)
|
output,
|
||||||
|
run_upscale_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
source_image,
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/chain', methods=['POST'])
|
@app.route("/api/chain", methods=["POST"])
|
||||||
def chain():
|
def chain():
|
||||||
logger.debug('chain pipeline request: %s, %s',
|
logger.debug(
|
||||||
request.form.keys(), request.files.keys())
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||||
body = request.form.get('chain') or request.files.get('chain')
|
)
|
||||||
|
body = request.form.get("chain") or request.files.get("chain")
|
||||||
if body is None:
|
if body is None:
|
||||||
return error_reply('chain pipeline must have a body')
|
return error_reply("chain pipeline must have a body")
|
||||||
|
|
||||||
data = yaml.safe_load(body)
|
data = yaml.safe_load(body)
|
||||||
with open('./schema.yaml', 'r') as f:
|
with open("./schema.yaml", "r") as f:
|
||||||
schema = yaml.safe_load(f.read())
|
schema = yaml.safe_load(f.read())
|
||||||
|
|
||||||
logger.info('validating chain request: %s against %s', data, schema)
|
logger.info("validating chain request: %s against %s", data, schema)
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
# get defaults from the regular parameters
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
output = make_output_name(
|
output = make_output_name(context, "chain", params, size)
|
||||||
context,
|
|
||||||
'chain',
|
|
||||||
params,
|
|
||||||
size)
|
|
||||||
|
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get('stages', []):
|
for stage_data in data.get("stages", []):
|
||||||
callback = chain_stages[stage_data.get('type')]
|
callback = chain_stages[stage_data.get("type")]
|
||||||
kwargs = stage_data.get('params', {})
|
kwargs = stage_data.get("params", {})
|
||||||
logger.info('request stage: %s, %s', callback.__name__, kwargs)
|
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
||||||
|
|
||||||
stage = StageParams(
|
stage = StageParams(
|
||||||
stage_data.get('name', callback.__name__),
|
stage_data.get("name", callback.__name__),
|
||||||
tile_size=get_size(kwargs.get('tile_size')),
|
tile_size=get_size(kwargs.get("tile_size")),
|
||||||
outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4),
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'border' in kwargs:
|
if "border" in kwargs:
|
||||||
border = Border.even(int(kwargs.get('border')))
|
border = Border.even(int(kwargs.get("border")))
|
||||||
kwargs['border'] = border
|
kwargs["border"] = border
|
||||||
|
|
||||||
if 'upscale' in kwargs:
|
if "upscale" in kwargs:
|
||||||
upscale = UpscaleParams(kwargs.get('upscale'))
|
upscale = UpscaleParams(kwargs.get("upscale"))
|
||||||
kwargs['upscale'] = upscale
|
kwargs["upscale"] = upscale
|
||||||
|
|
||||||
stage_source_name = 'source:%s' % (stage.name)
|
stage_source_name = "source:%s" % (stage.name)
|
||||||
stage_mask_name = 'mask:%s' % (stage.name)
|
stage_mask_name = "mask:%s" % (stage.name)
|
||||||
|
|
||||||
if stage_source_name in request.files:
|
if stage_source_name in request.files:
|
||||||
logger.debug('loading source image %s for pipeline stage %s',
|
logger.debug(
|
||||||
stage_source_name, stage.name)
|
"loading source image %s for pipeline stage %s",
|
||||||
|
stage_source_name,
|
||||||
|
stage.name,
|
||||||
|
)
|
||||||
source_file = request.files.get(stage_source_name)
|
source_file = request.files.get(stage_source_name)
|
||||||
source_image = Image.open(
|
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
BytesIO(source_file.read())).convert('RGB')
|
|
||||||
source_image = source_image.thumbnail((512, 512))
|
source_image = source_image.thumbnail((512, 512))
|
||||||
kwargs['source_image'] = source_image
|
kwargs["source_image"] = source_image
|
||||||
|
|
||||||
if stage_mask_name in request.files:
|
if stage_mask_name in request.files:
|
||||||
logger.debug('loading mask image %s for pipeline stage %s',
|
logger.debug(
|
||||||
stage_mask_name, stage.name)
|
"loading mask image %s for pipeline stage %s",
|
||||||
|
stage_mask_name,
|
||||||
|
stage.name,
|
||||||
|
)
|
||||||
mask_file = request.files.get(stage_mask_name)
|
mask_file = request.files.get(stage_mask_name)
|
||||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
mask_image = mask_image.thumbnail((512, 512))
|
mask_image = mask_image.thumbnail((512, 512))
|
||||||
kwargs['mask_image'] = mask_image
|
kwargs["mask_image"] = mask_image
|
||||||
|
|
||||||
pipeline.append((callback, stage, kwargs))
|
pipeline.append((callback, stage, kwargs))
|
||||||
|
|
||||||
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
empty_source = Image.new('RGB', (size.width, size.height))
|
empty_source = Image.new("RGB", (size.width, size.height))
|
||||||
executor.submit(output, pipeline, context,
|
executor.submit(
|
||||||
params, empty_source, output=output, size=size)
|
output, pipeline, context, params, empty_source, output=output, size=size
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/cancel', methods=['PUT'])
|
@app.route("/api/cancel", methods=["PUT"])
|
||||||
def cancel():
|
def cancel():
|
||||||
output_file = request.args.get('output', None)
|
output_file = request.args.get("output", None)
|
||||||
|
|
||||||
cancel = executor.cancel(output_file)
|
cancel = executor.cancel(output_file)
|
||||||
|
|
||||||
return ready_reply(cancel)
|
return ready_reply(cancel)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/ready')
|
@app.route("/api/ready")
|
||||||
def ready():
|
def ready():
|
||||||
output_file = request.args.get('output', None)
|
output_file = request.args.get("output", None)
|
||||||
|
|
||||||
done, progress = executor.done(output_file)
|
done, progress = executor.done(output_file)
|
||||||
|
|
||||||
|
@ -687,11 +729,13 @@ def ready():
|
||||||
return ready_reply(done, progress=progress)
|
return ready_reply(done, progress=progress)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
@app.route("/api/status")
|
||||||
def status():
|
def status():
|
||||||
return jsonify(executor.status())
|
return jsonify(executor.status())
|
||||||
|
|
||||||
|
|
||||||
@app.route('/output/<path:filename>')
|
@app.route("/output/<path:filename>")
|
||||||
def output(filename: str):
|
def output(filename: str):
|
||||||
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
|
return send_from_directory(
|
||||||
|
path.join("..", context.output_path), filename, as_attachment=False
|
||||||
|
)
|
||||||
|
|
|
@ -1,24 +1,16 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .chain import (
|
from .chain import (
|
||||||
correct_gfpgan,
|
|
||||||
upscale_stable_diffusion,
|
|
||||||
upscale_resrgan,
|
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
|
correct_gfpgan,
|
||||||
|
upscale_resrgan,
|
||||||
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .device_pool import (
|
from .device_pool import JobContext
|
||||||
JobContext,
|
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
)
|
from .utils import ServerContext
|
||||||
from .params import (
|
|
||||||
ImageParams,
|
|
||||||
SizeChart,
|
|
||||||
StageParams,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
ServerContext,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,27 +24,25 @@ def run_upscale_correction(
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
'''
|
"""
|
||||||
This is a convenience method for a chain pipeline that will run upscaling and
|
This is a convenience method for a chain pipeline that will run upscaling and
|
||||||
correction, based on the `upscale` params.
|
correction, based on the `upscale` params.
|
||||||
'''
|
"""
|
||||||
logger.info('running upscaling and correction pipeline')
|
logger.info("running upscaling and correction pipeline")
|
||||||
|
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
|
|
||||||
if upscale.scale > 1:
|
if upscale.scale > 1:
|
||||||
if 'esrgan' in upscale.upscale_model:
|
if "esrgan" in upscale.upscale_model:
|
||||||
stage = StageParams(tile_size=stage.tile_size,
|
stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale)
|
||||||
outscale=upscale.outscale)
|
|
||||||
chain.append((upscale_resrgan, stage, None))
|
chain.append((upscale_resrgan, stage, None))
|
||||||
elif 'stable-diffusion' in upscale.upscale_model:
|
elif "stable-diffusion" in upscale.upscale_model:
|
||||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
chain.append((upscale_stable_diffusion, stage, None))
|
chain.append((upscale_stable_diffusion, stage, None))
|
||||||
|
|
||||||
if upscale.faces:
|
if upscale.faces:
|
||||||
stage = StageParams(tile_size=stage.tile_size,
|
stage = StageParams(tile_size=stage.tile_size, outscale=1)
|
||||||
outscale=1)
|
|
||||||
chain.append((correct_gfpgan, stage, None))
|
chain.append((correct_gfpgan, stage, None))
|
||||||
|
|
||||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
|
import gc
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import gc
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .params import (
|
from .params import SizeChart
|
||||||
SizeChart,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -15,15 +13,15 @@ logger = getLogger(__name__)
|
||||||
class ServerContext:
|
class ServerContext:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bundle_path: str = '.',
|
bundle_path: str = ".",
|
||||||
model_path: str = '.',
|
model_path: str = ".",
|
||||||
output_path: str = '.',
|
output_path: str = ".",
|
||||||
params_path: str = '.',
|
params_path: str = ".",
|
||||||
cors_origin: str = '*',
|
cors_origin: str = "*",
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
block_platforms: List[str] = [],
|
block_platforms: List[str] = [],
|
||||||
default_platform: str = None,
|
default_platform: str = None,
|
||||||
image_format: str = 'png',
|
image_format: str = "png",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -38,40 +36,39 @@ class ServerContext:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
return ServerContext(
|
return ServerContext(
|
||||||
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
|
bundle_path=environ.get(
|
||||||
path.join('..', 'gui', 'out')),
|
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
||||||
model_path=environ.get('ONNX_WEB_MODEL_PATH',
|
|
||||||
path.join('..', 'models')),
|
|
||||||
output_path=environ.get(
|
|
||||||
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
|
|
||||||
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
|
|
||||||
# others
|
|
||||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
|
||||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
|
||||||
block_platforms=environ.get(
|
|
||||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
|
||||||
default_platform=environ.get(
|
|
||||||
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
|
||||||
image_format=environ.get(
|
|
||||||
'ONNX_WEB_IMAGE_FORMAT', 'png'
|
|
||||||
),
|
),
|
||||||
|
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||||
|
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||||
|
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||||
|
# others
|
||||||
|
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||||
|
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
|
||||||
|
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||||
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def base_join(base: str, tail: str) -> str:
|
def base_join(base: str, tail: str) -> str:
|
||||||
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||||
return path.join(base, tail_path)
|
return path.join(base, tail_path)
|
||||||
|
|
||||||
|
|
||||||
def is_debug() -> bool:
|
def is_debug() -> bool:
|
||||||
return environ.get('DEBUG') is not None
|
return environ.get("DEBUG") is not None
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
def get_and_clamp_float(
|
||||||
|
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
||||||
|
) -> float:
|
||||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int:
|
def get_and_clamp_int(
|
||||||
|
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
||||||
|
) -> int:
|
||||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +77,7 @@ def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
logger.warn('invalid selection: %s', selected)
|
logger.warn("invalid selection: %s", selected)
|
||||||
if len(values) > 0:
|
if len(values) > 0:
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
|
@ -118,10 +115,10 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||||
|
|
||||||
return int(val)
|
return int(val)
|
||||||
|
|
||||||
raise Exception('invalid size')
|
raise Exception("invalid size")
|
||||||
|
|
||||||
|
|
||||||
def run_gc():
|
def run_gc():
|
||||||
logger.debug('running garbage collection')
|
logger.debug("running garbage collection")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
Loading…
Reference in New Issue