1
0
Fork 0

continue converting upscale to ONNX

This commit is contained in:
Sean Sube 2023-01-29 15:23:19 -06:00
parent c905fbb728
commit 8c6d957a53
1 changed files with 1 additions and 13 deletions

View File

@ -6,6 +6,7 @@ from diffusers import (
from logging import getLogger
from typing import (
Any,
List,
)
import numpy as np
@ -56,12 +57,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
# TODO: TypeError: __call__() takes 1 positional argument but 2 were given
# no positional arguments to text_encoder
text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device),
@ -70,7 +65,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
@ -104,11 +98,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt",
)
if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device),
# attention_mask=attention_mask,
@ -116,7 +105,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)