1
0
Fork 0

fix(api): apply fp16 optimizations to LoRA and Textual Inversion blending

This commit is contained in:
Sean Sube 2023-03-21 21:45:27 -05:00
parent 4f6574c88e
commit 0315a8cbc6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 24 additions and 9 deletions

View File

@ -97,8 +97,7 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
torch_half = "torch-fp16" in ctx.optimizations
torch_dtype = torch.float16 if torch_half else torch.float32
torch_dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
dest_path = path.join(ctx.model_path, name)

View File

@ -62,6 +62,7 @@ def blend_loras(
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.torch_dtype()
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
@ -88,11 +89,11 @@ def blend_loras(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
)
down_weight = lora_model[key].to(dtype=torch.float32)
up_weight = lora_model[up_key].to(dtype=torch.float32)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
try:
if len(up_weight.size()) == 2:
@ -203,7 +204,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name)
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
@ -232,7 +233,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name)
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:

View File

@ -22,7 +22,7 @@ def blend_textual_inversions(
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float
dtype = context.numpy_dtype()
embeds = {}
for name, weight, base_token, inversion_format in inversions:
@ -149,7 +149,7 @@ def blend_textual_inversions(
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name
embedding_weights.astype(dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]

View File

@ -2,6 +2,9 @@ from logging import getLogger
from os import environ, path
from typing import List, Optional
import torch
import numpy as np
from ..utils import get_boolean
from .model_cache import ModelCache
@ -77,3 +80,15 @@ class ServerContext:
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit,
)
def torch_dtype(self):
if "torch-fp16" in self.optimizations:
return torch.float16
else:
return torch.float32
def numpy_dtype(self):
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
return np.float16
else:
return np.float32