1
0
Fork 0

start porting SD upscale to ONNX

This commit is contained in:
Sean Sube 2023-01-29 13:49:30 -06:00
parent f0c905721f
commit 8d346cbed0
1 changed files with 26 additions and 8 deletions

View File

@ -3,17 +3,17 @@ from diffusers import (
OnnxRuntimeModel,
StableDiffusionUpscalePipeline,
)
from logging import getLogger
from typing import (
Any,
Callable,
Union,
List,
Optional,
)
import PIL
import numpy as np
import torch
logger = getLogger(__name__)
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
def __init__(
@ -56,7 +56,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device))
if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
# TODO: TypeError: __call__() takes 1 positional argument but 2 were given
# no positional arguments to text_encoder
text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device),
# attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
@ -94,7 +104,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device))
if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device),
# attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@ -105,6 +123,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
return text_embeddings