type and test fixes
This commit is contained in:
parent
6d2d5058d9
commit
cfe7a55935
|
@ -25,7 +25,9 @@ class TileCallback(Protocol):
|
|||
Definition for a tile job function.
|
||||
"""
|
||||
|
||||
def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||
def __call__(
|
||||
self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]
|
||||
) -> StageResult:
|
||||
"""
|
||||
Run this stage against a single tile.
|
||||
"""
|
||||
|
|
|
@ -200,7 +200,9 @@ def expand_prompt(
|
|||
)[0]
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
||||
negative_padding = (
|
||||
tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
||||
)
|
||||
logger.trace(
|
||||
"padding negative prompt to match input: %s, %s, %s extra tokens",
|
||||
tokens.input_ids.shape,
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers import DDIMScheduler, OnnxRuntimeModel
|
||||
|
||||
from onnx_web.diffusers.load import (
|
||||
get_available_pipelines,
|
||||
|
@ -120,8 +121,13 @@ class TestPatchPipeline(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_unet_wrapper_not_xl(self):
|
||||
session = MagicMock()
|
||||
session.get_inputs.return_value = []
|
||||
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
|
@ -131,8 +137,13 @@ class TestPatchPipeline(unittest.TestCase):
|
|||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_unet_wrapper_xl(self):
|
||||
session = MagicMock()
|
||||
session.get_inputs.return_value = []
|
||||
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
|
@ -142,8 +153,13 @@ class TestPatchPipeline(unittest.TestCase):
|
|||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_vae_wrapper(self):
|
||||
session = MagicMock()
|
||||
session.get_inputs.return_value = []
|
||||
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||
|
||||
patch_pipeline(
|
||||
server,
|
||||
pipeline,
|
||||
|
|
Loading…
Reference in New Issue