From a71298ff33f457933e1c2e082e6c68fd8d174c5b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 15 Sep 2023 08:40:56 -0500 Subject: [PATCH] remove unused schema, lint --- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/source_txt2img.py | 2 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/convert/diffusion/lora.py | 3 - api/onnx_web/params.py | 2 +- api/schemas/generate.yaml | 117 ----------------------- api/tests/chain/test_base.py | 26 +++++ api/tests/convert/diffusion/test_lora.py | 11 ++- 8 files changed, 40 insertions(+), 125 deletions(-) delete mode 100644 api/schemas/generate.yaml create mode 100644 api/tests/chain/test_base.py diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 020cd917..0d5ad28e 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -16,7 +16,7 @@ logger = getLogger(__name__) class BlendImg2ImgStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index dfed652a..b4dd7c24 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -22,7 +22,7 @@ logger = getLogger(__name__) class SourceTxt2ImgStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 71de2629..67f7ca0a 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -24,7 +24,7 @@ logger = getLogger(__name__) class UpscaleOutpaintStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 3d3fed7c..cb0e2db9 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -543,9 +543,6 @@ def blend_loras( if len(unmatched_keys) > 0: logger.warning("could not find nodes for some keys: %s", unmatched_keys) - # if model_type == "unet": - # save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb") - return base_model diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 5b504c6a..885b09b3 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -24,7 +24,7 @@ class SizeChart(IntEnum): hd16k = 2**14 hd32k = 2**15 hd64k = 2**16 - unlimited = 2**32 # sort of + max = 2**32 # should be a reasonable upper limit for now class TileOrder: diff --git a/api/schemas/generate.yaml b/api/schemas/generate.yaml deleted file mode 100644 index 8666468e..00000000 --- a/api/schemas/generate.yaml +++ /dev/null @@ -1,117 +0,0 @@ -$id: TODO -$schema: https://json-schema.org/draft/2020-12/schema - -$defs: - grid: - type: object - additionalProperties: False - required: [width, height] - width: - type: number - height: - type: number - labels: - type: object - additionalProperties: False - properties: - title: - type: string - rows: - type: array - items: - type: string - columns: - type: array - items: - type: string - order: - type: array - items: number - - job_base: - type: object - additionalProperties: true - required: [ - device, - model, - pipeline, - scheduler, - prompt, - cfg, - steps, - seed, - ] - properties: - batch: - type: number - device: - type: string - model: - type: string - control: - type: string - pipeline: - type: string - scheduler: - type: string - prompt: - type: string - negative_prompt: - type: string - cfg: - type: number - eta: - type: number - steps: - type: number - tiled_vae: - type: boolean - tiles: - type: number - overlap: - type: number - seed: - type: number - stride: - type: number - - job_txt2img: - allOf: - - $ref: "#/$defs/job_base" - - type: object - additionalProperties: False - required: [ - height, - width, - ] - properties: - width: - type: number - height: - type: number - - job_img2img: - allOf: - - $ref: "#/$defs/job_base" - - type: object - additionalProperties: False - required: [] - properties: - loopback: - type: number - -type: object -additionalProperties: False -properties: - txt2img: - type: array - items: - $ref: "#/$defs/job_txt2img" - img2img: - type: array - items: - $ref: "#/$defs/job_img2img" - grid: - type: array - items: - $ref: "#/$defs/grid" diff --git a/api/tests/chain/test_base.py b/api/tests/chain/test_base.py new file mode 100644 index 00000000..a2530600 --- /dev/null +++ b/api/tests/chain/test_base.py @@ -0,0 +1,26 @@ +import unittest + +from onnx_web.chain.base import ChainProgress + + +class ChainProgressTests(unittest.TestCase): + def test_accumulate_with_reset(self): + def parent(step, timestep, latents): + pass + + progress = ChainProgress(parent) + progress(5, 1, None) + progress(0, 1, None) + progress(5, 1, None) + + self.assertEqual(progress.get_total(), 10) + + def test_start_value(self): + def parent(step, timestep, latents): + pass + + progress = ChainProgress(parent, 5) + self.assertEqual(progress.get_total(), 5) + + progress(10, 1, None) + self.assertEqual(progress.get_total(), 10) diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index 58462b80..672c6be6 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -150,8 +150,17 @@ class KernelSliceTests(unittest.TestCase): (2, 2), ) + class BlendLoRATests(unittest.TestCase): - pass + def test_blend_unet(self): + pass + + def test_blend_text_encoder(self): + pass + + def test_blend_text_encoder_index(self): + pass + class InterpToMatchTests(unittest.TestCase): def test_same_shape(self):