feat(api): add reduce stages, noise source
This commit is contained in:
parent
8d346cbed0
commit
c905fbb728
|
@ -19,6 +19,15 @@ from .persist_disk import (
|
|||
from .persist_s3 import (
|
||||
persist_s3,
|
||||
)
|
||||
from .reduce_crop import (
|
||||
reduce_crop,
|
||||
)
|
||||
from .reduce_thumbnail import (
|
||||
reduce_thumbnail,
|
||||
)
|
||||
from .source_noise import (
|
||||
source_noise,
|
||||
)
|
||||
from .source_txt2img import (
|
||||
source_txt2img,
|
||||
)
|
||||
|
|
|
@ -70,11 +70,11 @@ class ChainPipeline:
|
|||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
|
||||
logger.info('running stage %s on result image with dimensions %sx%s, %s',
|
||||
logger.info('running stage %s on image with dimensions %sx%s, %s',
|
||||
name, image.width, image.height, kwargs.keys())
|
||||
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
logger.info('source image larger than tile size of %s, tiling stage',
|
||||
logger.info('image larger than tile size of %s, tiling stage',
|
||||
stage_params.tile_size)
|
||||
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
|
@ -89,7 +89,7 @@ class ChainPipeline:
|
|||
image = process_tile_grid(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
else:
|
||||
logger.info('source image within tile size, running stage')
|
||||
logger.info('image within tile size, running stage')
|
||||
image = stage_pipe(ctx, stage_params, params, image,
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_crop(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
origin: Size,
|
||||
size: Size,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
image = source_image.crop(
|
||||
(origin.width, origin.height, size.width, size.height))
|
||||
logger.info('created thumbnail with dimensions: %sx%s',
|
||||
image.width, image.height)
|
||||
return image
|
|
@ -0,0 +1,28 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_thumbnail(
|
||||
ctx: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
size: Size,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
image = source_image.thumbnail((size.width, size.height))
|
||||
logger.info('created thumbnail with dimensions: %sx%s',
|
||||
image.width, image.height)
|
||||
return image
|
|
@ -0,0 +1,38 @@
|
|||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from typing import Callable
|
||||
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_noise(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
size: Size,
|
||||
noise_source: Callable,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image from noise source')
|
||||
|
||||
if source_image is not None:
|
||||
logger.warn(
|
||||
'a source image was passed to a noise stage, but will be discarded')
|
||||
|
||||
output = noise_source(source_image, (size.width, size.height), (0, 0))
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
return output
|
|
@ -1,6 +1,4 @@
|
|||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
|
@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
|||
return last_pipeline_instance
|
||||
|
||||
if upscale.format == 'onnx':
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
vae=AutoencoderKL.from_pretrained(
|
||||
model_path, subfolder='vae_encoder'),
|
||||
low_res_scheduler=DDPMScheduler.from_pretrained(
|
||||
model_path, subfolder='scheduler'),
|
||||
)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path)
|
||||
else:
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler')
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path)
|
||||
|
||||
last_pipeline_instance = pipeline
|
||||
last_pipeline_params = cache_params
|
||||
|
|
|
@ -6,6 +6,7 @@ from diffusers import (
|
|||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from onnx import load, save_model
|
||||
|
@ -202,7 +203,7 @@ def onnx_export(
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
|
||||
'''
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
'''
|
||||
|
@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
|||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
||||
|
||||
if single_vae:
|
||||
logger.info('converting model with single VAE')
|
||||
|
||||
if path.isdir(dest_path):
|
||||
logger.info('ONNX model already exists, skipping.')
|
||||
return
|
||||
|
@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
|||
)
|
||||
del pipeline.unet
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
if single_vae:
|
||||
# SINGLE VAE
|
||||
vae_only = pipeline.vae
|
||||
vae_in_channels = vae_only.config.in_channels
|
||||
vae_sample_size = vae_only.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_only.forward = lambda sample, return_dict: vae_only.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_only,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
else:
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
|
@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
|||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_encoder"),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
if single_vae:
|
||||
onnx_pipeline = StableDiffusionUpscalePipeline(
|
||||
vae=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
low_res_scheduler=pipeline.scheduler,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
)
|
||||
else:
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_encoder"),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
logger.info('exporting ONNX model')
|
||||
|
||||
|
@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
|||
|
||||
del pipeline
|
||||
del onnx_pipeline
|
||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider")
|
||||
|
||||
if single_vae:
|
||||
_ = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider"
|
||||
)
|
||||
else:
|
||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider")
|
||||
|
||||
logger.info("ONNX pipeline is loadable")
|
||||
|
||||
|
||||
|
@ -409,7 +457,8 @@ def load_models(args, models: Models):
|
|||
if source[0] in args.skip:
|
||||
logger.info('Skipping model: %s', source[0])
|
||||
else:
|
||||
convert_diffuser(*source, args.opset, args.half, args.token)
|
||||
single_vae = 'upscaling' in source[0]
|
||||
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get('upscaling'):
|
||||
|
|
|
@ -32,6 +32,9 @@ from .chain import (
|
|||
correct_gfpgan,
|
||||
persist_disk,
|
||||
persist_s3,
|
||||
reduce_thumbnail,
|
||||
reduce_crop,
|
||||
source_noise,
|
||||
source_txt2img,
|
||||
upscale_outpaint,
|
||||
upscale_resrgan,
|
||||
|
@ -61,7 +64,6 @@ from .params import (
|
|||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
|
@ -129,6 +131,9 @@ chain_stages = {
|
|||
'correct-gfpgan': correct_gfpgan,
|
||||
'persist-disk': persist_disk,
|
||||
'persist-s3': persist_s3,
|
||||
'reduce-crop': reduce_crop,
|
||||
'reduce-thumbnail': reduce_thumbnail,
|
||||
'source-noise': source_noise,
|
||||
'source-txt2img': source_txt2img,
|
||||
'upscale-outpaint': upscale_outpaint,
|
||||
'upscale-resrgan': upscale_resrgan,
|
||||
|
|
|
@ -56,6 +56,9 @@
|
|||
"spinalcase",
|
||||
"stabilityai",
|
||||
"stringcase",
|
||||
"uncond",
|
||||
"unet",
|
||||
"untruncated",
|
||||
"upsampler",
|
||||
"upscaling",
|
||||
"venv",
|
||||
|
|
Loading…
Reference in New Issue