type and test fixes
This commit is contained in:
parent
6d2d5058d9
commit
cfe7a55935
|
@ -25,7 +25,9 @@ class TileCallback(Protocol):
|
||||||
Definition for a tile job function.
|
Definition for a tile job function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
def __call__(
|
||||||
|
self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]
|
||||||
|
) -> StageResult:
|
||||||
"""
|
"""
|
||||||
Run this stage against a single tile.
|
Run this stage against a single tile.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -200,7 +200,9 @@ def expand_prompt(
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if negative_prompt_embeds is not None:
|
if negative_prompt_embeds is not None:
|
||||||
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
negative_padding = (
|
||||||
|
tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
||||||
|
)
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"padding negative prompt to match input: %s, %s, %s extra tokens",
|
"padding negative prompt to match input: %s, %s, %s extra tokens",
|
||||||
tokens.input_ids.shape,
|
tokens.input_ids.shape,
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DDIMScheduler
|
from diffusers import DDIMScheduler, OnnxRuntimeModel
|
||||||
|
|
||||||
from onnx_web.diffusers.load import (
|
from onnx_web.diffusers.load import (
|
||||||
get_available_pipelines,
|
get_available_pipelines,
|
||||||
|
@ -120,8 +121,13 @@ class TestPatchPipeline(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_unet_wrapper_not_xl(self):
|
def test_unet_wrapper_not_xl(self):
|
||||||
|
session = MagicMock()
|
||||||
|
session.get_inputs.return_value = []
|
||||||
|
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||||
|
|
||||||
patch_pipeline(
|
patch_pipeline(
|
||||||
server,
|
server,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
@ -131,8 +137,13 @@ class TestPatchPipeline(unittest.TestCase):
|
||||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||||
|
|
||||||
def test_unet_wrapper_xl(self):
|
def test_unet_wrapper_xl(self):
|
||||||
|
session = MagicMock()
|
||||||
|
session.get_inputs.return_value = []
|
||||||
|
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||||
|
|
||||||
patch_pipeline(
|
patch_pipeline(
|
||||||
server,
|
server,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
@ -142,8 +153,13 @@ class TestPatchPipeline(unittest.TestCase):
|
||||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||||
|
|
||||||
def test_vae_wrapper(self):
|
def test_vae_wrapper(self):
|
||||||
|
session = MagicMock()
|
||||||
|
session.get_inputs.return_value = []
|
||||||
|
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
pipeline = MockPipeline()
|
pipeline = MockPipeline()
|
||||||
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
||||||
|
|
||||||
patch_pipeline(
|
patch_pipeline(
|
||||||
server,
|
server,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
|
Loading…
Reference in New Issue