From 8d346cbed0dce286c9f0d6e67907942512a84742 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 29 Jan 2023 13:49:30 -0600 Subject: [PATCH] start porting SD upscale to ONNX --- .../pipeline_onnx_stable_diffusion_upscale.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 9e651152..e5d9592c 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -3,17 +3,17 @@ from diffusers import ( OnnxRuntimeModel, StableDiffusionUpscalePipeline, ) +from logging import getLogger from typing import ( Any, - Callable, - Union, - List, - Optional, ) -import PIL +import numpy as np import torch +logger = getLogger(__name__) + + class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): def __init__( @@ -56,7 +56,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(input_ids=text_input_ids.int().to(device)) + if hasattr(text_inputs, 'attention_mask') and text_inputs.attention_mask is not None: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + # TODO: TypeError: __call__() takes 1 positional argument but 2 were given + # no positional arguments to text_encoder + text_embeddings = self.text_encoder( + input_ids=text_input_ids.int().to(device), + # attention_mask=attention_mask, + ) text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -94,7 +104,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): return_tensors="pt", ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.int().to(device)) + if hasattr(uncond_input, 'attention_mask') and uncond_input.attention_mask is not None: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.int().to(device), + # attention_mask=attention_mask, + ) uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -105,6 +123,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) return text_embeddings