apply lint, rename lookup table
This commit is contained in:
parent
f5ae9dd492
commit
388eb640c0
|
@ -23,7 +23,6 @@ from logging import getLogger
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import huggingface_hub.utils.tqdm
|
import huggingface_hub.utils.tqdm
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
@ -1147,8 +1146,8 @@ def extract_checkpoint(
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
is_512=True,
|
is_512=True,
|
||||||
config_file: str =None,
|
config_file: str = None,
|
||||||
vae_file: str =None,
|
vae_file: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ ORT_TO_NP_TYPE = {
|
||||||
"tensor(double)": np.float64,
|
"tensor(double)": np.float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_DTYPES = {
|
ORT_TO_PT_TYPE = {
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
|
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||||
|
|
||||||
# 4. Preprocess image
|
# 4. Preprocess image
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
|
Loading…
Reference in New Issue