1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-21 22:19:50 -05:00
parent e445d2afaa
commit fa71d87e2c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 16 additions and 8 deletions

View File

@ -204,7 +204,9 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name) updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx] del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node) base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names: elif matmul_key in fixed_node_names:
@ -233,7 +235,9 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name) updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)
else: else:

View File

@ -200,13 +200,13 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_torch(name: str, map_location=None) -> Optional[Dict]: def load_torch(name: str, map_location=None) -> Optional[Dict]:
try: try:
logger.debug("loading tensor with Torch JIT: %s", name) logger.debug("loading tensor with Torch: %s", name)
checkpoint = torch.jit.load(name) checkpoint = torch.load(name, map_location=map_location)
except Exception: except Exception:
logger.exception( logger.exception(
"error loading with Torch JIT, falling back to Torch: %s", name "error loading with Torch JIT, trying with Torch JIT: %s", name
) )
checkpoint = torch.load(name, map_location=map_location) checkpoint = torch.jit.load(name)
return checkpoint return checkpoint

View File

@ -360,7 +360,12 @@ class UNetWrapper(object):
global timestep_dtype global timestep_dtype
timestep_dtype = timestep.dtype timestep_dtype = timestep.dtype
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype) logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
if sample.dtype != timestep.dtype: if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype") logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype) sample = sample.astype(timestep.dtype)

View File

@ -3,7 +3,6 @@ from os import environ, path
from typing import List, Optional from typing import List, Optional
import torch import torch
import numpy as np
from ..utils import get_boolean from ..utils import get_boolean
from .model_cache import ModelCache from .model_cache import ModelCache