fix(api): update SD upscaling pipeline
This commit is contained in:
parent
1ea1a57706
commit
3d73b9e621
|
@ -222,4 +222,4 @@ def load_tensor(name: str, map_location=None):
|
|||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
return checkpoint
|
||||
|
|
|
@ -23,11 +23,10 @@ logger = getLogger(__name__)
|
|||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
# TODO: should this be a lookup? it needs to match the conversion script
|
||||
class_labels_dtype = np.long
|
||||
|
||||
# TODO: should this be a lookup or converted? can it vary on ONNX?
|
||||
text_embeddings_dtype = torch.float32
|
||||
TORCH_DTYPES = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
|
@ -98,6 +97,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.cpu()
|
||||
|
@ -108,7 +109,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
|
@ -122,7 +123,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
text_embeddings_dtype,
|
||||
latents_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
|
@ -142,6 +143,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
@ -154,14 +160,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||
|
||||
# timestep to tensor
|
||||
timestep = np.array([t], dtype=np.float32)
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level.astype(class_labels_dtype),
|
||||
class_labels=noise_level.astype(np.int64),
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
|
Loading…
Reference in New Issue