1
0
Fork 0

fix(api): load blending tensors onto CPU

This commit is contained in:
Sean Sube 2023-03-19 15:13:54 -05:00
parent 1c631c28d3
commit 33008531e9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 29 additions and 19 deletions

View File

@ -5,6 +5,8 @@ from .chain import (
upscale_resrgan,
upscale_stable_diffusion,
)
from .convert.diffusion.lora import blend_loras
from .convert.diffusion.textual_inversion import blend_textual_inversions
from .diffusers.load import get_latents_from_seed, load_pipeline, optimize_pipeline
from .diffusers.run import (
run_blend_pipeline,

View File

@ -16,7 +16,7 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file
from ...server.context import ServerContext
from ..utils import ConversionContext
from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
@ -61,8 +61,11 @@ def blend_loras(
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
):
# always load to CPU for blending
device = torch.device("cpu")
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_file(name) for name, _weight in loras]
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
lora_prefix = "lora_te_"

View File

@ -9,7 +9,7 @@ from onnx import ModelProto, load_model, numpy_helper, save_model
from transformers import CLIPTokenizer
from ...server.context import ServerContext
from ..utils import ConversionContext
from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
@ -21,6 +21,8 @@ def blend_textual_inversions(
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float
embeds = {}
@ -47,19 +49,19 @@ def blend_textual_inversions(
with open(token_file, "r") as f:
token = base_token or f.read()
loaded_embeds = torch.load(embeds_file)
loaded_embeds = load_tensor(embeds_file, map_location=device)
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
layer = loaded_embeds[trained_token].numpy().astype(dtype)
layer *= weight
if trained_token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
elif inversion_format == "embeddings":
loaded_embeds = torch.load(name)
loaded_embeds = load_tensor(name, map_location=device)
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
@ -75,7 +77,7 @@ def blend_textual_inversions(
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer

View File

@ -199,9 +199,11 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_tensor(name: str, map_location=None):
logger.info("loading model from checkpoint")
logger.debug("loading tensor: %s", name)
_, extension = path.splitext(name)
if extension.lower() == ".safetensors":
extension = extension[1:].lower()
if extension == "safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
@ -209,7 +211,7 @@ def load_tensor(name: str, map_location=None):
except Exception as e:
try:
logger.warning(
"failed to load as safetensors file, falling back to torch: %s", e
"failed to load as safetensors file, falling back to Torch JIT: %s", e
)
checkpoint = torch.jit.load(name)
except Exception as e:
@ -217,16 +219,17 @@ def load_tensor(name: str, map_location=None):
"failed to load with Torch JIT, falling back to PyTorch: %s", e
)
checkpoint = torch.load(name, map_location=map_location)
checkpoint = (
checkpoint["state_dict"]
if "state_dict" in checkpoint
else checkpoint
)
else:
elif extension in ["", "bin", "ckpt", "pt"]:
logger.debug("loading ckpt")
checkpoint = torch.load(name, map_location=map_location)
checkpoint = (
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
)
elif extension in ["onnx", "pt"]:
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
checkpoint = torch.load(name, map_location=map_location)
else:
logger.warning("unknown tensor extension: %s", extension)
checkpoint = torch.load(name, map_location=map_location)
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint