fix(api): make SD upscaling compatible with more schedulers
This commit is contained in:
parent
3e5edb1c39
commit
2b29b099f0
|
@ -126,7 +126,9 @@ def convert_diffusion_stable(
|
|||
# UNET
|
||||
if single_vae:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
||||
unet_scale = torch.tensor(4).to(
|
||||
device=ctx.training_device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = torch.tensor(False).to(
|
||||
|
|
|
@ -19,11 +19,15 @@ from diffusers.pipeline_utils import ImagePipelineOutput
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
# TODO: make this dynamic, from self.vae.config.latent_channels
|
||||
num_channels_latents = 4
|
||||
|
||||
# TODO: make this dynamic, from self.unet.config.in_channels
|
||||
unet_in_channels = 7
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
# TODO: should this be a lookup? it needs to match the conversion script
|
||||
class_labels_dtype = np.long
|
||||
|
||||
# TODO: should this be a lookup or converted? can it vary on ONNX?
|
||||
text_embeddings_dtype = torch.float32
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
|
@ -93,7 +97,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
|
@ -116,7 +119,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
height, width = image.shape[2:]
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
text_embeddings_dtype,
|
||||
|
@ -127,12 +130,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if num_channels_latents + num_channels_image != unet_in_channels:
|
||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||
raise ValueError(
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
|
@ -158,7 +161,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level,
|
||||
class_labels=noise_level.astype(class_labels_dtype),
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
@ -167,7 +170,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
|
Loading…
Reference in New Issue