1
0
Fork 0

feat(api): add reduce stages, noise source

This commit is contained in:
Sean Sube 2023-01-29 15:23:01 -06:00
parent 8d346cbed0
commit c905fbb728
10 changed files with 229 additions and 80 deletions

View File

@ -19,6 +19,15 @@ from .persist_disk import (
from .persist_s3 import ( from .persist_s3 import (
persist_s3, persist_s3,
) )
from .reduce_crop import (
reduce_crop,
)
from .reduce_thumbnail import (
reduce_thumbnail,
)
from .source_noise import (
source_noise,
)
from .source_txt2img import ( from .source_txt2img import (
source_txt2img, source_txt2img,
) )

View File

@ -70,11 +70,11 @@ class ChainPipeline:
kwargs = stage_kwargs or {} kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs} kwargs = {**pipeline_kwargs, **kwargs}
logger.info('running stage %s on result image with dimensions %sx%s, %s', logger.info('running stage %s on image with dimensions %sx%s, %s',
name, image.width, image.height, kwargs.keys()) name, image.width, image.height, kwargs.keys())
if image.width > stage_params.tile_size or image.height > stage_params.tile_size: if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
logger.info('source image larger than tile size of %s, tiling stage', logger.info('image larger than tile size of %s, tiling stage',
stage_params.tile_size) stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
@ -89,7 +89,7 @@ class ChainPipeline:
image = process_tile_grid( image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile]) image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else: else:
logger.info('source image within tile size, running stage') logger.info('image within tile size, running stage')
image = stage_pipe(ctx, stage_params, params, image, image = stage_pipe(ctx, stage_params, params, image,
**kwargs) **kwargs)

View File

