1
0
Fork 0

lint(api): use constant for model filename

This commit is contained in:
Sean Sube 2023-03-24 08:14:19 -05:00
parent 6b4c046867
commit 2c47904057
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 39 additions and 21 deletions

View File

@ -16,6 +16,10 @@ logger = getLogger(__name__)
class StageCallback(Protocol): class StageCallback(Protocol):
"""
Definition for a stage job function.
"""
def __call__( def __call__(
self, self,
job: WorkerContext, job: WorkerContext,
@ -25,6 +29,9 @@ class StageCallback(Protocol):
source: Image.Image, source: Image.Image,
**kwargs: Any **kwargs: Any
) -> Image.Image: ) -> Image.Image:
"""
Run this stage against a source image.
"""
pass pass

View File

@ -9,7 +9,14 @@ logger = getLogger(__name__)
class TileCallback(Protocol): class TileCallback(Protocol):
"""
Definition for a tile job function.
"""
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
"""
Run this stage against a single tile.
"""
pass pass

View File

@ -0,0 +1,2 @@
ONNX_MODEL = "model.onnx"
ONNX_WEIGHTS = "weights.pb"

View File

@ -12,6 +12,7 @@ from onnx import load_model, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from yaml import safe_load from yaml import safe_load
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from .correction_gfpgan import convert_correction_gfpgan from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras from .diffusion.lora import blend_loras
@ -297,7 +298,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
path.join( path.join(
dest, dest,
"text_encoder", "text_encoder",
"model.onnx", ONNX_MODEL,
) )
) )
@ -341,13 +342,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
path.join( path.join(
dest, dest,
"text_encoder", "text_encoder",
"model.onnx", ONNX_MODEL,
) )
) )
if "unet" not in blend_models: if "unet" not in blend_models:
blend_models["text_encoder"] = load_model( blend_models["text_encoder"] = load_model(
path.join(dest, "unet", "model.onnx") path.join(dest, "unet", ONNX_MODEL)
) )
# load models if not loaded yet # load models if not loaded yet
@ -377,7 +378,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for name in ["text_encoder", "unet"]: for name in ["text_encoder", "unet"]:
if name in blend_models: if name in blend_models:
dest_path = path.join(dest, name, "model.onnx") dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug( logger.debug(
"saving blended %s model to %s", name, dest_path "saving blended %s model to %s", name, dest_path
) )
@ -386,7 +387,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
dest_path, dest_path,
save_as_external_data=True, save_as_external_data=True,
all_tensors_to_one_file=True, all_tensors_to_one_file=True,
location="weights.pb", location=ONNX_WEIGHTS,
) )
except Exception: except Exception:

View File

