From fe498b16f0a31033aa1a6d559564683e885e0163 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 19:01:22 -0500 Subject: [PATCH] fix(api): embed Inversion concepts using their name --- api/onnx_web/convert/diffusion/textual_inversion.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 51f0978c..da0179d4 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -44,10 +44,10 @@ def blend_textual_inversions( if inversion_format == "concept": # TODO: this should be done in fetch, maybe embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") - token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") + token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed with open(token_file, "r") as f: - token = base_token or f.read() + token = f.read() loaded_embeds = load_tensor(embeds_file, map_location=device) if loaded_embeds is None: @@ -59,7 +59,13 @@ def blend_textual_inversions( layer = loaded_embeds[trained_token].numpy().astype(dtype) layer *= weight - if trained_token in embeds: + + if base_token in embeds: + embeds[base_token] += layer + else: + embeds[base_token] = layer + + if token in embeds: embeds[token] += layer else: embeds[token] = layer