testing ONNX upscaling
This commit is contained in:
parent
ca613cabe1
commit
bacce0ace9
|
@ -63,11 +63,7 @@ def upscale_stable_diffusion(
|
||||||
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
||||||
|
|
||||||
pipeline = load_stable_diffusion(ctx, upscale)
|
pipeline = load_stable_diffusion(ctx, upscale)
|
||||||
|
generator = torch.manual_seed(params.seed)
|
||||||
if upscale.format == 'onnx':
|
|
||||||
generator = np.random.default_rng(params.seed)
|
|
||||||
else:
|
|
||||||
generator = torch.manual_seed(params.seed)
|
|
||||||
|
|
||||||
return pipeline(
|
return pipeline(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
|
|
@ -253,7 +253,17 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
)
|
)
|
||||||
del pipeline.text_encoder
|
del pipeline.text_encoder
|
||||||
|
|
||||||
|
logger.info('UNET config: %s', pipeline.unet.config)
|
||||||
|
|
||||||
# UNET
|
# UNET
|
||||||
|
if single_vae:
|
||||||
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||||
|
# unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"]
|
||||||
|
unet_scale = 4
|
||||||
|
else:
|
||||||
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
|
unet_scale = False
|
||||||
|
|
||||||
unet_in_channels = pipeline.unet.config.in_channels
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
unet_sample_size = pipeline.unet.config.sample_size
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
unet_path = output_path / "unet" / "model.onnx"
|
unet_path = output_path / "unet" / "model.onnx"
|
||||||
|
@ -265,13 +275,10 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype),
|
||||||
# TODO: needs to be Int or Long for upscaling, Bool for regular
|
unet_scale,
|
||||||
4,
|
|
||||||
# False,
|
|
||||||
),
|
),
|
||||||
output_path=unet_path,
|
output_path=unet_path,
|
||||||
ordered_input_names=["sample", "timestep",
|
ordered_input_names=unet_inputs,
|
||||||
"encoder_hidden_states", "return_dict"],
|
|
||||||
# has to be different from "sample" for correct tracing
|
# has to be different from "sample" for correct tracing
|
||||||
output_names=["out_sample"],
|
output_names=["out_sample"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
|
@ -300,25 +307,26 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
|
logger.info('VAE config: %s', pipeline.vae.config)
|
||||||
|
|
||||||
# SINGLE VAE
|
# SINGLE VAE
|
||||||
vae_only = pipeline.vae
|
vae_only = pipeline.vae
|
||||||
vae_in_channels = vae_only.config.in_channels
|
vae_latent_channels = vae_only.config.latent_channels
|
||||||
vae_sample_size = vae_only.config.sample_size
|
vae_out_channels = vae_only.config.out_channels
|
||||||
# need to get the raw tensor output (sample) from the encoder
|
# forward only through the decoder part
|
||||||
vae_only.forward = lambda sample, return_dict: vae_only.encode(
|
vae_only.forward = vae_only.decode
|
||||||
sample, return_dict)[0].sample()
|
|
||||||
onnx_export(
|
onnx_export(
|
||||||
vae_only,
|
vae_only,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae" / "model.onnx",
|
output_path=output_path / "vae" / "model.onnx",
|
||||||
ordered_input_names=["sample", "return_dict"],
|
ordered_input_names=["latent_sample", "return_dict"],
|
||||||
output_names=["latent_sample"],
|
output_names=["sample"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=opset,
|
||||||
)
|
)
|
||||||
|
@ -435,7 +443,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si
|
||||||
logger.info('exporting ONNX model')
|
logger.info('exporting ONNX model')
|
||||||
|
|
||||||
onnx_pipeline.save_pretrained(output_path)
|
onnx_pipeline.save_pretrained(output_path)
|
||||||
logger.info("ONNX pipeline saved to", output_path)
|
logger.info("ONNX pipeline saved to %s", output_path)
|
||||||
|
|
||||||
del pipeline
|
del pipeline
|
||||||
del onnx_pipeline
|
del onnx_pipeline
|
||||||
|
|
|
@ -21,6 +21,8 @@ import torch
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
num_channels_latents = 4 # self.vae.config.latent_channels
|
||||||
|
unet_in_channels = 7 # self.unet.config.in_channels
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
|
@ -68,7 +70,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
# generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
@ -100,11 +103,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
timesteps = self.scheduler.timesteps
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
# 5. Add noise to image
|
# 5. Add noise to image
|
||||||
print('text embedding dtype', text_embeddings.dtype)
|
# print('text embedding dtype', text_embeddings.dtype)
|
||||||
|
text_embeddings_dtype = torch.float32
|
||||||
|
|
||||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||||
noise = generator.random(size=image.shape, dtype=text_embeddings.dtype)
|
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
|
||||||
noise = torch.from_numpy(noise).to(device)
|
|
||||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||||
|
|
||||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||||
|
@ -113,13 +116,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
height, width = image.shape[2:]
|
height, width = image.shape[2:]
|
||||||
num_channels_latents = self.vae.config.latent_channels # TODO: config
|
|
||||||
latents = self.prepare_latents(
|
latents = self.prepare_latents(
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
num_channels_latents,
|
num_channels_latents,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
text_embeddings.dtype,
|
text_embeddings_dtype,
|
||||||
device,
|
device,
|
||||||
generator,
|
generator,
|
||||||
latents,
|
latents,
|
||||||
|
@ -127,10 +129,10 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 7. Check that sizes of image and latents match
|
# 7. Check that sizes of image and latents match
|
||||||
num_channels_image = image.shape[1]
|
num_channels_image = image.shape[1]
|
||||||
if num_channels_latents + num_channels_image != self.unet.config.in_channels: # TODO: config
|
if num_channels_latents + num_channels_image != unet_in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
f"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||||
f" `num_channels_image`: {num_channels_image} "
|
f" `num_channels_image`: {num_channels_image} "
|
||||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||||
" `pipeline.unet` or your `image` input."
|
" `pipeline.unet` or your `image` input."
|
||||||
|
@ -148,16 +150,23 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# concat latents, mask, masked_image_latents in the channel dimension
|
# concat latents, mask, masked_image_latents in the channel dimension
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
latent_model_input = np.concatenate([latent_model_input, image], dim=1)
|
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||||
|
|
||||||
|
# timestep to tensor
|
||||||
|
timestep = np.array([t], dtype=np.float32)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
# print('noise pred unet', latent_model_input.dtype, text_embeddings.dtype, t, noise_level)
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
sample=latent_model_input,
|
||||||
).sample
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=text_embeddings,
|
||||||
|
class_labels=noise_level
|
||||||
|
)[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) # noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
@ -171,7 +180,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 10. Post-processing
|
# 10. Post-processing
|
||||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
self.vae.to(dtype=np.float32)
|
# self.vae.to(dtype=np.float32)
|
||||||
image = self.decode_latents(latents.float())
|
image = self.decode_latents(latents.float())
|
||||||
|
|
||||||
# 11. Convert to PIL
|
# 11. Convert to PIL
|
||||||
|
@ -183,6 +192,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
return ImagePipelineOutput(images=image)
|
return ImagePipelineOutput(images=image)
|
||||||
|
|
||||||
|
def decode_latents(self, latents):
|
||||||
|
latents = 1 / 0.08333 * latents
|
||||||
|
image = self.vae(latent_sample=latents)[0]
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||||
|
image = image.transpose((0, 2, 3, 1))
|
||||||
|
return image
|
||||||
|
|
||||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
|
|
|
@ -10,7 +10,6 @@ try:
|
||||||
if path.exists(logging_path):
|
if path.exists(logging_path):
|
||||||
with open(logging_path, 'r') as f:
|
with open(logging_path, 'r') as f:
|
||||||
config_logging = safe_load(f)
|
config_logging = safe_load(f)
|
||||||
# print('configuring logger', config_logging)
|
|
||||||
dictConfig(config_logging)
|
dictConfig(config_logging)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print('error loading logging config: %s' % (err))
|
print('error loading logging config: %s' % (err))
|
||||||
|
|
|
@ -575,11 +575,9 @@ def chain():
|
||||||
|
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get('stages', []):
|
for stage_data in data.get('stages', []):
|
||||||
logger.info('request stage: %s', stage_data)
|
|
||||||
|
|
||||||
callback = chain_stages[stage_data.get('type')]
|
callback = chain_stages[stage_data.get('type')]
|
||||||
kwargs = stage_data.get('params', {})
|
kwargs = stage_data.get('params', {})
|
||||||
print('stage', callback.__name__, kwargs)
|
logger.info('request stage: %s, %s', callback.__name__, kwargs)
|
||||||
|
|
||||||
stage = StageParams(
|
stage = StageParams(
|
||||||
stage_data.get('name', callback.__name__),
|
stage_data.get('name', callback.__name__),
|
||||||
|
@ -587,12 +585,10 @@ def chain():
|
||||||
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
|
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: create Border from border
|
|
||||||
if 'border' in kwargs:
|
if 'border' in kwargs:
|
||||||
border = Border.even(int(kwargs.get('border')))
|
border = Border.even(int(kwargs.get('border')))
|
||||||
kwargs['border'] = border
|
kwargs['border'] = border
|
||||||
|
|
||||||
# TODO: create Upscale from upscale
|
|
||||||
if 'upscale' in kwargs:
|
if 'upscale' in kwargs:
|
||||||
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
||||||
kwargs['upscale'] = upscale
|
kwargs['upscale'] = upscale
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
"denoise",
|
"denoise",
|
||||||
"denoising",
|
"denoising",
|
||||||
"directml",
|
"directml",
|
||||||
|
"dtype",
|
||||||
"ESRGAN",
|
"ESRGAN",
|
||||||
"ftfy",
|
"ftfy",
|
||||||
"gfpgan",
|
"gfpgan",
|
||||||
|
@ -44,6 +45,7 @@
|
||||||
"pndm",
|
"pndm",
|
||||||
"pretrained",
|
"pretrained",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
|
"randn",
|
||||||
"realesr",
|
"realesr",
|
||||||
"resrgan",
|
"resrgan",
|
||||||
"RRDB",
|
"RRDB",
|
||||||
|
@ -56,6 +58,8 @@
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
"stabilityai",
|
"stabilityai",
|
||||||
"stringcase",
|
"stringcase",
|
||||||
|
"timestep",
|
||||||
|
"timesteps",
|
||||||
"uncond",
|
"uncond",
|
||||||
"unet",
|
"unet",
|
||||||
"untruncated",
|
"untruncated",
|
||||||
|
|
Loading…
Reference in New Issue