1
0
Fork 0

testing ONNX upscaling

This commit is contained in:
Sean Sube 2023-01-29 18:42:05 -06:00
parent ca613cabe1
commit bacce0ace9
6 changed files with 58 additions and 40 deletions

View File

@ -63,11 +63,7 @@ def upscale_stable_diffusion(
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
pipeline = load_stable_diffusion(ctx, upscale) pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
if upscale.format == 'onnx':
generator = np.random.default_rng(params.seed)
else:
generator = torch.manual_seed(params.seed)
return pipeline( return pipeline(
params.prompt, params.prompt,

View File

@ -253,7 +253,17 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
) )
del pipeline.text_encoder del pipeline.text_encoder
logger.info('UNET config: %s', pipeline.unet.config)
# UNET # UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
# unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = 4
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = False
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / "model.onnx" unet_path = output_path / "unet" / "model.onnx"
@ -265,13 +275,10 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2).to(device=training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to( torch.randn(2, num_tokens, text_hidden_size).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype),
# TODO: needs to be Int or Long for upscaling, Bool for regular unet_scale,
4,
# False,
), ),
output_path=unet_path, output_path=unet_path,
ordered_input_names=["sample", "timestep", ordered_input_names=unet_inputs,
"encoder_hidden_states", "return_dict"],
# has to be different from "sample" for correct tracing # has to be different from "sample" for correct tracing
output_names=["out_sample"], output_names=["out_sample"],
dynamic_axes={ dynamic_axes={
@ -300,25 +307,26 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
del pipeline.unet del pipeline.unet
if single_vae: if single_vae:
logger.info('VAE config: %s', pipeline.vae.config)
# SINGLE VAE # SINGLE VAE
vae_only = pipeline.vae vae_only = pipeline.vae
vae_in_channels = vae_only.config.in_channels vae_latent_channels = vae_only.config.latent_channels
vae_sample_size = vae_only.config.sample_size vae_out_channels = vae_only.config.out_channels
# need to get the raw tensor output (sample) from the encoder # forward only through the decoder part
vae_only.forward = lambda sample, return_dict: vae_only.encode( vae_only.forward = vae_only.decode
sample, return_dict)[0].sample()
onnx_export( onnx_export(
vae_only, vae_only,
model_args=( model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype), device=training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae" / "model.onnx", output_path=output_path / "vae" / "model.onnx",
ordered_input_names=["sample", "return_dict"], ordered_input_names=["latent_sample", "return_dict"],
output_names=["latent_sample"], output_names=["sample"],
dynamic_axes={ dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=opset, opset=opset,
) )
@ -435,7 +443,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
logger.info('exporting ONNX model') logger.info('exporting ONNX model')
onnx_pipeline.save_pretrained(output_path) onnx_pipeline.save_pretrained(output_path)
logger.info("ONNX pipeline saved to", output_path) logger.info("ONNX pipeline saved to %s", output_path)
del pipeline del pipeline
del onnx_pipeline del onnx_pipeline

View File

@ -21,6 +21,8 @@ import torch
logger = getLogger(__name__) logger = getLogger(__name__)
num_channels_latents = 4 # self.vae.config.latent_channels
unet_in_channels = 7 # self.unet.config.in_channels
def preprocess(image): def preprocess(image):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
@ -68,7 +70,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
# generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
@ -100,11 +103,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Add noise to image # 5. Add noise to image
print('text embedding dtype', text_embeddings.dtype) # print('text embedding dtype', text_embeddings.dtype)
text_embeddings_dtype = torch.float32
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = generator.random(size=image.shape, dtype=text_embeddings.dtype) noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
noise = torch.from_numpy(noise).to(device)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -113,13 +116,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 6. Prepare latent variables # 6. Prepare latent variables
height, width = image.shape[2:] height, width = image.shape[2:]
num_channels_latents = self.vae.config.latent_channels # TODO: config
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, text_embeddings_dtype,
device, device,
generator, generator,
latents, latents,
@ -127,10 +129,10 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 7. Check that sizes of image and latents match # 7. Check that sizes of image and latents match
num_channels_image = image.shape[1] num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != self.unet.config.in_channels: # TODO: config if num_channels_latents + num_channels_image != unet_in_channels:
raise ValueError( raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} " f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of" f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input." " `pipeline.unet` or your `image` input."
@ -148,16 +150,23 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# concat latents, mask, masked_image_latents in the channel dimension # concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = np.concatenate([latent_model_input, image], dim=1) latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
timestep = np.array([t], dtype=np.float32)
# predict the noise residual # predict the noise residual
# print('noise pred unet', latent_model_input.dtype, text_embeddings.dtype, t, noise_level)
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level sample=latent_model_input,
).sample timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) # noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -171,7 +180,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 10. Post-processing # 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=np.float32) # self.vae.to(dtype=np.float32)
image = self.decode_latents(latents.float()) image = self.decode_latents(latents.float())
# 11. Convert to PIL # 11. Convert to PIL
@ -183,6 +192,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return ImagePipelineOutput(images=image) return ImagePipelineOutput(images=image)
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1

View File

@ -10,7 +10,6 @@ try:
if path.exists(logging_path): if path.exists(logging_path):
with open(logging_path, 'r') as f: with open(logging_path, 'r') as f:
config_logging = safe_load(f) config_logging = safe_load(f)
# print('configuring logger', config_logging)
dictConfig(config_logging) dictConfig(config_logging)
except Exception as err: except Exception as err:
print('error loading logging config: %s' % (err)) print('error loading logging config: %s' % (err))

View File

@ -575,11 +575,9 @@ def chain():
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get('stages', []): for stage_data in data.get('stages', []):
logger.info('request stage: %s', stage_data)
callback = chain_stages[stage_data.get('type')] callback = chain_stages[stage_data.get('type')]
kwargs = stage_data.get('params', {}) kwargs = stage_data.get('params', {})
print('stage', callback.__name__, kwargs) logger.info('request stage: %s, %s', callback.__name__, kwargs)
stage = StageParams( stage = StageParams(
stage_data.get('name', callback.__name__), stage_data.get('name', callback.__name__),
@ -587,12 +585,10 @@ def chain():
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4), outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
) )
# TODO: create Border from border
if 'border' in kwargs: if 'border' in kwargs:
border = Border.even(int(kwargs.get('border'))) border = Border.even(int(kwargs.get('border')))
kwargs['border'] = border kwargs['border'] = border
# TODO: create Upscale from upscale
if 'upscale' in kwargs: if 'upscale' in kwargs:
upscale = UpscaleParams(kwargs.get('upscale'), params.provider) upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
kwargs['upscale'] = upscale kwargs['upscale'] = upscale

View File

@ -20,6 +20,7 @@
"denoise", "denoise",
"denoising", "denoising",
"directml", "directml",
"dtype",
"ESRGAN", "ESRGAN",
"ftfy", "ftfy",
"gfpgan", "gfpgan",
@ -44,6 +45,7 @@
"pndm", "pndm",
"pretrained", "pretrained",
"protobuf", "protobuf",
"randn",
"realesr", "realesr",
"resrgan", "resrgan",
"RRDB", "RRDB",
@ -56,6 +58,8 @@
"spinalcase", "spinalcase",
"stabilityai", "stabilityai",
"stringcase", "stringcase",
"timestep",
"timesteps",
"uncond", "uncond",
"unet", "unet",
"untruncated", "untruncated",