fix(api): apply fp16 optimizations to LoRA and Textual Inversion blending
This commit is contained in:
parent
4f6574c88e
commit
0315a8cbc6
|
@ -97,8 +97,7 @@ def convert_diffusion_diffusers(
|
|||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
|
||||
torch_half = "torch-fp16" in ctx.optimizations
|
||||
torch_dtype = torch.float16 if torch_half else torch.float32
|
||||
torch_dtype = ctx.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
|
|
|
@ -62,6 +62,7 @@ def blend_loras(
|
|||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = context.torch_dtype()
|
||||
|
||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
@ -88,11 +89,11 @@ def blend_loras(
|
|||
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
||||
)
|
||||
|
||||
down_weight = lora_model[key].to(dtype=torch.float32)
|
||||
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
||||
down_weight = lora_model[key].to(dtype=dtype)
|
||||
up_weight = lora_model[up_key].to(dtype=dtype)
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
|
||||
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
||||
|
||||
try:
|
||||
if len(up_weight.size()) == 2:
|
||||
|
@ -203,7 +204,7 @@ def blend_loras(
|
|||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||
elif matmul_key in fixed_node_names:
|
||||
|
@ -232,7 +233,7 @@ def blend_loras(
|
|||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
||||
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
else:
|
||||
|
|
|
@ -22,7 +22,7 @@ def blend_textual_inversions(
|
|||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float
|
||||
dtype = context.numpy_dtype()
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, inversion_format in inversions:
|
||||
|
@ -149,7 +149,7 @@ def blend_textual_inversions(
|
|||
== "text_model.embeddings.token_embedding.weight"
|
||||
):
|
||||
new_initializer = numpy_helper.from_array(
|
||||
embedding_weights.astype(np.float32), embedding_node.name
|
||||
embedding_weights.astype(dtype), embedding_node.name
|
||||
)
|
||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||
del text_encoder.graph.initializer[i]
|
||||
|
|
|
@ -2,6 +2,9 @@ from logging import getLogger
|
|||
from os import environ, path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..utils import get_boolean
|
||||
from .model_cache import ModelCache
|
||||
|
||||
|
@ -77,3 +80,15 @@ class ServerContext:
|
|||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||
memory_limit=memory_limit,
|
||||
)
|
||||
|
||||
def torch_dtype(self):
|
||||
if "torch-fp16" in self.optimizations:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def numpy_dtype(self):
|
||||
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
|
||||
return np.float16
|
||||
else:
|
||||
return np.float32
|
Loading…
Reference in New Issue