lint(api): use constant for model filename
This commit is contained in:
parent
6b4c046867
commit
2c47904057
|
@ -16,6 +16,10 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class StageCallback(Protocol):
|
||||
"""
|
||||
Definition for a stage job function.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
|
@ -25,6 +29,9 @@ class StageCallback(Protocol):
|
|||
source: Image.Image,
|
||||
**kwargs: Any
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run this stage against a source image.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,14 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class TileCallback(Protocol):
|
||||
"""
|
||||
Definition for a tile job function.
|
||||
"""
|
||||
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
|
||||
"""
|
||||
Run this stage against a single tile.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
ONNX_MODEL = "model.onnx"
|
||||
ONNX_WEIGHTS = "weights.pb"
|
|
@ -12,6 +12,7 @@ from onnx import load_model, save_model
|
|||
from transformers import CLIPTokenizer
|
||||
from yaml import safe_load
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.lora import blend_loras
|
||||
|
@ -297,7 +298,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
"model.onnx",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -341,13 +342,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
path.join(
|
||||
dest,
|
||||
"text_encoder",
|
||||
"model.onnx",
|
||||
ONNX_MODEL,
|
||||
)
|
||||
)
|
||||
|
||||
if "unet" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(
|
||||
path.join(dest, "unet", "model.onnx")
|
||||
path.join(dest, "unet", ONNX_MODEL)
|
||||
)
|
||||
|
||||
# load models if not loaded yet
|
||||
|
@ -377,7 +378,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
|
||||
for name in ["text_encoder", "unet"]:
|
||||
if name in blend_models:
|
||||
dest_path = path.join(dest, name, "model.onnx")
|
||||
dest_path = path.join(dest, name, ONNX_MODEL)
|
||||
logger.debug(
|
||||
"saving blended %s model to %s", name, dest_path
|
||||
)
|
||||
|
@ -386,7 +387,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
dest_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
|
|
|
@ -27,6 +27,7 @@ from onnx.shape_inference import infer_shapes_path
|
|||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
from torch.onnx import export
|
||||
|
||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ...diffusers.load import optimize_pipeline
|
||||
from ...diffusers.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
|
@ -79,7 +80,7 @@ def onnx_export(
|
|||
f"{output_file}",
|
||||
save_as_external_data=external_data,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
|
@ -144,7 +145,7 @@ def convert_diffusion_diffusers(
|
|||
None, # output attentions
|
||||
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
|
||||
),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
|
||||
dynamic_axes={
|
||||
|
@ -169,7 +170,7 @@ def convert_diffusion_diffusers(
|
|||
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
unet_path = output_path / "unet" / ONNX_MODEL
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
|
@ -207,7 +208,7 @@ def convert_diffusion_diffusers(
|
|||
unet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
location=ONNX_WEIGHTS,
|
||||
convert_attribute=False,
|
||||
)
|
||||
del pipeline.unet
|
||||
|
@ -233,7 +234,7 @@ def convert_diffusion_diffusers(
|
|||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / "model.onnx",
|
||||
output_path=output_path / "vae" / ONNX_MODEL,
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
|
@ -259,7 +260,7 @@ def convert_diffusion_diffusers(
|
|||
),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
output_path=output_path / "vae_encoder" / ONNX_MODEL,
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
|
@ -282,7 +283,7 @@ def convert_diffusion_diffusers(
|
|||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
output_path=output_path / "vae_decoder" / ONNX_MODEL,
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
|
|
|
@ -55,7 +55,7 @@ def fix_node_name(key: str):
|
|||
|
||||
|
||||
def blend_loras(
|
||||
context: ServerContext,
|
||||
_context: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext, load_tensor
|
||||
|
||||
|
@ -174,7 +175,7 @@ def convert_diffusion_textual_inversion(
|
|||
)
|
||||
|
||||
encoder_path = path.join(dest_path, "text_encoder")
|
||||
encoder_model = path.join(encoder_path, "model.onnx")
|
||||
encoder_model = path.join(encoder_path, ONNX_MODEL)
|
||||
tokenizer_path = path.join(dest_path, "tokenizer")
|
||||
|
||||
if (
|
||||
|
@ -187,7 +188,7 @@ def convert_diffusion_textual_inversion(
|
|||
|
||||
makedirs(encoder_path, exist_ok=True)
|
||||
|
||||
text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx"))
|
||||
text_encoder = load_model(path.join(base_model, "text_encoder", ONNX_MODEL))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
base_model,
|
||||
subfolder="tokenizer",
|
||||
|
|
|
@ -25,7 +25,8 @@ from diffusers import (
|
|||
from onnx import load_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from onnx_web.diffusers.utils import expand_prompt
|
||||
from ..constants import ONNX_MODEL
|
||||
from ..diffusers.utils import expand_prompt
|
||||
|
||||
try:
|
||||
from diffusers import DEISMultistepScheduler
|
||||
|
@ -191,7 +192,7 @@ def load_pipeline(
|
|||
path.join(server.model_path, "inversion", name)
|
||||
for name in inversion_names
|
||||
]
|
||||
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer",
|
||||
|
@ -233,9 +234,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# blend and load text encoder
|
||||
text_encoder = text_encoder or path.join(
|
||||
model, "text_encoder", "model.onnx"
|
||||
)
|
||||
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
|
||||
text_encoder = blend_loras(
|
||||
server,
|
||||
text_encoder,
|
||||
|
@ -261,7 +260,7 @@ def load_pipeline(
|
|||
# blend and load unet
|
||||
blended_unet = blend_loras(
|
||||
server,
|
||||
path.join(model, "unet", "model.onnx"),
|
||||
path.join(model, "unet", ONNX_MODEL),
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"unet",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue