1
0
Fork 0

apply lint, rename lookup table

This commit is contained in:
Sean Sube 2023-02-16 22:22:46 -06:00
parent f5ae9dd492
commit 388eb640c0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 4 additions and 5 deletions

View File

@ -23,7 +23,6 @@ from logging import getLogger
from typing import Dict, List from typing import Dict, List
import huggingface_hub.utils.tqdm import huggingface_hub.utils.tqdm
import safetensors.torch
import torch import torch
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
@ -1147,8 +1146,8 @@ def extract_checkpoint(
extract_ema=False, extract_ema=False,
train_unfrozen=False, train_unfrozen=False,
is_512=True, is_512=True,
config_file: str =None, config_file: str = None,
vae_file: str =None, vae_file: str = None,
): ):
""" """

View File

@ -38,7 +38,7 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64, "tensor(double)": np.float64,
} }
TORCH_DTYPES = { ORT_TO_PT_TYPE = {
"float16": torch.float16, "float16": torch.float16,
"float32": torch.float32, "float32": torch.float32,
} }
@ -112,7 +112,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)] latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = preprocess(image)