feat(api): add support for custom tokens for textual inversions (#179)
This commit is contained in:
parent
d0b80451ad
commit
39d36618e6
|
@ -234,7 +234,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
for inversion in model.get("inversions", []):
|
||||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", "huggingface")
|
||||
inversion_format = inversion.get("format", "embeddings")
|
||||
inversion_source = fetch_model(
|
||||
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
||||
)
|
||||
|
@ -244,6 +244,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
model["source"],
|
||||
inversion_source,
|
||||
inversion_format,
|
||||
base_token=inversion.get("token"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
@ -13,7 +14,7 @@ logger = getLogger(__name__)
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_textual_inversion(
|
||||
context: ConversionContext, name: str, base_model: str, inversion: str, format: str
|
||||
context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
|
||||
):
|
||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||
logger.info(
|
||||
|
@ -31,7 +32,7 @@ def convert_diffusion_textual_inversion(
|
|||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||
|
||||
with open(token_file, "r") as f:
|
||||
token = f.read()
|
||||
token = base_token or f.read()
|
||||
|
||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||
|
||||
|
@ -50,7 +51,7 @@ def convert_diffusion_textual_inversion(
|
|||
|
||||
num_tokens = embeds.shape[0]
|
||||
logger.info("generating %s layer tokens", num_tokens)
|
||||
token = [f"{name}-{i}" for i in range(num_tokens)]
|
||||
token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
|
||||
else:
|
||||
raise ValueError(f"unknown textual inversion format: {format}")
|
||||
|
||||
|
|
|
@ -18,6 +18,11 @@ $defs:
|
|||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
type: string
|
||||
enum: [concept, embeddings]
|
||||
token:
|
||||
type: string
|
||||
|
||||
base_model:
|
||||
type: object
|
||||
|
|
Loading…
Reference in New Issue