1
0
Fork 0

type and test fixes

This commit is contained in:
Sean Sube 2023-12-26 20:21:34 -06:00
parent 6d2d5058d9
commit cfe7a55935
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 23 additions and 3 deletions

View File

@ -25,7 +25,9 @@ class TileCallback(Protocol):
Definition for a tile job function.
"""
def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
def __call__(
self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
"""
Run this stage against a single tile.
"""

View File

@ -200,7 +200,9 @@ def expand_prompt(
)[0]
if negative_prompt_embeds is not None:
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
negative_padding = (
tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
)
logger.trace(
"padding negative prompt to match input: %s, %s, %s extra tokens",
tokens.input_ids.shape,

View File

@ -1,8 +1,9 @@
import unittest
from unittest.mock import MagicMock
from os import path
import torch
from diffusers import DDIMScheduler
from diffusers import DDIMScheduler, OnnxRuntimeModel
from onnx_web.diffusers.load import (
get_available_pipelines,
@ -120,8 +121,13 @@ class TestPatchPipeline(unittest.TestCase):
pass
def test_unet_wrapper_not_xl(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext()
pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline(
server,
pipeline,
@ -131,8 +137,13 @@ class TestPatchPipeline(unittest.TestCase):
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext()
pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline(
server,
pipeline,
@ -142,8 +153,13 @@ class TestPatchPipeline(unittest.TestCase):
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_vae_wrapper(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext()
pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline(
server,
pipeline,