@ -1,7 +1,6 @@
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from ..params import ( from ..params import (
ImageParams, ImageParams,
StageParams, StageParams,

View File

@ -0,0 +1,30 @@
from logging import getLogger
from PIL import Image
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def reduce_crop(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
origin: Size,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.crop(
(origin.width, origin.height, size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image

View File

@ -0,0 +1,28 @@
from logging import getLogger
from PIL import Image
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def reduce_thumbnail(
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.thumbnail((size.width, size.height))
logger.info('created thumbnail with dimensions: %sx%s',
image.width, image.height)
return image

View File

@ -0,0 +1,38 @@
from logging import getLogger
from PIL import Image
from typing import Callable
from ..params import (
ImageParams,
Size,
StageParams,
)
from ..utils import (
ServerContext,
)
logger = getLogger(__name__)
def source_noise(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
size: Size,
noise_source: Callable,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('generating image from noise source')
if source_image is not None:
logger.warn(
'a source image was passed to a noise stage, but will be discarded')
output = noise_source(source_image, (size.width, size.height), (0, 0))
logger.info('final output image size: %sx%s', output.width, output.height)
return output

View File

@ -1,6 +1,4 @@
from diffusers import ( from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
from logging import getLogger from logging import getLogger
@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance return last_pipeline_instance
if upscale.format == 'onnx': if upscale.format == 'onnx':
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'> pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path)
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
else: else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained( pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path)
'stabilityai/stable-diffusion-x4-upscaler')
last_pipeline_instance = pipeline last_pipeline_instance = pipeline
last_pipeline_params = cache_params last_pipeline_params = cache_params

View File

@ -6,6 +6,7 @@ from diffusers import (
OnnxRuntimeModel, OnnxRuntimeModel,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
) )
from logging import getLogger from logging import getLogger
from onnx import load, save_model from onnx import load, save_model
@ -202,7 +203,7 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
''' '''
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
''' '''
@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
logger.info('converting Diffusers model: %s -> %s/', name, dest_path) logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
if single_vae:
logger.info('converting model with single VAE')
if path.isdir(dest_path): if path.isdir(dest_path):
logger.info('ONNX model already exists, skipping.') logger.info('ONNX model already exists, skipping.')
return return
@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
) )
del pipeline.unet del pipeline.unet
# VAE ENCODER if single_vae:
vae_encoder = pipeline.vae # SINGLE VAE
vae_in_channels = vae_encoder.config.in_channels vae_only = pipeline.vae
vae_sample_size = vae_encoder.config.sample_size vae_in_channels = vae_only.config.in_channels
# need to get the raw tensor output (sample) from the encoder vae_sample_size = vae_only.config.sample_size
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( # need to get the raw tensor output (sample) from the encoder
sample, return_dict)[0].sample() vae_only.forward = lambda sample, return_dict: vae_only.encode(
onnx_export( sample, return_dict)[0].sample()
vae_encoder, onnx_export(
model_args=( vae_only,
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( model_args=(
device=training_device, dtype=dtype), torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
False, device=training_device, dtype=dtype),
), False,
output_path=output_path / "vae_encoder" / "model.onnx", ),
ordered_input_names=["sample", "return_dict"], output_path=output_path / "vae" / "model.onnx",
output_names=["latent_sample"], ordered_input_names=["sample", "return_dict"],
dynamic_axes={ output_names=["latent_sample"],
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, dynamic_axes={
}, "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
opset=opset, },
) opset=opset,
)
else:
# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
del pipeline.vae del pipeline.vae
# SAFETY CHECKER # SAFETY CHECKER
@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
safety_checker = None safety_checker = None
feature_extractor = None feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline( if single_vae:
vae_encoder=OnnxRuntimeModel.from_pretrained( onnx_pipeline = StableDiffusionUpscalePipeline(
output_path / "vae_encoder"), vae=OnnxRuntimeModel.from_pretrained(
vae_decoder=OnnxRuntimeModel.from_pretrained( output_path / "vae"),
output_path / "vae_decoder"), text_encoder=OnnxRuntimeModel.from_pretrained(
text_encoder=OnnxRuntimeModel.from_pretrained( output_path / "text_encoder"),
output_path / "text_encoder"), tokenizer=pipeline.tokenizer,
tokenizer=pipeline.tokenizer, low_res_scheduler=pipeline.scheduler,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler, scheduler=pipeline.scheduler,
safety_checker=safety_checker, )
feature_extractor=feature_extractor, else:
requires_safety_checker=safety_checker is not None, onnx_pipeline = OnnxStableDiffusionPipeline(
) vae_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(
output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(
output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
logger.info('exporting ONNX model') logger.info('exporting ONNX model')
@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
del pipeline del pipeline
del onnx_pipeline del onnx_pipeline
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider") if single_vae:
_ = StableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider"
)
else:
_ = OnnxStableDiffusionPipeline.from_pretrained(
output_path, provider="CPUExecutionProvider")
logger.info("ONNX pipeline is loadable") logger.info("ONNX pipeline is loadable")
@ -409,7 +457,8 @@ def load_models(args, models: Models):
if source[0] in args.skip: if source[0] in args.skip:
logger.info('Skipping model: %s', source[0]) logger.info('Skipping model: %s', source[0])
else: else:
convert_diffuser(*source, args.opset, args.half, args.token) single_vae = 'upscaling' in source[0]
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
if args.upscaling: if args.upscaling:
for source in models.get('upscaling'): for source in models.get('upscaling'):

View File

@ -32,6 +32,9 @@ from .chain import (
correct_gfpgan, correct_gfpgan,
persist_disk, persist_disk,
persist_s3, persist_s3,
reduce_thumbnail,
reduce_crop,
source_noise,
source_txt2img, source_txt2img,
upscale_outpaint, upscale_outpaint,
upscale_resrgan, upscale_resrgan,
@ -61,7 +64,6 @@ from .params import (
Border, Border,
ImageParams, ImageParams,
Size, Size,
SizeChart,
StageParams, StageParams,
UpscaleParams, UpscaleParams,
) )
@ -129,6 +131,9 @@ chain_stages = {
'correct-gfpgan': correct_gfpgan, 'correct-gfpgan': correct_gfpgan,
'persist-disk': persist_disk, 'persist-disk': persist_disk,
'persist-s3': persist_s3, 'persist-s3': persist_s3,
'reduce-crop': reduce_crop,
'reduce-thumbnail': reduce_thumbnail,
'source-noise': source_noise,
'source-txt2img': source_txt2img, 'source-txt2img': source_txt2img,
'upscale-outpaint': upscale_outpaint, 'upscale-outpaint': upscale_outpaint,
'upscale-resrgan': upscale_resrgan, 'upscale-resrgan': upscale_resrgan,

View File

@ -56,6 +56,9 @@
"spinalcase", "spinalcase",
"stabilityai", "stabilityai",
"stringcase", "stringcase",
"uncond",
"unet",
"untruncated",
"upsampler", "upsampler",
"upscaling", "upscaling",
"venv", "venv",