apply lint
This commit is contained in:
parent
e445d2afaa
commit
fa71d87e2c
|
@ -204,7 +204,9 @@ def blend_loras(
|
|||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(base_weights.dtype), weight_node.name
|
||||
)
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||
elif matmul_key in fixed_node_names:
|
||||
|
@ -233,7 +235,9 @@ def blend_loras(
|
|||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(base_weights.dtype), matmul_node.name
|
||||
)
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
else:
|
||||
|
|
|
@ -200,13 +200,13 @@ def remove_prefix(name: str, prefix: str) -> str:
|
|||
|
||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||
try:
|
||||
logger.debug("loading tensor with Torch JIT: %s", name)
|
||||
checkpoint = torch.jit.load(name)
|
||||
logger.debug("loading tensor with Torch: %s", name)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error loading with Torch JIT, falling back to Torch: %s", name
|
||||
"error loading with Torch JIT, trying with Torch JIT: %s", name
|
||||
)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
checkpoint = torch.jit.load(name)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
|
|
@ -360,7 +360,12 @@ class UNetWrapper(object):
|
|||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
|
||||
logger.trace(
|
||||
"UNet parameter types: %s, %s, %s",
|
||||
sample.dtype,
|
||||
timestep.dtype,
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
|
|
@ -3,7 +3,6 @@ from os import environ, path
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..utils import get_boolean
|
||||
from .model_cache import ModelCache
|
||||
|
|
Loading…
Reference in New Issue