1
0
Fork 0

fix(api): update SD upscaling pipeline

This commit is contained in:
Sean Sube 2023-02-16 21:44:33 -06:00
parent 1ea1a57706
commit 3d73b9e621
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 10 deletions

View File

@ -222,4 +222,4 @@ def load_tensor(name: str, map_location=None):
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
)
return checkpoint
return checkpoint

View File

@ -23,11 +23,10 @@ logger = getLogger(__name__)
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
# TODO: should this be a lookup? it needs to match the conversion script
class_labels_dtype = np.long
# TODO: should this be a lookup or converted? can it vary on ONNX?
text_embeddings_dtype = torch.float32
TORCH_DTYPES = {
"float16": torch.float16,
"float32": torch.float32,
}
def preprocess(image):
@ -98,6 +97,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
# 4. Preprocess image
image = preprocess(image)
image = image.cpu()
@ -108,7 +109,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
@ -122,7 +123,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
NUM_LATENT_CHANNELS,
height,
width,
text_embeddings_dtype,
latents_dtype,
device,
generator,
latents,
@ -142,6 +143,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
@ -154,14 +160,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
timestep = np.array([t], dtype=np.float32)
timestep = np.array([t], dtype=timestep_dtype)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level.astype(class_labels_dtype),
class_labels=noise_level.astype(np.int64),
)[0]
# perform guidance