diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index cebb5a18..809e580e 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -44,11 +44,13 @@ def convert_diffusion_textual_inversion( string_to_token = loaded_embeds["string_to_token"] string_to_param = loaded_embeds["string_to_param"] - token = name - # separate token and embeds trained_token = list(string_to_token.keys())[0] embeds = string_to_param[trained_token] + + num_tokens = embeds.shape[0] + logger.info("generating %s layer tokens", num_tokens) + token = [f"{name}-{i}" for i in range(num_tokens)] else: raise ValueError(f"unknown textual inversion format: {format}") @@ -74,12 +76,23 @@ def convert_diffusion_textual_inversion( f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." ) + logger.info("added %s tokens", num_added_tokens) + # resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) - # get the id for the token and assign the embeds - token_id = tokenizer.convert_tokens_to_ids(token) - text_encoder.get_input_embeddings().weight.data[token_id] = embeds + if len(embeds.shape) == 2: + # multiple vectors in embeds + for i in range(embeds.shape[0]): + layer_embeds = embeds[i] + layer_token = token[i] + logger.info("embedding %s vector for layer %s", layer_embeds.shape, layer_token) + token_id = tokenizer.convert_tokens_to_ids(layer_token) + text_encoder.get_input_embeddings().weight.data[token_id] = layer_embeds + else: + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds # conversion stuff text_input = tokenizer(