From fa71d87e2c8569a195a989678b1e6e88f75f4ffb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Mar 2023 22:19:50 -0500 Subject: [PATCH] apply lint --- api/onnx_web/convert/diffusion/lora.py | 8 ++++++-- api/onnx_web/convert/utils.py | 8 ++++---- api/onnx_web/diffusers/load.py | 7 ++++++- api/onnx_web/server/context.py | 1 - 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index da6f06b2..38d2ea1a 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -204,7 +204,9 @@ def blend_loras( logger.trace("blended weight shape: %s", blended.shape) # replace the original initializer - updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name) + updated_node = numpy_helper.from_array( + blended.astype(base_weights.dtype), weight_node.name + ) del base_model.graph.initializer[weight_idx] base_model.graph.initializer.insert(weight_idx, updated_node) elif matmul_key in fixed_node_names: @@ -233,7 +235,9 @@ def blend_loras( logger.trace("blended weight shape: %s", blended.shape) # replace the original initializer - updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name) + updated_node = numpy_helper.from_array( + blended.astype(base_weights.dtype), matmul_node.name + ) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 290aa513..b2ab00b2 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -200,13 +200,13 @@ def remove_prefix(name: str, prefix: str) -> str: def load_torch(name: str, map_location=None) -> Optional[Dict]: try: - logger.debug("loading tensor with Torch JIT: %s", name) - checkpoint = torch.jit.load(name) + logger.debug("loading tensor with Torch: %s", name) + checkpoint = torch.load(name, map_location=map_location) except Exception: logger.exception( - "error loading with Torch JIT, falling back to Torch: %s", name + "error loading with Torch JIT, trying with Torch JIT: %s", name ) - checkpoint = torch.load(name, map_location=map_location) + checkpoint = torch.jit.load(name) return checkpoint diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index f7ffec0b..2c88eb1b 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -360,7 +360,12 @@ class UNetWrapper(object): global timestep_dtype timestep_dtype = timestep.dtype - logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype) + logger.trace( + "UNet parameter types: %s, %s, %s", + sample.dtype, + timestep.dtype, + encoder_hidden_states.dtype, + ) if sample.dtype != timestep.dtype: logger.trace("converting UNet sample to timestep dtype") sample = sample.astype(timestep.dtype) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 773abef9..20cf361b 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -3,7 +3,6 @@ from os import environ, path from typing import List, Optional import torch -import numpy as np from ..utils import get_boolean from .model_cache import ModelCache