@ -27,6 +27,7 @@ from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16 from onnxruntime.transformers.float16 import convert_float_to_float16
from torch.onnx import export from torch.onnx import export
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline from ...diffusers.load import optimize_pipeline
from ...diffusers.pipeline_onnx_stable_diffusion_upscale import ( from ...diffusers.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
@ -79,7 +80,7 @@ def onnx_export(
f"{output_file}", f"{output_file}",
save_as_external_data=external_data, save_as_external_data=external_data,
all_tensors_to_one_file=True, all_tensors_to_one_file=True,
location="weights.pb", location=ONNX_WEIGHTS,
) )
@ -144,7 +145,7 @@ def convert_diffusion_diffusers(
None, # output attentions None, # output attentions
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool), torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
), ),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / ONNX_MODEL,
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output", "hidden_states"], output_names=["last_hidden_state", "pooler_output", "hidden_states"],
dynamic_axes={ dynamic_axes={
@ -169,7 +170,7 @@ def convert_diffusion_diffusers(
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / "model.onnx" unet_path = output_path / "unet" / ONNX_MODEL
onnx_export( onnx_export(
pipeline.unet, pipeline.unet,
model_args=( model_args=(
@ -207,7 +208,7 @@ def convert_diffusion_diffusers(
unet_model_path, unet_model_path,
save_as_external_data=True, save_as_external_data=True,
all_tensors_to_one_file=True, all_tensors_to_one_file=True,
location="weights.pb", location=ONNX_WEIGHTS,
convert_attribute=False, convert_attribute=False,
) )
del pipeline.unet del pipeline.unet
@ -233,7 +234,7 @@ def convert_diffusion_diffusers(
).to(device=ctx.training_device, dtype=dtype), ).to(device=ctx.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae" / "model.onnx", output_path=output_path / "vae" / ONNX_MODEL,
ordered_input_names=["latent_sample", "return_dict"], ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"], output_names=["sample"],
dynamic_axes={ dynamic_axes={
@ -259,7 +260,7 @@ def convert_diffusion_diffusers(
), ),
False, False,
), ),
output_path=output_path / "vae_encoder" / "model.onnx", output_path=output_path / "vae_encoder" / ONNX_MODEL,
ordered_input_names=["sample", "return_dict"], ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"], output_names=["latent_sample"],
dynamic_axes={ dynamic_axes={
@ -282,7 +283,7 @@ def convert_diffusion_diffusers(
).to(device=ctx.training_device, dtype=dtype), ).to(device=ctx.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae_decoder" / "model.onnx", output_path=output_path / "vae_decoder" / ONNX_MODEL,
ordered_input_names=["latent_sample", "return_dict"], ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"], output_names=["sample"],
dynamic_axes={ dynamic_axes={

View File

@ -55,7 +55,7 @@ def fix_node_name(key: str):
def blend_loras( def blend_loras(
context: ServerContext, _context: ServerContext,
base_name: Union[str, ModelProto], base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]], loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],

View File

@ -7,6 +7,7 @@ import torch
from onnx import ModelProto, load_model, numpy_helper, save_model from onnx import ModelProto, load_model, numpy_helper, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from ...constants import ONNX_MODEL
from ...server.context import ServerContext from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor from ..utils import ConversionContext, load_tensor
@ -174,7 +175,7 @@ def convert_diffusion_textual_inversion(
) )
encoder_path = path.join(dest_path, "text_encoder") encoder_path = path.join(dest_path, "text_encoder")
encoder_model = path.join(encoder_path, "model.onnx") encoder_model = path.join(encoder_path, ONNX_MODEL)
tokenizer_path = path.join(dest_path, "tokenizer") tokenizer_path = path.join(dest_path, "tokenizer")
if ( if (
@ -187,7 +188,7 @@ def convert_diffusion_textual_inversion(
makedirs(encoder_path, exist_ok=True) makedirs(encoder_path, exist_ok=True)
text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx")) text_encoder = load_model(path.join(base_model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
base_model, base_model,
subfolder="tokenizer", subfolder="tokenizer",

View File

@ -25,7 +25,8 @@ from diffusers import (
from onnx import load_model from onnx import load_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from onnx_web.diffusers.utils import expand_prompt from ..constants import ONNX_MODEL
from ..diffusers.utils import expand_prompt
try: try:
from diffusers import DEISMultistepScheduler from diffusers import DEISMultistepScheduler
@ -191,7 +192,7 @@ def load_pipeline(
path.join(server.model_path, "inversion", name) path.join(server.model_path, "inversion", name)
for name in inversion_names for name in inversion_names
] ]
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
model, model,
subfolder="tokenizer", subfolder="tokenizer",
@ -233,9 +234,7 @@ def load_pipeline(
) )
# blend and load text encoder # blend and load text encoder
text_encoder = text_encoder or path.join( text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
model, "text_encoder", "model.onnx"
)
text_encoder = blend_loras( text_encoder = blend_loras(
server, server,
text_encoder, text_encoder,
@ -261,7 +260,7 @@ def load_pipeline(
# blend and load unet # blend and load unet
blended_unet = blend_loras( blended_unet = blend_loras(
server, server,
path.join(model, "unet", "model.onnx"), path.join(model, "unet", ONNX_MODEL),
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"unet", "unet",
) )