continue converting upscale to ONNX
This commit is contained in:
parent
c905fbb728
commit
8c6d957a53
|
@ -6,6 +6,7 @@ from diffusers import (
|
|||
from logging import getLogger
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
@ -56,12 +57,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
# TODO: TypeError: __call__() takes 1 positional argument but 2 were given
|
||||
# no positional arguments to text_encoder
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
|
@ -70,7 +65,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
|
@ -104,11 +98,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
|
@ -116,7 +105,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue