apply lint
This commit is contained in:
parent
e445d2afaa
commit
fa71d87e2c
|
@ -204,7 +204,9 @@ def blend_loras(
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
|
updated_node = numpy_helper.from_array(
|
||||||
|
blended.astype(base_weights.dtype), weight_node.name
|
||||||
|
)
|
||||||
del base_model.graph.initializer[weight_idx]
|
del base_model.graph.initializer[weight_idx]
|
||||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
|
@ -233,7 +235,9 @@ def blend_loras(
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
|
updated_node = numpy_helper.from_array(
|
||||||
|
blended.astype(base_weights.dtype), matmul_node.name
|
||||||
|
)
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -200,13 +200,13 @@ def remove_prefix(name: str, prefix: str) -> str:
|
||||||
|
|
||||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
try:
|
try:
|
||||||
logger.debug("loading tensor with Torch JIT: %s", name)
|
logger.debug("loading tensor with Torch: %s", name)
|
||||||
checkpoint = torch.jit.load(name)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error loading with Torch JIT, falling back to Torch: %s", name
|
"error loading with Torch JIT, trying with Torch JIT: %s", name
|
||||||
)
|
)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.jit.load(name)
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
|
@ -360,7 +360,12 @@ class UNetWrapper(object):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
timestep_dtype = timestep.dtype
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
|
logger.trace(
|
||||||
|
"UNet parameter types: %s, %s, %s",
|
||||||
|
sample.dtype,
|
||||||
|
timestep.dtype,
|
||||||
|
encoder_hidden_states.dtype,
|
||||||
|
)
|
||||||
if sample.dtype != timestep.dtype:
|
if sample.dtype != timestep.dtype:
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
sample = sample.astype(timestep.dtype)
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
|
@ -3,7 +3,6 @@ from os import environ, path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ..utils import get_boolean
|
from ..utils import get_boolean
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
|
Loading…
Reference in New Issue