From cfe7a55935c4815e7470d4da68baaf3f6f2a5365 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 26 Dec 2023 20:21:34 -0600 Subject: [PATCH] type and test fixes --- api/onnx_web/chain/tile.py | 4 +++- api/onnx_web/diffusers/utils.py | 4 +++- api/tests/test_diffusers/test_load.py | 18 +++++++++++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index a00a5b4d..e8e1baff 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -25,7 +25,9 @@ class TileCallback(Protocol): Definition for a tile job function. """ - def __call__(self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]) -> StageResult: + def __call__( + self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int] + ) -> StageResult: """ Run this stage against a single tile. """ diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index d83d0f98..ecf01373 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -200,7 +200,9 @@ def expand_prompt( )[0] if negative_prompt_embeds is not None: - negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] + negative_padding = ( + tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] + ) logger.trace( "padding negative prompt to match input: %s, %s, %s extra tokens", tokens.input_ids.shape, diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index 014f7aa0..22ca1a8d 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -1,8 +1,9 @@ import unittest +from unittest.mock import MagicMock from os import path import torch -from diffusers import DDIMScheduler +from diffusers import DDIMScheduler, OnnxRuntimeModel from onnx_web.diffusers.load import ( get_available_pipelines, @@ -120,8 +121,13 @@ class TestPatchPipeline(unittest.TestCase): pass def test_unet_wrapper_not_xl(self): + session = MagicMock() + session.get_inputs.return_value = [] + server = ServerContext() pipeline = MockPipeline() + pipeline.unet = OnnxRuntimeModel(model=session) + patch_pipeline( server, pipeline, @@ -131,8 +137,13 @@ class TestPatchPipeline(unittest.TestCase): self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) def test_unet_wrapper_xl(self): + session = MagicMock() + session.get_inputs.return_value = [] + server = ServerContext() pipeline = MockPipeline() + pipeline.unet = OnnxRuntimeModel(model=session) + patch_pipeline( server, pipeline, @@ -142,8 +153,13 @@ class TestPatchPipeline(unittest.TestCase): self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) def test_vae_wrapper(self): + session = MagicMock() + session.get_inputs.return_value = [] + server = ServerContext() pipeline = MockPipeline() + pipeline.unet = OnnxRuntimeModel(model=session) + patch_pipeline( server, pipeline,