look up input types for SD1.5 as well
This commit is contained in:
parent
6dbed8c114
commit
d084f53a7e
|
@ -7,6 +7,16 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
|
||||
def get_inference_session(model):
|
||||
if hasattr(model, "session"):
|
||||
return model.session
|
||||
|
||||
if hasattr(model, "model"):
|
||||
return model.model
|
||||
|
||||
raise ValueError("Model does not have an inference session")
|
||||
|
||||
|
||||
def wrap_encoder(text_encoder, sdxl=False):
|
||||
class WrappedEncoder:
|
||||
device = "cpu"
|
||||
|
@ -24,7 +34,8 @@ def wrap_encoder(text_encoder, sdxl=False):
|
|||
If `output_hidden_states` is None, return pooled embeds.
|
||||
"""
|
||||
dtype = np.int32
|
||||
if text_encoder.session.get_inputs()[0].type == "tensor(int64)":
|
||||
session = get_inference_session(self.text_encoder)
|
||||
if session.get_inputs()[0].type == "tensor(int64)":
|
||||
dtype = np.int64
|
||||
|
||||
# TODO: does compel use attention masks?
|
||||
|
|
Loading…
Reference in New Issue