1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-21 22:19:50 -05:00
parent e445d2afaa
commit fa71d87e2c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 16 additions and 8 deletions

View File

@ -204,7 +204,9 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
@ -233,7 +235,9 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:

View File

@ -200,13 +200,13 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_torch(name: str, map_location=None) -> Optional[Dict]:
try:
logger.debug("loading tensor with Torch JIT: %s", name)
checkpoint = torch.jit.load(name)
logger.debug("loading tensor with Torch: %s", name)
checkpoint = torch.load(name, map_location=map_location)
except Exception:
logger.exception(
"error loading with Torch JIT, falling back to Torch: %s", name
"error loading with Torch JIT, trying with Torch JIT: %s", name
)
checkpoint = torch.load(name, map_location=map_location)
checkpoint = torch.jit.load(name)
return checkpoint

View File

@ -360,7 +360,12 @@ class UNetWrapper(object):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

View File

@ -3,7 +3,6 @@ from os import environ, path
from typing import List, Optional
import torch
import numpy as np
from ..utils import get_boolean
from .model_cache import ModelCache