1
0
Fork 0

fix(api): make SD upscaling compatible with more schedulers

This commit is contained in:
Sean Sube 2023-02-15 17:16:20 -06:00
parent 3e5edb1c39
commit 2b29b099f0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 19 additions and 12 deletions

View File

@ -126,7 +126,9 @@ def convert_diffusion_stable(
# UNET # UNET
if single_vae: if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) unet_scale = torch.tensor(4).to(
device=ctx.training_device, dtype=torch.long
)
else: else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to( unet_scale = torch.tensor(False).to(

View File

@ -19,11 +19,15 @@ from diffusers.pipeline_utils import ImagePipelineOutput
logger = getLogger(__name__) logger = getLogger(__name__)
# TODO: make this dynamic, from self.vae.config.latent_channels
num_channels_latents = 4
# TODO: make this dynamic, from self.unet.config.in_channels NUM_LATENT_CHANNELS = 4
unet_in_channels = 7 NUM_UNET_INPUT_CHANNELS = 7
# TODO: should this be a lookup? it needs to match the conversion script
class_labels_dtype = np.long
# TODO: should this be a lookup or converted? can it vary on ONNX?
text_embeddings_dtype = torch.float32
def preprocess(image): def preprocess(image):
@ -93,7 +97,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = preprocess(image)
@ -116,7 +119,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
height, width = image.shape[2:] height, width = image.shape[2:]
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, NUM_LATENT_CHANNELS,
height, height,
width, width,
text_embeddings_dtype, text_embeddings_dtype,
@ -127,12 +130,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 7. Check that sizes of image and latents match # 7. Check that sizes of image and latents match
num_channels_image = image.shape[1] num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != unet_in_channels: if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
raise ValueError( raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects" "Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
f" `num_channels_image`: {num_channels_image} " f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of" f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input." " `pipeline.unet` or your `image` input."
) )
@ -158,7 +161,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
sample=latent_model_input, sample=latent_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=text_embeddings, encoder_hidden_states=text_embeddings,
class_labels=noise_level, class_labels=noise_level.astype(class_labels_dtype),
)[0] )[0]
# perform guidance # perform guidance
@ -167,7 +170,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):