1
0
Fork 0

type and test fixes

This commit is contained in:
Sean Sube 2023-12-26 20:21:34 -06:00
parent 6d2d5058d9
commit cfe7a55935
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 23 additions and 3 deletions

View File

@ -25,7 +25,9 @@ class TileCallback(Protocol):
Definition for a tile job function. Definition for a tile job function.
""" """
def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult: def __call__(
self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
""" """
Run this stage against a single tile. Run this stage against a single tile.
""" """

View File

@ -200,7 +200,9 @@ def expand_prompt(
)[0] )[0]
if negative_prompt_embeds is not None: if negative_prompt_embeds is not None:
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] negative_padding = (
tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
)
logger.trace( logger.trace(
"padding negative prompt to match input: %s, %s, %s extra tokens", "padding negative prompt to match input: %s, %s, %s extra tokens",
tokens.input_ids.shape, tokens.input_ids.shape,

View File

@ -1,8 +1,9 @@
import unittest import unittest
from unittest.mock import MagicMock
from os import path from os import path
import torch import torch
from diffusers import DDIMScheduler from diffusers import DDIMScheduler, OnnxRuntimeModel
from onnx_web.diffusers.load import ( from onnx_web.diffusers.load import (
get_available_pipelines, get_available_pipelines,
@ -120,8 +121,13 @@ class TestPatchPipeline(unittest.TestCase):
pass pass
def test_unet_wrapper_not_xl(self): def test_unet_wrapper_not_xl(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline( patch_pipeline(
server, server,
pipeline, pipeline,
@ -131,8 +137,13 @@ class TestPatchPipeline(unittest.TestCase):
self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_unet_wrapper_xl(self): def test_unet_wrapper_xl(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline( patch_pipeline(
server, server,
pipeline, pipeline,
@ -142,8 +153,13 @@ class TestPatchPipeline(unittest.TestCase):
self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
def test_vae_wrapper(self): def test_vae_wrapper(self):
session = MagicMock()
session.get_inputs.return_value = []
server = ServerContext() server = ServerContext()
pipeline = MockPipeline() pipeline = MockPipeline()
pipeline.unet = OnnxRuntimeModel(model=session)
patch_pipeline( patch_pipeline(
server, server,
pipeline, pipeline,