fix(api): update SD upscaling pipeline
This commit is contained in:
parent
1ea1a57706
commit
3d73b9e621
|
@ -23,11 +23,10 @@ logger = getLogger(__name__)
|
||||||
NUM_LATENT_CHANNELS = 4
|
NUM_LATENT_CHANNELS = 4
|
||||||
NUM_UNET_INPUT_CHANNELS = 7
|
NUM_UNET_INPUT_CHANNELS = 7
|
||||||
|
|
||||||
# TODO: should this be a lookup? it needs to match the conversion script
|
TORCH_DTYPES = {
|
||||||
class_labels_dtype = np.long
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
# TODO: should this be a lookup or converted? can it vary on ONNX?
|
}
|
||||||
text_embeddings_dtype = torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
|
@ -98,6 +97,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
image = image.cpu()
|
image = image.cpu()
|
||||||
|
@ -108,7 +109,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
|
|
||||||
# 5. Add noise to image
|
# 5. Add noise to image
|
||||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
|
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||||
|
|
||||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||||
|
@ -122,7 +123,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
NUM_LATENT_CHANNELS,
|
NUM_LATENT_CHANNELS,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
text_embeddings_dtype,
|
latents_dtype,
|
||||||
device,
|
device,
|
||||||
generator,
|
generator,
|
||||||
latents,
|
latents,
|
||||||
|
@ -142,6 +143,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
timestep_dtype = next(
|
||||||
|
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||||
|
)
|
||||||
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
# 9. Denoising loop
|
# 9. Denoising loop
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
@ -154,14 +160,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||||
|
|
||||||
# timestep to tensor
|
# timestep to tensor
|
||||||
timestep = np.array([t], dtype=np.float32)
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=text_embeddings,
|
encoder_hidden_states=text_embeddings,
|
||||||
class_labels=noise_level.astype(class_labels_dtype),
|
class_labels=noise_level.astype(np.int64),
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
|
|
Loading…
Reference in New Issue