fix(api): apply fp16 optimizations to LoRA and Textual Inversion blending
This commit is contained in:
parent
4f6574c88e
commit
0315a8cbc6
|
@ -97,8 +97,7 @@ def convert_diffusion_diffusers(
|
||||||
single_vae = model.get("single_vae")
|
single_vae = model.get("single_vae")
|
||||||
replace_vae = model.get("vae")
|
replace_vae = model.get("vae")
|
||||||
|
|
||||||
torch_half = "torch-fp16" in ctx.optimizations
|
torch_dtype = ctx.torch_dtype()
|
||||||
torch_dtype = torch.float16 if torch_half else torch.float32
|
|
||||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||||
|
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
|
|
@ -62,6 +62,7 @@ def blend_loras(
|
||||||
):
|
):
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
dtype = context.torch_dtype()
|
||||||
|
|
||||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||||
|
@ -88,11 +89,11 @@ def blend_loras(
|
||||||
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
||||||
)
|
)
|
||||||
|
|
||||||
down_weight = lora_model[key].to(dtype=torch.float32)
|
down_weight = lora_model[key].to(dtype=dtype)
|
||||||
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
up_weight = lora_model[up_key].to(dtype=dtype)
|
||||||
|
|
||||||
dim = down_weight.size()[0]
|
dim = down_weight.size()[0]
|
||||||
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
|
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
|
@ -203,7 +204,7 @@ def blend_loras(
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
|
||||||
del base_model.graph.initializer[weight_idx]
|
del base_model.graph.initializer[weight_idx]
|
||||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
|
@ -232,7 +233,7 @@ def blend_loras(
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,7 +22,7 @@ def blend_textual_inversions(
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = np.float
|
dtype = context.numpy_dtype()
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, weight, base_token, inversion_format in inversions:
|
for name, weight, base_token, inversion_format in inversions:
|
||||||
|
@ -149,7 +149,7 @@ def blend_textual_inversions(
|
||||||
== "text_model.embeddings.token_embedding.weight"
|
== "text_model.embeddings.token_embedding.weight"
|
||||||
):
|
):
|
||||||
new_initializer = numpy_helper.from_array(
|
new_initializer = numpy_helper.from_array(
|
||||||
embedding_weights.astype(np.float32), embedding_node.name
|
embedding_weights.astype(dtype), embedding_node.name
|
||||||
)
|
)
|
||||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||||
del text_encoder.graph.initializer[i]
|
del text_encoder.graph.initializer[i]
|
||||||
|
|
|
@ -2,6 +2,9 @@ from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ..utils import get_boolean
|
from ..utils import get_boolean
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
|
||||||
|
@ -77,3 +80,15 @@ class ServerContext:
|
||||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
memory_limit=memory_limit,
|
memory_limit=memory_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def torch_dtype(self):
|
||||||
|
if "torch-fp16" in self.optimizations:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
def numpy_dtype(self):
|
||||||
|
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
|
||||||
|
return np.float16
|
||||||
|
else:
|
||||||
|
return np.float32
|
Loading…
Reference in New Issue