1
0
Fork 0

fix(api): apply fp16 optimizations to LoRA and Textual Inversion blending

This commit is contained in:
Sean Sube 2023-03-21 21:45:27 -05:00
parent 4f6574c88e
commit 0315a8cbc6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 24 additions and 9 deletions

View File

@ -97,8 +97,7 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae") single_vae = model.get("single_vae")
replace_vae = model.get("vae") replace_vae = model.get("vae")
torch_half = "torch-fp16" in ctx.optimizations torch_dtype = ctx.torch_dtype()
torch_dtype = torch.float16 if torch_half else torch.float32
logger.debug("using Torch dtype %s for pipeline", torch_dtype) logger.debug("using Torch dtype %s for pipeline", torch_dtype)
dest_path = path.join(ctx.model_path, name) dest_path = path.join(ctx.model_path, name)

View File

@ -62,6 +62,7 @@ def blend_loras(
): ):
# always load to CPU for blending # always load to CPU for blending
device = torch.device("cpu") device = torch.device("cpu")
dtype = context.torch_dtype()
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
@ -88,11 +89,11 @@ def blend_loras(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key "blending weights for keys: %s, %s, %s", key, up_key, alpha_key
) )
down_weight = lora_model[key].to(dtype=torch.float32) down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=torch.float32) up_weight = lora_model[up_key].to(dtype=dtype)
dim = down_weight.size()[0] dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy() alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
try: try:
if len(up_weight.size()) == 2: if len(up_weight.size()) == 2:
@ -203,7 +204,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name) updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
del base_model.graph.initializer[weight_idx] del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node) base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names: elif matmul_key in fixed_node_names:
@ -232,7 +233,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name) updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)
else: else:

View File

@ -22,7 +22,7 @@ def blend_textual_inversions(
) -> Tuple[ModelProto, CLIPTokenizer]: ) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending # always load to CPU for blending
device = torch.device("cpu") device = torch.device("cpu")
dtype = np.float dtype = context.numpy_dtype()
embeds = {} embeds = {}
for name, weight, base_token, inversion_format in inversions: for name, weight, base_token, inversion_format in inversions:
@ -149,7 +149,7 @@ def blend_textual_inversions(
== "text_model.embeddings.token_embedding.weight" == "text_model.embeddings.token_embedding.weight"
): ):
new_initializer = numpy_helper.from_array( new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name embedding_weights.astype(dtype), embedding_node.name
) )
logger.trace("new initializer data type: %s", new_initializer.data_type) logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i] del text_encoder.graph.initializer[i]

View File

@ -2,6 +2,9 @@ from logging import getLogger
from os import environ, path from os import environ, path
from typing import List, Optional from typing import List, Optional
import torch
import numpy as np
from ..utils import get_boolean from ..utils import get_boolean
from .model_cache import ModelCache from .model_cache import ModelCache
@ -77,3 +80,15 @@ class ServerContext:
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit, memory_limit=memory_limit,
) )
def torch_dtype(self):
if "torch-fp16" in self.optimizations:
return torch.float16
else:
return torch.float32
def numpy_dtype(self):
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
return np.float16
else:
return np.float32