1
0
Fork 0

feat(api): add reduce stages, noise source

This commit is contained in:
Sean Sube 2023-01-29 15:23:01 -06:00
parent 8d346cbed0
commit c905fbb728
10 changed files with 229 additions and 80 deletions

View File

@ -19,6 +19,15 @@ from .persist_disk import (
from .persist_s3 import (
persist_s3,
)
from .reduce_crop import (
reduce_crop,
)
from .reduce_thumbnail import (
reduce_thumbnail,
)
from .source_noise import (
source_noise,
)
from .source_txt2img import (
source_txt2img,
)

View File

@ -70,11 +70,11 @@ class ChainPipeline:
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.info('running stage %s on result image with dimensions %sx%s, %s',
logger.info('running stage %s on image with dimensions %sx%s, %s',
name, image.width, image.height, kwargs.keys())
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
logger.info('source image larger than tile size of %s, tiling stage',
logger.info('image larger than tile size of %s, tiling stage',
stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
@ -89,7 +89,7 @@ class ChainPipeline:
image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('source image within tile size, running stage')
logger.info('image within tile size, running stage')
image = stage_pipe(ctx, stage_params, params, image,
**kwargs)

View File

@ -1,7 +1,6 @@
from logging import getLogger
from PIL import Image
from ..params import (
ImageParams,
StageParams,

View File

@ -0,0 +1,30 @@
from logging import getLogger
from PIL import Image
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def reduce_crop(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
origin: Size,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.crop(
(origin.width, origin.height, size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image

View File

@ -0,0 +1,28 @@
from logging import getLogger
from PIL import Image
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def reduce_thumbnail(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.thumbnail((size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image

View File

@ -0,0 +1,38 @@
from logging import getLogger
from PIL import Image
from typing import Callable
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def source_noise(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
size: Size,
noise_source: Callable,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image from noise source')
if source_image is not None:
logger.warn(
'a source image was passed to a noise stage, but will be discarded')
output = noise_source(source_image, (size.width, size.height), (0, 0))
logger.info('final output image size: %sx%s', output.width, output.height)
return output

View File

@ -1,6 +1,4 @@
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionUpscalePipeline,
)
from logging import getLogger
@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance
if upscale.format == 'onnx':
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path)
last_pipeline_instance = pipeline
last_pipeline_params = cache_params

View File

@ -6,6 +6,7 @@ from diffusers import (
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from logging import getLogger
from onnx import load, save_model
@ -202,7 +203,7 @@ def onnx_export(
@torch.no_grad()
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
'''
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
'''
@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
# diffusers go into a directory rather than .onnx file
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
if single_vae:
logger.info('converting model with single VAE')
if path.isdir(dest_path):
logger.info('ONNX model already exists, skipping.')
return
@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
)
del pipeline.unet
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
if single_vae:
# SINGLE VAE
vae_only = pipeline.vae
vae_in_channels = vae_only.config.in_channels
vae_sample_size = vae_only.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_only.forward = lambda sample, return_dict: vae_only.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_only,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
del pipeline.vae
# SAFETY CHECKER
@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
if single_vae:
onnx_pipeline = StableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained(
output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
)
else:
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
logger.info('exporting ONNX model')
@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
del pipeline
del onnx_pipeline
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")
if single_vae:
_ = StableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")
logger.info("ONNX pipeline is loadable")
@ -409,7 +457,8 @@ def load_models(args, models: Models):
if source[0] in args.skip:
logger.info('Skipping model: %s', source[0])
else:
convert_diffuser(*source, args.opset, args.half, args.token)
single_vae = 'upscaling' in source[0]
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
if args.upscaling:
for source in models.get('upscaling'):

View File

@ -32,6 +32,9 @@ from .chain import (
correct_gfpgan,
persist_disk,
persist_s3,
reduce_thumbnail,
reduce_crop,
source_noise,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
@ -61,7 +64,6 @@ from .params import (
Border,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
@ -129,6 +131,9 @@ chain_stages = {
'correct-gfpgan': correct_gfpgan,
'persist-disk': persist_disk,
'persist-s3': persist_s3,
'reduce-crop': reduce_crop,
'reduce-thumbnail': reduce_thumbnail,
'source-noise': source_noise,
'source-txt2img': source_txt2img,
'upscale-outpaint': upscale_outpaint,
'upscale-resrgan': upscale_resrgan,

View File

@ -56,6 +56,9 @@
"spinalcase",
"stabilityai",
"stringcase",
"uncond",
"unet",
"untruncated",
"upsampler",
"upscaling",
"venv",