fix(api): make SD upscaling compatible with more schedulers
This commit is contained in:
parent
3e5edb1c39
commit
2b29b099f0
|
@ -126,7 +126,9 @@ def convert_diffusion_stable(
|
||||||
# UNET
|
# UNET
|
||||||
if single_vae:
|
if single_vae:
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
unet_scale = torch.tensor(4).to(
|
||||||
|
device=ctx.training_device, dtype=torch.long
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
unet_scale = torch.tensor(False).to(
|
unet_scale = torch.tensor(False).to(
|
||||||
|
|
|
@ -19,11 +19,15 @@ from diffusers.pipeline_utils import ImagePipelineOutput
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
# TODO: make this dynamic, from self.vae.config.latent_channels
|
|
||||||
num_channels_latents = 4
|
|
||||||
|
|
||||||
# TODO: make this dynamic, from self.unet.config.in_channels
|
NUM_LATENT_CHANNELS = 4
|
||||||
unet_in_channels = 7
|
NUM_UNET_INPUT_CHANNELS = 7
|
||||||
|
|
||||||
|
# TODO: should this be a lookup? it needs to match the conversion script
|
||||||
|
class_labels_dtype = np.long
|
||||||
|
|
||||||
|
# TODO: should this be a lookup or converted? can it vary on ONNX?
|
||||||
|
text_embeddings_dtype = torch.float32
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
|
@ -93,7 +97,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
text_embeddings = self._encode_prompt(
|
text_embeddings = self._encode_prompt(
|
||||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||||
)
|
)
|
||||||
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype
|
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
@ -116,7 +119,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
height, width = image.shape[2:]
|
height, width = image.shape[2:]
|
||||||
latents = self.prepare_latents(
|
latents = self.prepare_latents(
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
num_channels_latents,
|
NUM_LATENT_CHANNELS,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
text_embeddings_dtype,
|
text_embeddings_dtype,
|
||||||
|
@ -127,12 +130,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 7. Check that sizes of image and latents match
|
# 7. Check that sizes of image and latents match
|
||||||
num_channels_image = image.shape[1]
|
num_channels_image = image.shape[1]
|
||||||
if num_channels_latents + num_channels_image != unet_in_channels:
|
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||||
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||||
f" `num_channels_image`: {num_channels_image} "
|
f" `num_channels_image`: {num_channels_image} "
|
||||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||||
" `pipeline.unet` or your `image` input."
|
" `pipeline.unet` or your `image` input."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -158,7 +161,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=text_embeddings,
|
encoder_hidden_states=text_embeddings,
|
||||||
class_labels=noise_level,
|
class_labels=noise_level.astype(class_labels_dtype),
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
|
@ -167,7 +170,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
latents = self.scheduler.step(
|
||||||
|
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
|
Loading…
Reference in New Issue