continue converting upscale to ONNX
This commit is contained in:
parent
c905fbb728
commit
8c6d957a53
|
@ -6,6 +6,7 @@ from diffusers import (
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
List,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -56,12 +57,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None:
|
|
||||||
attention_mask = text_inputs.attention_mask.to(device)
|
|
||||||
else:
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
# TODO: TypeError: __call__() takes 1 positional argument but 2 were given
|
|
||||||
# no positional arguments to text_encoder
|
# no positional arguments to text_encoder
|
||||||
text_embeddings = self.text_encoder(
|
text_embeddings = self.text_encoder(
|
||||||
input_ids=text_input_ids.int().to(device),
|
input_ids=text_input_ids.int().to(device),
|
||||||
|
@ -70,7 +65,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
text_embeddings = text_embeddings[0]
|
text_embeddings = text_embeddings[0]
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||||
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
# text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
@ -104,11 +98,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None:
|
|
||||||
attention_mask = uncond_input.attention_mask.to(device)
|
|
||||||
else:
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
uncond_embeddings = self.text_encoder(
|
uncond_embeddings = self.text_encoder(
|
||||||
input_ids=uncond_input.input_ids.int().to(device),
|
input_ids=uncond_input.input_ids.int().to(device),
|
||||||
# attention_mask=attention_mask,
|
# attention_mask=attention_mask,
|
||||||
|
@ -116,7 +105,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
uncond_embeddings = uncond_embeddings[0]
|
uncond_embeddings = uncond_embeddings[0]
|
||||||
|
|
||||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
seq_len = uncond_embeddings.shape[1]
|
|
||||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) #, 1)
|
||||||
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
# uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue