diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 5aa4b715..a13f5203 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -65,14 +65,24 @@ def blend_textual_inversions( num_tokens = trained_embeds.shape[0] logger.debug("generating %s layer tokens for %s", num_tokens, name) + sum_layer = np.zeros(trained_embeds[0, :].shape) + for i in range(num_tokens): token = f"{base_token or name}-{i}" layer = trained_embeds[i, :].cpu().numpy().astype(dtype) layer *= weight + sum_layer += layer if token in embeds: embeds[token] += layer else: embeds[token] = layer + + # add sum layer to embeds + sum_token = f"{base_token or name}-all" + if sum_token in embeds: + embeds[sum_token] += sum_layer + else: + embeds[sum_token] = sum_layer else: raise ValueError(f"unknown Textual Inversion format: {format}")