start porting SD upscale to ONNX
This commit is contained in:
parent
f0c905721f
commit
8d346cbed0
|
@ -3,17 +3,17 @@ from diffusers import (
|
|||
OnnxRuntimeModel,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from logging import getLogger
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Union,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
def __init__(
|
||||
|
@ -56,7 +56,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device))
|
||||
if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
# TODO: TypeError: __call__() takes 1 positional argument but 2 were given
|
||||
# no positional arguments to text_encoder
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
|
@ -94,7 +104,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device))
|
||||
if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
@ -105,6 +123,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
|
Loading…
Reference in New Issue