feat(api): add support for custom tokens for textual inversions (#179)
This commit is contained in:
parent
d0b80451ad
commit
39d36618e6
|
@ -234,7 +234,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
for inversion in model.get("inversions", []):
|
for inversion in model.get("inversions", []):
|
||||||
inversion_name = inversion["name"]
|
inversion_name = inversion["name"]
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
inversion_format = inversion.get("format", "huggingface")
|
inversion_format = inversion.get("format", "embeddings")
|
||||||
inversion_source = fetch_model(
|
inversion_source = fetch_model(
|
||||||
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
||||||
)
|
)
|
||||||
|
@ -244,6 +244,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
model["source"],
|
model["source"],
|
||||||
inversion_source,
|
inversion_source,
|
||||||
inversion_format,
|
inversion_format,
|
||||||
|
base_token=inversion.get("token"),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import makedirs, path
|
from os import makedirs, path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
@ -13,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_textual_inversion(
|
def convert_diffusion_textual_inversion(
|
||||||
context: ConversionContext, name: str, base_model: str, inversion: str, format: str
|
context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -31,7 +32,7 @@ def convert_diffusion_textual_inversion(
|
||||||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||||
|
|
||||||
with open(token_file, "r") as f:
|
with open(token_file, "r") as f:
|
||||||
token = f.read()
|
token = base_token or f.read()
|
||||||
|
|
||||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||||
|
|
||||||
|
@ -50,7 +51,7 @@ def convert_diffusion_textual_inversion(
|
||||||
|
|
||||||
num_tokens = embeds.shape[0]
|
num_tokens = embeds.shape[0]
|
||||||
logger.info("generating %s layer tokens", num_tokens)
|
logger.info("generating %s layer tokens", num_tokens)
|
||||||
token = [f"{name}-{i}" for i in range(num_tokens)]
|
token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown textual inversion format: {format}")
|
raise ValueError(f"unknown textual inversion format: {format}")
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,11 @@ $defs:
|
||||||
type: string
|
type: string
|
||||||
source:
|
source:
|
||||||
type: string
|
type: string
|
||||||
|
format:
|
||||||
|
type: string
|
||||||
|
enum: [concept, embeddings]
|
||||||
|
token:
|
||||||
|
type: string
|
||||||
|
|
||||||
base_model:
|
base_model:
|
||||||
type: object
|
type: object
|
||||||
|
|
Loading…
Reference in New Issue