apply lint, rename lookup table
This commit is contained in:
parent
f5ae9dd492
commit
388eb640c0
|
@ -23,7 +23,6 @@ from logging import getLogger
|
|||
from typing import Dict, List
|
||||
|
||||
import huggingface_hub.utils.tqdm
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
@ -1147,8 +1146,8 @@ def extract_checkpoint(
|
|||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
is_512=True,
|
||||
config_file: str =None,
|
||||
vae_file: str =None,
|
||||
config_file: str = None,
|
||||
vae_file: str = None,
|
||||
):
|
||||
"""
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ ORT_TO_NP_TYPE = {
|
|||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
TORCH_DTYPES = {
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
|
|
Loading…
Reference in New Issue