1
0
Fork 0

look up input types for SD1.5 as well

This commit is contained in:
Sean Sube 2024-03-03 12:14:46 -06:00
parent 6dbed8c114
commit d084f53a7e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 1 deletions

View File

@ -7,6 +7,16 @@ from compel import Compel, ReturnedEmbeddingsType
from diffusers import OnnxStableDiffusionPipeline
def get_inference_session(model):
if hasattr(model, "session"):
return model.session
if hasattr(model, "model"):
return model.model
raise ValueError("Model does not have an inference session")
def wrap_encoder(text_encoder, sdxl=False):
class WrappedEncoder:
device = "cpu"
@ -24,7 +34,8 @@ def wrap_encoder(text_encoder, sdxl=False):
If `output_hidden_states` is None, return pooled embeds.
"""
dtype = np.int32
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
session = get_inference_session(self.text_encoder)
if session.get_inputs()[0].type == "tensor(int64)":
dtype = np.int64
# TODO: does compel use attention masks?