From 963794abaa4a131018f2435a1c7f09baf444a2b4 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 15 Sep 2023 19:16:47 -0500 Subject: [PATCH] lint, tests --- api/onnx_web/chain/blend_grid.py | 10 +-- api/onnx_web/chain/stage.py | 14 ++-- api/onnx_web/server/params.py | 4 +- api/tests/convert/diffusion/test_lora.py | 21 ++++++ api/tests/mocks.py | 43 ++++++++++++ api/tests/test_diffusers/test_load.py | 85 +++++++++++++++++++++--- api/tests/test_diffusers/test_utils.py | 85 ++++++++++++++++++++++++ 7 files changed, 240 insertions(+), 22 deletions(-) create mode 100644 api/tests/mocks.py create mode 100644 api/tests/test_diffusers/test_utils.py diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index cf4b9f90..19b1eca2 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -22,12 +22,12 @@ class BlendGridStage(BaseStage): *, height: int, width: int, - rows: Optional[List[str]] = None, - columns: Optional[List[str]] = None, - title: Optional[str] = None, + # rows: Optional[List[str]] = None, + # columns: Optional[List[str]] = None, + # title: Optional[str] = None, order: Optional[int] = None, stage_source: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> List[Image.Image]: logger.info("combining source images using grid layout") @@ -51,7 +51,7 @@ class BlendGridStage(BaseStage): def outputs( self, - params: ImageParams, + _params: ImageParams, sources: int, ) -> int: return sources + 1 diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py index 3942460b..c9c6eafd 100644 --- a/api/onnx_web/chain/stage.py +++ b/api/onnx_web/chain/stage.py @@ -12,11 +12,11 @@ class BaseStage: def run( self, - worker: WorkerContext, - server: ServerContext, - stage: StageParams, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + _sources: List[Image.Image], *args, stage_source: Optional[Image.Image] = None, **kwargs, @@ -25,14 +25,14 @@ class BaseStage: def steps( self, - params: ImageParams, - size: Size, + _params: ImageParams, + _size: Size, ) -> int: return 1 # noqa def outputs( self, - params: ImageParams, + _params: ImageParams, sources: int, ) -> int: return sources diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 62332944..d68e2dcc 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -35,7 +35,7 @@ logger = getLogger(__name__) def build_device( - server: ServerContext, + _server: ServerContext, data: Dict[str, str], ) -> Optional[DeviceParams]: # platform stuff @@ -172,7 +172,7 @@ def build_params( def build_size( - server: ServerContext, + _server: ServerContext, data: Dict[str, str], ) -> Size: height = get_and_clamp_int( diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index 672c6be6..01372e93 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -153,12 +153,33 @@ class KernelSliceTests(unittest.TestCase): class BlendLoRATests(unittest.TestCase): def test_blend_unet(self): + """ + blend_loras(None, "test", [], "unet") + """ pass def test_blend_text_encoder(self): + """ + blend_loras(None, "test", [], "text_encoder") + """ pass def test_blend_text_encoder_index(self): + """ + blend_loras(None, "test", [], "text_encoder", model_index=2) + """ + pass + + def test_unmatched_keys(self): + pass + + def test_xl_keys(self): + """ + blend_loras(None, "test", [], "unet", xl=True) + """ + pass + + def test_node_dtype(self): pass diff --git a/api/tests/mocks.py b/api/tests/mocks.py new file mode 100644 index 00000000..f16ae22f --- /dev/null +++ b/api/tests/mocks.py @@ -0,0 +1,43 @@ +from typing import Any, Optional + + +class MockPipeline(): + # flags + slice_size: Optional[str] + vae_slicing: Optional[bool] + sequential_offload: Optional[bool] + model_offload: Optional[bool] + xformers: Optional[bool] + + # stubs + _encode_prompt: Optional[Any] + unet: Optional[Any] + vae_decoder: Optional[Any] + vae_encoder: Optional[Any] + + def __init__(self) -> None: + self.slice_size = None + self.vae_slicing = None + self.sequential_offload = None + self.model_offload = None + self.xformers = None + + self._encode_prompt = None + self.unet = None + self.vae_decoder = None + self.vae_encoder = None + + def enable_attention_slicing(self, slice_size: str = None): + self.slice_size = slice_size + + def enable_vae_slicing(self): + self.vae_slicing = True + + def enable_sequential_cpu_offload(self): + self.sequential_offload = True + + def enable_model_cpu_offload(self): + self.model_offload = True + + def enable_xformers_memory_efficient_attention(self): + self.xformers = True \ No newline at end of file diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index c3649a47..474b86d9 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -6,7 +6,15 @@ from onnx_web.diffusers.load import ( get_available_pipelines, get_pipeline_schedulers, get_scheduler_name, + optimize_pipeline, + patch_pipeline, ) +from onnx_web.diffusers.patches.unet import UNetWrapper +from onnx_web.diffusers.patches.vae import VAEWrapper +from onnx_web.diffusers.utils import expand_prompt +from onnx_web.params import ImageParams +from onnx_web.server.context import ServerContext +from tests.mocks import MockPipeline class TestAvailablePipelines(unittest.TestCase): @@ -35,30 +43,91 @@ class TestSchedulerNames(unittest.TestCase): class TestOptimizePipeline(unittest.TestCase): def test_auto_attention_slicing(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-auto", + ], + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "auto") def test_max_attention_slicing(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-max", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "max") def test_vae_slicing(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-vae-slicing", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.vae_slicing, True) def test_cpu_offload_sequential(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-sequential", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.sequential_offload, True) def test_cpu_offload_model(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-model", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.model_offload, True) def test_memory_efficient_attention(self): - pass + server = ServerContext( + optimizations=[ + "diffusers-memory-efficient-attention", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.xformers, True) class TestPatchPipeline(unittest.TestCase): def test_expand_not_lpw(self): + """ + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) + self.assertEqual(pipeline._encode_prompt, expand_prompt) + """ pass def test_unet_wrapper_not_xl(self): - pass + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) + self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) + + def test_unet_wrapper_xl(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1)) + self.assertFalse(isinstance(pipeline.unet, UNetWrapper)) def test_vae_wrapper(self): - pass \ No newline at end of file + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) + self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) + self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) diff --git a/api/tests/test_diffusers/test_utils.py b/api/tests/test_diffusers/test_utils.py new file mode 100644 index 00000000..b723beee --- /dev/null +++ b/api/tests/test_diffusers/test_utils.py @@ -0,0 +1,85 @@ +import unittest + +import numpy as np + +from onnx_web.diffusers.utils import ( + expand_interval_ranges, + expand_alternative_ranges, + get_inversions_from_prompt, + get_latents_from_seed, + get_loras_from_prompt, + get_scaled_latents, + get_tokens_from_prompt, +) +from onnx_web.params import Size + +class TestExpandIntervalRanges(unittest.TestCase): + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual(prompt, result) + + def test_prompt_with_range(self): + prompt = "an astronaut-{1,4} eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger") + +class TestExpandAlternativeRanges(unittest.TestCase): + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_alternative_ranges(prompt) + self.assertEqual([prompt], result) + + def test_ranges_match(self): + prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)" + result = expand_alternative_ranges(prompt) + self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]) + +class TestInversionsFromPrompt(unittest.TestCase): + def test_get_inversions(self): + prompt = " an astronaut eating an embedding" + result, tokens = get_inversions_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating an embedding") + self.assertEqual(tokens, [("test", 1.0)]) + +class TestLoRAsFromPrompt(unittest.TestCase): + def test_get_loras(self): + prompt = " an astronaut eating a LoRA" + result, tokens = get_loras_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating a LoRA") + self.assertEqual(tokens, [("test", 1.0)]) + +class TestLatentsFromSeed(unittest.TestCase): + def test_batch_size(self): + latents = get_latents_from_seed(1, Size(64, 64), batch=4) + self.assertEqual(latents.shape, (4, 4, 8, 8)) + + def test_consistency(self): + latents1 = get_latents_from_seed(1, Size(64, 64)) + latents2 = get_latents_from_seed(1, Size(64, 64)) + self.assertTrue(np.array_equal(latents1, latents2)) + +class TestTileLatents(unittest.TestCase): + def test_full_tile(self): + pass + + def test_partial_tile(self): + pass + +class TestScaledLatents(unittest.TestCase): + def test_scale_up(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=2) + self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0]) + + def test_scale_down(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=0.5) + self.assertEqual(( + latents[0, 0, 0, 0] + + latents[0, 0, 0, 1] + + latents[0, 0, 1, 0] + + latents[0, 0, 1, 1] + ) / 4, scaled[0, 0, 0, 0])