1
0
Fork 0

feat(api): add support for custom tokens for textual inversions (#179)

This commit is contained in:
Sean Sube 2023-03-02 23:32:20 -06:00
parent d0b80451ad
commit 39d36618e6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 11 additions and 4 deletions

View File

@ -234,7 +234,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []): for inversion in model.get("inversions", []):
inversion_name = inversion["name"] inversion_name = inversion["name"]
inversion_source = inversion["source"] inversion_source = inversion["source"]
inversion_format = inversion.get("format", "huggingface") inversion_format = inversion.get("format", "embeddings")
inversion_source = fetch_model( inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source ctx, f"{name}-inversion-{inversion_name}", inversion_source
) )
@ -244,6 +244,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model["source"], model["source"],
inversion_source, inversion_source,
inversion_format, inversion_format,
base_token=inversion.get("token"),
) )
except Exception as e: except Exception as e:

View File

@ -1,5 +1,6 @@
from logging import getLogger from logging import getLogger
from os import makedirs, path from os import makedirs, path
from typing import Optional
import torch import torch
from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download
@ -13,7 +14,7 @@ logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def convert_diffusion_textual_inversion( def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str, format: str context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
): ):
dest_path = path.join(context.model_path, f"inversion-{name}") dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info( logger.info(
@ -31,7 +32,7 @@ def convert_diffusion_textual_inversion(
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
with open(token_file, "r") as f: with open(token_file, "r") as f:
token = f.read() token = base_token or f.read()
loaded_embeds = torch.load(embeds_file, map_location=context.map_location) loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
@ -50,7 +51,7 @@ def convert_diffusion_textual_inversion(
num_tokens = embeds.shape[0] num_tokens = embeds.shape[0]
logger.info("generating %s layer tokens", num_tokens) logger.info("generating %s layer tokens", num_tokens)
token = [f"{name}-{i}" for i in range(num_tokens)] token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
else: else:
raise ValueError(f"unknown textual inversion format: {format}") raise ValueError(f"unknown textual inversion format: {format}")

View File

@ -18,6 +18,11 @@ $defs:
type: string type: string
source: source:
type: string type: string
format:
type: string
enum: [concept, embeddings]
token:
type: string
base_model: base_model:
type: object type: object