feat(api): add reduce stages, noise source
This commit is contained in:
parent
8d346cbed0
commit
c905fbb728
|
@ -19,6 +19,15 @@ from .persist_disk import (
|
||||||
from .persist_s3 import (
|
from .persist_s3 import (
|
||||||
persist_s3,
|
persist_s3,
|
||||||
)
|
)
|
||||||
|
from .reduce_crop import (
|
||||||
|
reduce_crop,
|
||||||
|
)
|
||||||
|
from .reduce_thumbnail import (
|
||||||
|
reduce_thumbnail,
|
||||||
|
)
|
||||||
|
from .source_noise import (
|
||||||
|
source_noise,
|
||||||
|
)
|
||||||
from .source_txt2img import (
|
from .source_txt2img import (
|
||||||
source_txt2img,
|
source_txt2img,
|
||||||
)
|
)
|
||||||
|
|
|
@ -70,11 +70,11 @@ class ChainPipeline:
|
||||||
kwargs = stage_kwargs or {}
|
kwargs = stage_kwargs or {}
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
|
||||||
logger.info('running stage %s on result image with dimensions %sx%s, %s',
|
logger.info('running stage %s on image with dimensions %sx%s, %s',
|
||||||
name, image.width, image.height, kwargs.keys())
|
name, image.width, image.height, kwargs.keys())
|
||||||
|
|
||||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||||
logger.info('source image larger than tile size of %s, tiling stage',
|
logger.info('image larger than tile size of %s, tiling stage',
|
||||||
stage_params.tile_size)
|
stage_params.tile_size)
|
||||||
|
|
||||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||||
|
@ -89,7 +89,7 @@ class ChainPipeline:
|
||||||
image = process_tile_grid(
|
image = process_tile_grid(
|
||||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||||
else:
|
else:
|
||||||
logger.info('source image within tile size, running stage')
|
logger.info('image within tile size, running stage')
|
||||||
image = stage_pipe(ctx, stage_params, params, image,
|
image = stage_pipe(ctx, stage_params, params, image,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import (
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
StageParams,
|
||||||
|
)
|
||||||
|
from ..utils import (
|
||||||
|
ServerContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_crop(
|
||||||
|
ctx: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source_image: Image.Image,
|
||||||
|
*,
|
||||||
|
origin: Size,
|
||||||
|
size: Size,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
image = source_image.crop(
|
||||||
|
(origin.width, origin.height, size.width, size.height))
|
||||||
|
logger.info('created thumbnail with dimensions: %sx%s',
|
||||||
|
image.width, image.height)
|
||||||
|
return image
|
|
@ -0,0 +1,28 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import (
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
StageParams,
|
||||||
|
)
|
||||||
|
from ..utils import (
|
||||||
|
ServerContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_thumbnail(
|
||||||
|
ctx: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source_image: Image.Image,
|
||||||
|
*,
|
||||||
|
size: Size,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
image = source_image.thumbnail((size.width, size.height))
|
||||||
|
logger.info('created thumbnail with dimensions: %sx%s',
|
||||||
|
image.width, image.height)
|
||||||
|
return image
|
|
@ -0,0 +1,38 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from ..params import (
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
StageParams,
|
||||||
|
)
|
||||||
|
from ..utils import (
|
||||||
|
ServerContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def source_noise(
|
||||||
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
source_image: Image.Image,
|
||||||
|
*,
|
||||||
|
size: Size,
|
||||||
|
noise_source: Callable,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
prompt = prompt or params.prompt
|
||||||
|
logger.info('generating image from noise source')
|
||||||
|
|
||||||
|
if source_image is not None:
|
||||||
|
logger.warn(
|
||||||
|
'a source image was passed to a noise stage, but will be discarded')
|
||||||
|
|
||||||
|
output = noise_source(source_image, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
|
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||||
|
return output
|
|
@ -1,6 +1,4 @@
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
|
||||||
DDPMScheduler,
|
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
if upscale.format == 'onnx':
|
if upscale.format == 'onnx':
|
||||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path)
|
||||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
|
||||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
vae=AutoencoderKL.from_pretrained(
|
|
||||||
model_path, subfolder='vae_encoder'),
|
|
||||||
low_res_scheduler=DDPMScheduler.from_pretrained(
|
|
||||||
model_path, subfolder='scheduler'),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path)
|
||||||
'stabilityai/stable-diffusion-x4-upscaler')
|
|
||||||
|
|
||||||
last_pipeline_instance = pipeline
|
last_pipeline_instance = pipeline
|
||||||
last_pipeline_params = cache_params
|
last_pipeline_params = cache_params
|
||||||
|
|
|
@ -6,6 +6,7 @@ from diffusers import (
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from onnx import load, save_model
|
from onnx import load, save_model
|
||||||
|
@ -202,7 +203,7 @@ def onnx_export(
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
|
||||||
'''
|
'''
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
'''
|
'''
|
||||||
|
@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
||||||
|
|
||||||
|
if single_vae:
|
||||||
|
logger.info('converting model with single VAE')
|
||||||
|
|
||||||
if path.isdir(dest_path):
|
if path.isdir(dest_path):
|
||||||
logger.info('ONNX model already exists, skipping.')
|
logger.info('ONNX model already exists, skipping.')
|
||||||
return
|
return
|
||||||
|
@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||||
)
|
)
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
|
||||||
# VAE ENCODER
|
if single_vae:
|
||||||
vae_encoder = pipeline.vae
|
# SINGLE VAE
|
||||||
vae_in_channels = vae_encoder.config.in_channels
|
vae_only = pipeline.vae
|
||||||
vae_sample_size = vae_encoder.config.sample_size
|
vae_in_channels = vae_only.config.in_channels
|
||||||
# need to get the raw tensor output (sample) from the encoder
|
vae_sample_size = vae_only.config.sample_size
|
||||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
# need to get the raw tensor output (sample) from the encoder
|
||||||
sample, return_dict)[0].sample()
|
vae_only.forward = lambda sample, return_dict: vae_only.encode(
|
||||||
onnx_export(
|
sample, return_dict)[0].sample()
|
||||||
vae_encoder,
|
onnx_export(
|
||||||
model_args=(
|
vae_only,
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
model_args=(
|
||||||
device=training_device, dtype=dtype),
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
False,
|
device=training_device, dtype=dtype),
|
||||||
),
|
False,
|
||||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
),
|
||||||
ordered_input_names=["sample", "return_dict"],
|
output_path=output_path / "vae" / "model.onnx",
|
||||||
output_names=["latent_sample"],
|
ordered_input_names=["sample", "return_dict"],
|
||||||
dynamic_axes={
|
output_names=["latent_sample"],
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
dynamic_axes={
|
||||||
},
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
opset=opset,
|
},
|
||||||
)
|
opset=opset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# VAE ENCODER
|
||||||
|
vae_encoder = pipeline.vae
|
||||||
|
vae_in_channels = vae_encoder.config.in_channels
|
||||||
|
vae_sample_size = vae_encoder.config.sample_size
|
||||||
|
# need to get the raw tensor output (sample) from the encoder
|
||||||
|
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||||
|
sample, return_dict)[0].sample()
|
||||||
|
onnx_export(
|
||||||
|
vae_encoder,
|
||||||
|
model_args=(
|
||||||
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
|
device=training_device, dtype=dtype),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||||
|
ordered_input_names=["sample", "return_dict"],
|
||||||
|
output_names=["latent_sample"],
|
||||||
|
dynamic_axes={
|
||||||
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
|
},
|
||||||
|
opset=opset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# VAE DECODER
|
||||||
|
vae_decoder = pipeline.vae
|
||||||
|
vae_latent_channels = vae_decoder.config.latent_channels
|
||||||
|
vae_out_channels = vae_decoder.config.out_channels
|
||||||
|
# forward only through the decoder part
|
||||||
|
vae_decoder.forward = vae_encoder.decode
|
||||||
|
onnx_export(
|
||||||
|
vae_decoder,
|
||||||
|
model_args=(
|
||||||
|
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||||
|
device=training_device, dtype=dtype),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||||
|
ordered_input_names=["latent_sample", "return_dict"],
|
||||||
|
output_names=["sample"],
|
||||||
|
dynamic_axes={
|
||||||
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
|
},
|
||||||
|
opset=opset,
|
||||||
|
)
|
||||||
|
|
||||||
# VAE DECODER
|
|
||||||
vae_decoder = pipeline.vae
|
|
||||||
vae_latent_channels = vae_decoder.config.latent_channels
|
|
||||||
vae_out_channels = vae_decoder.config.out_channels
|
|
||||||
# forward only through the decoder part
|
|
||||||
vae_decoder.forward = vae_encoder.decode
|
|
||||||
onnx_export(
|
|
||||||
vae_decoder,
|
|
||||||
model_args=(
|
|
||||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
|
||||||
device=training_device, dtype=dtype),
|
|
||||||
False,
|
|
||||||
),
|
|
||||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
|
||||||
ordered_input_names=["latent_sample", "return_dict"],
|
|
||||||
output_names=["sample"],
|
|
||||||
dynamic_axes={
|
|
||||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
||||||
},
|
|
||||||
opset=opset,
|
|
||||||
)
|
|
||||||
del pipeline.vae
|
del pipeline.vae
|
||||||
|
|
||||||
# SAFETY CHECKER
|
# SAFETY CHECKER
|
||||||
|
@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
|
|
||||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
if single_vae:
|
||||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
onnx_pipeline = StableDiffusionUpscalePipeline(
|
||||||
output_path / "vae_encoder"),
|
vae=OnnxRuntimeModel.from_pretrained(
|
||||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
output_path / "vae"),
|
||||||
output_path / "vae_decoder"),
|
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
output_path / "text_encoder"),
|
||||||
output_path / "text_encoder"),
|
tokenizer=pipeline.tokenizer,
|
||||||
tokenizer=pipeline.tokenizer,
|
low_res_scheduler=pipeline.scheduler,
|
||||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||||
scheduler=pipeline.scheduler,
|
scheduler=pipeline.scheduler,
|
||||||
safety_checker=safety_checker,
|
)
|
||||||
feature_extractor=feature_extractor,
|
else:
|
||||||
requires_safety_checker=safety_checker is not None,
|
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||||
)
|
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
||||||
|
output_path / "vae_encoder"),
|
||||||
|
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
||||||
|
output_path / "vae_decoder"),
|
||||||
|
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||||
|
output_path / "text_encoder"),
|
||||||
|
tokenizer=pipeline.tokenizer,
|
||||||
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||||
|
scheduler=pipeline.scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
requires_safety_checker=safety_checker is not None,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info('exporting ONNX model')
|
logger.info('exporting ONNX model')
|
||||||
|
|
||||||
|
@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||||
|
|
||||||
del pipeline
|
del pipeline
|
||||||
del onnx_pipeline
|
del onnx_pipeline
|
||||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
|
||||||
output_path, provider="CPUExecutionProvider")
|
if single_vae:
|
||||||
|
_ = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
output_path, provider="CPUExecutionProvider"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||||
|
output_path, provider="CPUExecutionProvider")
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
|
||||||
|
@ -409,7 +457,8 @@ def load_models(args, models: Models):
|
||||||
if source[0] in args.skip:
|
if source[0] in args.skip:
|
||||||
logger.info('Skipping model: %s', source[0])
|
logger.info('Skipping model: %s', source[0])
|
||||||
else:
|
else:
|
||||||
convert_diffuser(*source, args.opset, args.half, args.token)
|
single_vae = 'upscaling' in source[0]
|
||||||
|
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
|
||||||
|
|
||||||
if args.upscaling:
|
if args.upscaling:
|
||||||
for source in models.get('upscaling'):
|
for source in models.get('upscaling'):
|
||||||
|
|
|
@ -32,6 +32,9 @@ from .chain import (
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
persist_disk,
|
persist_disk,
|
||||||
persist_s3,
|
persist_s3,
|
||||||
|
reduce_thumbnail,
|
||||||
|
reduce_crop,
|
||||||
|
source_noise,
|
||||||
source_txt2img,
|
source_txt2img,
|
||||||
upscale_outpaint,
|
upscale_outpaint,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
|
@ -61,7 +64,6 @@ from .params import (
|
||||||
Border,
|
Border,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Size,
|
Size,
|
||||||
SizeChart,
|
|
||||||
StageParams,
|
StageParams,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
|
@ -129,6 +131,9 @@ chain_stages = {
|
||||||
'correct-gfpgan': correct_gfpgan,
|
'correct-gfpgan': correct_gfpgan,
|
||||||
'persist-disk': persist_disk,
|
'persist-disk': persist_disk,
|
||||||
'persist-s3': persist_s3,
|
'persist-s3': persist_s3,
|
||||||
|
'reduce-crop': reduce_crop,
|
||||||
|
'reduce-thumbnail': reduce_thumbnail,
|
||||||
|
'source-noise': source_noise,
|
||||||
'source-txt2img': source_txt2img,
|
'source-txt2img': source_txt2img,
|
||||||
'upscale-outpaint': upscale_outpaint,
|
'upscale-outpaint': upscale_outpaint,
|
||||||
'upscale-resrgan': upscale_resrgan,
|
'upscale-resrgan': upscale_resrgan,
|
||||||
|
|
|
@ -56,6 +56,9 @@
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
"stabilityai",
|
"stabilityai",
|
||||||
"stringcase",
|
"stringcase",
|
||||||
|
"uncond",
|
||||||
|
"unet",
|
||||||
|
"untruncated",
|
||||||
"upsampler",
|
"upsampler",
|
||||||
"upscaling",
|
"upscaling",
|
||||||
"venv",
|
"venv",
|
||||||
|
|
Loading…
Reference in New Issue