From 2c479040571c58392dab18d127dfef9b7b18e3db Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 24 Mar 2023 08:14:19 -0500 Subject: [PATCH] lint(api): use constant for model filename --- api/onnx_web/chain/base.py | 7 +++++++ api/onnx_web/chain/utils.py | 7 +++++++ api/onnx_web/constants.py | 2 ++ api/onnx_web/convert/__main__.py | 11 ++++++----- api/onnx_web/convert/diffusion/diffusers.py | 15 ++++++++------- api/onnx_web/convert/diffusion/lora.py | 2 +- .../convert/diffusion/textual_inversion.py | 5 +++-- api/onnx_web/diffusers/load.py | 11 +++++------ 8 files changed, 39 insertions(+), 21 deletions(-) create mode 100644 api/onnx_web/constants.py diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 06d10ffd..adace06f 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -16,6 +16,10 @@ logger = getLogger(__name__) class StageCallback(Protocol): + """ + Definition for a stage job function. + """ + def __call__( self, job: WorkerContext, @@ -25,6 +29,9 @@ class StageCallback(Protocol): source: Image.Image, **kwargs: Any ) -> Image.Image: + """ + Run this stage against a source image. + """ pass diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index f0799020..9df7c5c7 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -9,7 +9,14 @@ logger = getLogger(__name__) class TileCallback(Protocol): + """ + Definition for a tile job function. + """ + def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: + """ + Run this stage against a single tile. + """ pass diff --git a/api/onnx_web/constants.py b/api/onnx_web/constants.py new file mode 100644 index 00000000..4fe47f98 --- /dev/null +++ b/api/onnx_web/constants.py @@ -0,0 +1,2 @@ +ONNX_MODEL = "model.onnx" +ONNX_WEIGHTS = "weights.pb" diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 98cc3b7a..6c77ba0e 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -12,6 +12,7 @@ from onnx import load_model, save_model from transformers import CLIPTokenizer from yaml import safe_load +from ..constants import ONNX_MODEL, ONNX_WEIGHTS from .correction_gfpgan import convert_correction_gfpgan from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.lora import blend_loras @@ -297,7 +298,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): path.join( dest, "text_encoder", - "model.onnx", + ONNX_MODEL, ) ) @@ -341,13 +342,13 @@ def convert_models(ctx: ConversionContext, args, models: Models): path.join( dest, "text_encoder", - "model.onnx", + ONNX_MODEL, ) ) if "unet" not in blend_models: blend_models["text_encoder"] = load_model( - path.join(dest, "unet", "model.onnx") + path.join(dest, "unet", ONNX_MODEL) ) # load models if not loaded yet @@ -377,7 +378,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): for name in ["text_encoder", "unet"]: if name in blend_models: - dest_path = path.join(dest, name, "model.onnx") + dest_path = path.join(dest, name, ONNX_MODEL) logger.debug( "saving blended %s model to %s", name, dest_path ) @@ -386,7 +387,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): dest_path, save_as_external_data=True, all_tensors_to_one_file=True, - location="weights.pb", + location=ONNX_WEIGHTS, ) except Exception: diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 6c434a2d..af0b27b2 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -27,6 +27,7 @@ from onnx.shape_inference import infer_shapes_path from onnxruntime.transformers.float16 import convert_float_to_float16 from torch.onnx import export +from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...diffusers.load import optimize_pipeline from ...diffusers.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, @@ -79,7 +80,7 @@ def onnx_export( f"{output_file}", save_as_external_data=external_data, all_tensors_to_one_file=True, - location="weights.pb", + location=ONNX_WEIGHTS, ) @@ -144,7 +145,7 @@ def convert_diffusion_diffusers( None, # output attentions torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool), ), - output_path=output_path / "text_encoder" / "model.onnx", + output_path=output_path / "text_encoder" / ONNX_MODEL, ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output", "hidden_states"], dynamic_axes={ @@ -169,7 +170,7 @@ def convert_diffusion_diffusers( unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size - unet_path = output_path / "unet" / "model.onnx" + unet_path = output_path / "unet" / ONNX_MODEL onnx_export( pipeline.unet, model_args=( @@ -207,7 +208,7 @@ def convert_diffusion_diffusers( unet_model_path, save_as_external_data=True, all_tensors_to_one_file=True, - location="weights.pb", + location=ONNX_WEIGHTS, convert_attribute=False, ) del pipeline.unet @@ -233,7 +234,7 @@ def convert_diffusion_diffusers( ).to(device=ctx.training_device, dtype=dtype), False, ), - output_path=output_path / "vae" / "model.onnx", + output_path=output_path / "vae" / ONNX_MODEL, ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], dynamic_axes={ @@ -259,7 +260,7 @@ def convert_diffusion_diffusers( ), False, ), - output_path=output_path / "vae_encoder" / "model.onnx", + output_path=output_path / "vae_encoder" / ONNX_MODEL, ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], dynamic_axes={ @@ -282,7 +283,7 @@ def convert_diffusion_diffusers( ).to(device=ctx.training_device, dtype=dtype), False, ), - output_path=output_path / "vae_decoder" / "model.onnx", + output_path=output_path / "vae_decoder" / ONNX_MODEL, ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], dynamic_axes={ diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 38d2ea1a..b4ff756e 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -55,7 +55,7 @@ def fix_node_name(key: str): def blend_loras( - context: ServerContext, + _context: ServerContext, base_name: Union[str, ModelProto], loras: List[Tuple[str, float]], model_type: Literal["text_encoder", "unet"], diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index f96c7a32..336074e4 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -7,6 +7,7 @@ import torch from onnx import ModelProto, load_model, numpy_helper, save_model from transformers import CLIPTokenizer +from ...constants import ONNX_MODEL from ...server.context import ServerContext from ..utils import ConversionContext, load_tensor @@ -174,7 +175,7 @@ def convert_diffusion_textual_inversion( ) encoder_path = path.join(dest_path, "text_encoder") - encoder_model = path.join(encoder_path, "model.onnx") + encoder_model = path.join(encoder_path, ONNX_MODEL) tokenizer_path = path.join(dest_path, "tokenizer") if ( @@ -187,7 +188,7 @@ def convert_diffusion_textual_inversion( makedirs(encoder_path, exist_ok=True) - text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx")) + text_encoder = load_model(path.join(base_model, "text_encoder", ONNX_MODEL)) tokenizer = CLIPTokenizer.from_pretrained( base_model, subfolder="tokenizer", diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 2c88eb1b..32f86f65 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -25,7 +25,8 @@ from diffusers import ( from onnx import load_model from transformers import CLIPTokenizer -from onnx_web.diffusers.utils import expand_prompt +from ..constants import ONNX_MODEL +from ..diffusers.utils import expand_prompt try: from diffusers import DEISMultistepScheduler @@ -191,7 +192,7 @@ def load_pipeline( path.join(server.model_path, "inversion", name) for name in inversion_names ] - text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) + text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) tokenizer = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer", @@ -233,9 +234,7 @@ def load_pipeline( ) # blend and load text encoder - text_encoder = text_encoder or path.join( - model, "text_encoder", "model.onnx" - ) + text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL) text_encoder = blend_loras( server, text_encoder, @@ -261,7 +260,7 @@ def load_pipeline( # blend and load unet blended_unet = blend_loras( server, - path.join(model, "unet", "model.onnx"), + path.join(model, "unet", ONNX_MODEL), list(zip(lora_models, lora_weights)), "unet", )