fix(api): load blending tensors onto CPU
This commit is contained in:
parent
1c631c28d3
commit
33008531e9
|
@ -5,6 +5,8 @@ from .chain import (
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
|
from .convert.diffusion.lora import blend_loras
|
||||||
|
from .convert.diffusion.textual_inversion import blend_textual_inversions
|
||||||
from .diffusers.load import get_latents_from_seed, load_pipeline, optimize_pipeline
|
from .diffusers.load import get_latents_from_seed, load_pipeline, optimize_pipeline
|
||||||
from .diffusers.run import (
|
from .diffusers.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext
|
from ..utils import ConversionContext, load_tensor
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -61,8 +61,11 @@ def blend_loras(
|
||||||
loras: List[Tuple[str, float]],
|
loras: List[Tuple[str, float]],
|
||||||
model_type: Literal["text_encoder", "unet"],
|
model_type: Literal["text_encoder", "unet"],
|
||||||
):
|
):
|
||||||
|
# always load to CPU for blending
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||||
lora_models = [load_file(name) for name, _weight in loras]
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||||
|
|
||||||
if model_type == "text_encoder":
|
if model_type == "text_encoder":
|
||||||
lora_prefix = "lora_te_"
|
lora_prefix = "lora_te_"
|
||||||
|
|
|
@ -9,7 +9,7 @@ from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext
|
from ..utils import ConversionContext, load_tensor
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -21,6 +21,8 @@ def blend_textual_inversions(
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
|
# always load to CPU for blending
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = np.float
|
dtype = np.float
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
|
@ -47,19 +49,19 @@ def blend_textual_inversions(
|
||||||
with open(token_file, "r") as f:
|
with open(token_file, "r") as f:
|
||||||
token = base_token or f.read()
|
token = base_token or f.read()
|
||||||
|
|
||||||
loaded_embeds = torch.load(embeds_file)
|
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||||
|
|
||||||
# separate token and the embeds
|
# separate token and the embeds
|
||||||
trained_token = list(loaded_embeds.keys())[0]
|
trained_token = list(loaded_embeds.keys())[0]
|
||||||
|
|
||||||
layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
|
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
||||||
layer *= weight
|
layer *= weight
|
||||||
if trained_token in embeds:
|
if trained_token in embeds:
|
||||||
embeds[token] += layer
|
embeds[token] += layer
|
||||||
else:
|
else:
|
||||||
embeds[token] = layer
|
embeds[token] = layer
|
||||||
elif inversion_format == "embeddings":
|
elif inversion_format == "embeddings":
|
||||||
loaded_embeds = torch.load(name)
|
loaded_embeds = load_tensor(name, map_location=device)
|
||||||
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
@ -75,7 +77,7 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
token = f"{base_token}-{i}"
|
token = f"{base_token}-{i}"
|
||||||
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
|
layer = trained_embeds[i, :].numpy().astype(dtype)
|
||||||
layer *= weight
|
layer *= weight
|
||||||
|
|
||||||
sum_layer += layer
|
sum_layer += layer
|
||||||
|
|
|
@ -199,9 +199,11 @@ def remove_prefix(name: str, prefix: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def load_tensor(name: str, map_location=None):
|
def load_tensor(name: str, map_location=None):
|
||||||
logger.info("loading model from checkpoint")
|
logger.debug("loading tensor: %s", name)
|
||||||
_, extension = path.splitext(name)
|
_, extension = path.splitext(name)
|
||||||
if extension.lower() == ".safetensors":
|
extension = extension[1:].lower()
|
||||||
|
|
||||||
|
if extension == "safetensors":
|
||||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
try:
|
try:
|
||||||
logger.debug("loading safetensors")
|
logger.debug("loading safetensors")
|
||||||
|
@ -209,7 +211,7 @@ def load_tensor(name: str, map_location=None):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"failed to load as safetensors file, falling back to torch: %s", e
|
"failed to load as safetensors file, falling back to Torch JIT: %s", e
|
||||||
)
|
)
|
||||||
checkpoint = torch.jit.load(name)
|
checkpoint = torch.jit.load(name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -217,16 +219,17 @@ def load_tensor(name: str, map_location=None):
|
||||||
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
||||||
)
|
)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
checkpoint = (
|
elif extension in ["", "bin", "ckpt", "pt"]:
|
||||||
checkpoint["state_dict"]
|
|
||||||
if "state_dict" in checkpoint
|
|
||||||
else checkpoint
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("loading ckpt")
|
logger.debug("loading ckpt")
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
checkpoint = (
|
elif extension in ["onnx", "pt"]:
|
||||||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
|
||||||
)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
else:
|
||||||
|
logger.warning("unknown tensor extension: %s", extension)
|
||||||
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
Loading…
Reference in New Issue