fix(api): load blending tensors onto CPU
This commit is contained in:
parent
1c631c28d3
commit
33008531e9
|
@ -5,6 +5,8 @@ from .chain import (
|
|||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .convert.diffusion.lora import blend_loras
|
||||
from .convert.diffusion.textual_inversion import blend_textual_inversions
|
||||
from .diffusers.load import get_latents_from_seed, load_pipeline, optimize_pipeline
|
||||
from .diffusers.run import (
|
||||
run_blend_pipeline,
|
||||
|
|
|
@ -16,7 +16,7 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
|||
from safetensors.torch import load_file
|
||||
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext
|
||||
from ..utils import ConversionContext, load_tensor
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -61,8 +61,11 @@ def blend_loras(
|
|||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
|
||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||
lora_models = [load_file(name) for name, _weight in loras]
|
||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
||||
if model_type == "text_encoder":
|
||||
lora_prefix = "lora_te_"
|
||||
|
|
|
@ -9,7 +9,7 @@ from onnx import ModelProto, load_model, numpy_helper, save_model
|
|||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext
|
||||
from ..utils import ConversionContext, load_tensor
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -21,6 +21,8 @@ def blend_textual_inversions(
|
|||
tokenizer: CLIPTokenizer,
|
||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float
|
||||
embeds = {}
|
||||
|
||||
|
@ -47,19 +49,19 @@ def blend_textual_inversions(
|
|||
with open(token_file, "r") as f:
|
||||
token = base_token or f.read()
|
||||
|
||||
loaded_embeds = torch.load(embeds_file)
|
||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||
|
||||
# separate token and the embeds
|
||||
trained_token = list(loaded_embeds.keys())[0]
|
||||
|
||||
layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
|
||||
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
||||
layer *= weight
|
||||
if trained_token in embeds:
|
||||
embeds[token] += layer
|
||||
else:
|
||||
embeds[token] = layer
|
||||
elif inversion_format == "embeddings":
|
||||
loaded_embeds = torch.load(name)
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
@ -75,7 +77,7 @@ def blend_textual_inversions(
|
|||
|
||||
for i in range(num_tokens):
|
||||
token = f"{base_token}-{i}"
|
||||
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
|
||||
layer = trained_embeds[i, :].numpy().astype(dtype)
|
||||
layer *= weight
|
||||
|
||||
sum_layer += layer
|
||||
|
|
|
@ -199,9 +199,11 @@ def remove_prefix(name: str, prefix: str) -> str:
|
|||
|
||||
|
||||
def load_tensor(name: str, map_location=None):
|
||||
logger.info("loading model from checkpoint")
|
||||
logger.debug("loading tensor: %s", name)
|
||||
_, extension = path.splitext(name)
|
||||
if extension.lower() == ".safetensors":
|
||||
extension = extension[1:].lower()
|
||||
|
||||
if extension == "safetensors":
|
||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
try:
|
||||
logger.debug("loading safetensors")
|
||||
|
@ -209,7 +211,7 @@ def load_tensor(name: str, map_location=None):
|
|||
except Exception as e:
|
||||
try:
|
||||
logger.warning(
|
||||
"failed to load as safetensors file, falling back to torch: %s", e
|
||||
"failed to load as safetensors file, falling back to Torch JIT: %s", e
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
except Exception as e:
|
||||
|
@ -217,16 +219,17 @@ def load_tensor(name: str, map_location=None):
|
|||
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
||||
)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
checkpoint = (
|
||||
checkpoint["state_dict"]
|
||||
if "state_dict" in checkpoint
|
||||
else checkpoint
|
||||
)
|
||||
else:
|
||||
elif extension in ["", "bin", "ckpt", "pt"]:
|
||||
logger.debug("loading ckpt")
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
checkpoint = (
|
||||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
)
|
||||
elif extension in ["onnx", "pt"]:
|
||||
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
else:
|
||||
logger.warning("unknown tensor extension: %s", extension)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
return checkpoint
|
||||
|
|
Loading…
Reference in New Issue