lint, tests
This commit is contained in:
parent
d8ea00582e
commit
963794abaa
|
@ -22,12 +22,12 @@ class BlendGridStage(BaseStage):
|
|||
*,
|
||||
height: int,
|
||||
width: int,
|
||||
rows: Optional[List[str]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
title: Optional[str] = None,
|
||||
# rows: Optional[List[str]] = None,
|
||||
# columns: Optional[List[str]] = None,
|
||||
# title: Optional[str] = None,
|
||||
order: Optional[int] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
logger.info("combining source images using grid layout")
|
||||
|
@ -51,7 +51,7 @@ class BlendGridStage(BaseStage):
|
|||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
_params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -12,11 +12,11 @@ class BaseStage:
|
|||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
_sources: List[Image.Image],
|
||||
*args,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
|
@ -25,14 +25,14 @@ class BaseStage:
|
|||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
_params: ImageParams,
|
||||
_size: Size,
|
||||
) -> int:
|
||||
return 1 # noqa
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
_params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources
|
||||
|
|
|
@ -35,7 +35,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def build_device(
|
||||
server: ServerContext,
|
||||
_server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
) -> Optional[DeviceParams]:
|
||||
# platform stuff
|
||||
|
@ -172,7 +172,7 @@ def build_params(
|
|||
|
||||
|
||||
def build_size(
|
||||
server: ServerContext,
|
||||
_server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
) -> Size:
|
||||
height = get_and_clamp_int(
|
||||
|
|
|
@ -153,12 +153,33 @@ class KernelSliceTests(unittest.TestCase):
|
|||
|
||||
class BlendLoRATests(unittest.TestCase):
|
||||
def test_blend_unet(self):
|
||||
"""
|
||||
blend_loras(None, "test", [], "unet")
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_blend_text_encoder(self):
|
||||
"""
|
||||
blend_loras(None, "test", [], "text_encoder")
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_blend_text_encoder_index(self):
|
||||
"""
|
||||
blend_loras(None, "test", [], "text_encoder", model_index=2)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_unmatched_keys(self):
|
||||
pass
|
||||
|
||||
def test_xl_keys(self):
|
||||
"""
|
||||
blend_loras(None, "test", [], "unet", xl=True)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_node_dtype(self):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
|
||||
class MockPipeline():
|
||||
# flags
|
||||
slice_size: Optional[str]
|
||||
vae_slicing: Optional[bool]
|
||||
sequential_offload: Optional[bool]
|
||||
model_offload: Optional[bool]
|
||||
xformers: Optional[bool]
|
||||
|
||||
# stubs
|
||||
_encode_prompt: Optional[Any]
|
||||
unet: Optional[Any]
|
||||
vae_decoder: Optional[Any]
|
||||
vae_encoder: Optional[Any]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.slice_size = None
|
||||
self.vae_slicing = None
|
||||
self.sequential_offload = None
|
||||
self.model_offload = None
|
||||
self.xformers = None
|
||||
|
||||
self._encode_prompt = None
|
||||
self.unet = None
|
||||
self.vae_decoder = None
|
||||
self.vae_encoder = None
|
||||
|
||||
def enable_attention_slicing(self, slice_size: str = None):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
self.vae_slicing = True
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
self.sequential_offload = True
|
||||
|
||||
def enable_model_cpu_offload(self):
|
||||
self.model_offload = True
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
self.xformers = True
|
|
@ -6,7 +6,15 @@ from onnx_web.diffusers.load import (
|
|||
get_available_pipelines,
|
||||
get_pipeline_schedulers,
|
||||
get_scheduler_name,
|
||||
optimize_pipeline,
|
||||
patch_pipeline,
|
||||
)
|
||||
from onnx_web.diffusers.patches.unet import UNetWrapper
|
||||
from onnx_web.diffusers.patches.vae import VAEWrapper
|
||||
from onnx_web.diffusers.utils import expand_prompt
|
||||
from onnx_web.params import ImageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from tests.mocks import MockPipeline
|
||||
|
||||
|
||||
class TestAvailablePipelines(unittest.TestCase):
|
||||
|
@ -35,30 +43,91 @@ class TestSchedulerNames(unittest.TestCase):
|
|||
|
||||
class TestOptimizePipeline(unittest.TestCase):
|
||||
def test_auto_attention_slicing(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-attention-slicing-auto",
|
||||
],
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.slice_size, "auto")
|
||||
|
||||
def test_max_attention_slicing(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-attention-slicing-max",
|
||||
]
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.slice_size, "max")
|
||||
|
||||
def test_vae_slicing(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-vae-slicing",
|
||||
]
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.vae_slicing, True)
|
||||
|
||||
def test_cpu_offload_sequential(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-cpu-offload-sequential",
|
||||
]
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.sequential_offload, True)
|
||||
|
||||
def test_cpu_offload_model(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-cpu-offload-model",
|
||||
]
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.model_offload, True)
|
||||
|
||||
def test_memory_efficient_attention(self):
|
||||
pass
|
||||
server = ServerContext(
|
||||
optimizations=[
|
||||
"diffusers-memory-efficient-attention",
|
||||
]
|
||||
)
|
||||
pipeline = MockPipeline()
|
||||
optimize_pipeline(server, pipeline)
|
||||
self.assertEqual(pipeline.xformers, True)
|
||||
|
||||
|
||||
class TestPatchPipeline(unittest.TestCase):
|
||||
def test_expand_not_lpw(self):
|
||||
"""
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
self.assertEqual(pipeline._encode_prompt, expand_prompt)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_unet_wrapper_not_xl(self):
|
||||
pass
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_unet_wrapper_xl(self):
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1))
|
||||
self.assertFalse(isinstance(pipeline.unet, UNetWrapper))
|
||||
|
||||
def test_vae_wrapper(self):
|
||||
pass
|
||||
server = ServerContext()
|
||||
pipeline = MockPipeline()
|
||||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
||||
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onnx_web.diffusers.utils import (
|
||||
expand_interval_ranges,
|
||||
expand_alternative_ranges,
|
||||
get_inversions_from_prompt,
|
||||
get_latents_from_seed,
|
||||
get_loras_from_prompt,
|
||||
get_scaled_latents,
|
||||
get_tokens_from_prompt,
|
||||
)
|
||||
from onnx_web.params import Size
|
||||
|
||||
class TestExpandIntervalRanges(unittest.TestCase):
|
||||
def test_prompt_with_no_ranges(self):
|
||||
prompt = "an astronaut eating a hamburger"
|
||||
result = expand_interval_ranges(prompt)
|
||||
self.assertEqual(prompt, result)
|
||||
|
||||
def test_prompt_with_range(self):
|
||||
prompt = "an astronaut-{1,4} eating a hamburger"
|
||||
result = expand_interval_ranges(prompt)
|
||||
self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger")
|
||||
|
||||
class TestExpandAlternativeRanges(unittest.TestCase):
|
||||
def test_prompt_with_no_ranges(self):
|
||||
prompt = "an astronaut eating a hamburger"
|
||||
result = expand_alternative_ranges(prompt)
|
||||
self.assertEqual([prompt], result)
|
||||
|
||||
def test_ranges_match(self):
|
||||
prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)"
|
||||
result = expand_alternative_ranges(prompt)
|
||||
self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"])
|
||||
|
||||
class TestInversionsFromPrompt(unittest.TestCase):
|
||||
def test_get_inversions(self):
|
||||
prompt = "<inversion:test:1.0> an astronaut eating an embedding"
|
||||
result, tokens = get_inversions_from_prompt(prompt)
|
||||
|
||||
self.assertEqual(result, " an astronaut eating an embedding")
|
||||
self.assertEqual(tokens, [("test", 1.0)])
|
||||
|
||||
class TestLoRAsFromPrompt(unittest.TestCase):
|
||||
def test_get_loras(self):
|
||||
prompt = "<lora:test:1.0> an astronaut eating a LoRA"
|
||||
result, tokens = get_loras_from_prompt(prompt)
|
||||
|
||||
self.assertEqual(result, " an astronaut eating a LoRA")
|
||||
self.assertEqual(tokens, [("test", 1.0)])
|
||||
|
||||
class TestLatentsFromSeed(unittest.TestCase):
|
||||
def test_batch_size(self):
|
||||
latents = get_latents_from_seed(1, Size(64, 64), batch=4)
|
||||
self.assertEqual(latents.shape, (4, 4, 8, 8))
|
||||
|
||||
def test_consistency(self):
|
||||
latents1 = get_latents_from_seed(1, Size(64, 64))
|
||||
latents2 = get_latents_from_seed(1, Size(64, 64))
|
||||
self.assertTrue(np.array_equal(latents1, latents2))
|
||||
|
||||
class TestTileLatents(unittest.TestCase):
|
||||
def test_full_tile(self):
|
||||
pass
|
||||
|
||||
def test_partial_tile(self):
|
||||
pass
|
||||
|
||||
class TestScaledLatents(unittest.TestCase):
|
||||
def test_scale_up(self):
|
||||
latents = get_latents_from_seed(1, Size(16, 16))
|
||||
scaled = get_scaled_latents(1, Size(16, 16), scale=2)
|
||||
self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0])
|
||||
|
||||
def test_scale_down(self):
|
||||
latents = get_latents_from_seed(1, Size(16, 16))
|
||||
scaled = get_scaled_latents(1, Size(16, 16), scale=0.5)
|
||||
self.assertEqual((
|
||||
latents[0, 0, 0, 0] +
|
||||
latents[0, 0, 0, 1] +
|
||||
latents[0, 0, 1, 0] +
|
||||
latents[0, 0, 1, 1]
|
||||
) / 4, scaled[0, 0, 0, 0])
|
Loading…
Reference in New Issue