From 2c988b8a1656f61507cc5a0723719b5111125cc9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 14 Jan 2024 19:24:50 -0600 Subject: [PATCH] lint fixes --- api/onnx_web/chain/correct_codeformer.py | 5 +-- api/onnx_web/chain/correct_gfpgan.py | 9 +++++- api/onnx_web/chain/edit_safety.py | 18 ++++++----- api/onnx_web/chain/upscale_bsrgan.py | 2 ++ api/onnx_web/chain/upscale_resrgan.py | 9 +++++- api/onnx_web/chain/upscale_swinir.py | 10 +++++- api/onnx_web/server/api.py | 5 +-- api/tests/chain/test_result.py | 8 ----- api/tests/convert/test_utils.py | 40 +++++++++++++----------- api/tests/worker/test_pool.py | 5 +++ 10 files changed, 69 insertions(+), 42 deletions(-) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 173e05e4..8a33cba8 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -7,7 +7,7 @@ import torch from PIL import Image from torchvision.transforms.functional import normalize -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage @@ -28,8 +28,9 @@ class CorrectCodeformerStage(BaseStage): _params: ImageParams, sources: StageResult, *, - stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: # adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 52ef659a..d80363dc 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -4,7 +4,13 @@ from typing import Optional from PIL import Image -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -60,6 +66,7 @@ class CorrectGFPGANStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py index 71650892..8778c963 100644 --- a/api/onnx_web/chain/edit_safety.py +++ b/api/onnx_web/chain/edit_safety.py @@ -4,6 +4,7 @@ from typing import Any, Optional from PIL import Image from ..errors import CancelledException +from ..output import save_metadata from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..server.model_cache import ModelTypes @@ -44,7 +45,7 @@ class EditSafetyStage(BaseStage): def run( self, - _worker: WorkerContext, + worker: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, @@ -65,23 +66,24 @@ class EditSafetyStage(BaseStage): images = sources.as_images() results = [] for i, image in enumerate(images): - prompt = sources.metadata[i].params.prompt + metadata = sources.metadata[i] + prompt = metadata.params.prompt check = nsfw_checker.check_for_nsfw(image, prompt=prompt) if check.is_csam: logger.warning("flagging csam result: %s, %s", i, prompt) is_csam = True - continue - if check.is_nsfw and block_nsfw: + report_name = f"csam-report-{worker.job}-{i}" + report_path = save_metadata(server, report_name, metadata) + logger.info("saved csam report: %s", report_path) + elif check.is_nsfw and block_nsfw: logger.warning("blocking nsfw image: %s, %s", i, prompt) results.append(Image.new("RGB", image.size, color="black")) - continue - - results.append(image) + else: + results.append(image) if is_csam: - # TODO: save metadata to a report file logger.warning("blocking csam result") raise CancelledException(reason="csam") else: diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 3d410992..bffc32f1 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -8,6 +8,7 @@ from PIL import Image from ..models.onnx import OnnxModel from ..params import ( DeviceParams, + HighresParams, ImageParams, Size, SizeChart, @@ -65,6 +66,7 @@ class UpscaleBSRGANStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 51f0a5ae..f818057c 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -5,7 +5,13 @@ from typing import Optional from PIL import Image from ..onnx import OnnxRRDBNet -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -102,6 +108,7 @@ class UpscaleRealESRGANStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index cab60078..a5510c28 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -6,7 +6,14 @@ import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + SizeChart, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -58,6 +65,7 @@ class UpscaleSwinIRStage(BaseStage): sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 100ce0cb..ab568b59 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -28,6 +28,7 @@ from ..utils import ( get_boolean, get_from_list, get_from_map, + get_list, get_not_empty, get_size, load_config, @@ -648,7 +649,7 @@ def job_create(server: ServerContext, pool: DevicePoolExecutor): def job_cancel(server: ServerContext, pool: DevicePoolExecutor): legacy_job_name = request.args.get("job", None) - job_list = request.args.get("jobs", "").split(",") + job_list = get_list(request.args, "jobs") if legacy_job_name is not None: job_list.append(legacy_job_name) @@ -672,7 +673,7 @@ def job_cancel(server: ServerContext, pool: DevicePoolExecutor): def job_status(server: ServerContext, pool: DevicePoolExecutor): legacy_job_name = request.args.get("job", None) - job_list = request.args.get("jobs", "").split(",") + job_list = get_list(request.args, "jobs") if legacy_job_name is not None: job_list.append(legacy_job_name) diff --git a/api/tests/chain/test_result.py b/api/tests/chain/test_result.py index 0d062d82..02f99890 100644 --- a/api/tests/chain/test_result.py +++ b/api/tests/chain/test_result.py @@ -4,9 +4,6 @@ from onnx_web.chain.result import ImageMetadata class ImageMetadataTests(unittest.TestCase): - def test_image_metadata(self): - pass - def test_from_exif_normal(self): exif_data = """test prompt Negative prompt: negative prompt @@ -36,8 +33,3 @@ Steps: 30, Seed: 5 self.assertEqual(metadata.params.cfg, 4.0) self.assertEqual(metadata.params.steps, 30) self.assertEqual(metadata.params.seed, 5) - - -class StageResultTests(unittest.TestCase): - def test_stage_result(self): - pass diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 6014889f..b48e8672 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -306,6 +306,9 @@ class LoadTorchTests(unittest.TestCase): self.assertEqual(result, checkpoint) +LOAD_TENSOR_LOG = "loading tensor: %s" + + class LoadTensorTests(unittest.TestCase): @patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.path") @@ -320,7 +323,7 @@ class LoadTensorTests(unittest.TestCase): result = load_tensor(name, map_location) - mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) mock_path.splitext.assert_called_once_with(name) mock_path.exists.assert_called_once_with(name) mock_torch.load.assert_called_once_with(name, map_location=map_location) @@ -339,7 +342,7 @@ class LoadTensorTests(unittest.TestCase): result = load_tensor(name) - mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu") self.assertEqual(result, checkpoint) @@ -353,7 +356,7 @@ class LoadTensorTests(unittest.TestCase): result = load_tensor(name, map_location) - mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) mock_torch.load.assert_has_calls( [ mock.call(name, map_location=map_location), @@ -370,9 +373,7 @@ class LoadTensorTests(unittest.TestCase): result = load_tensor(ONNX_MODEL, map_location) - mock_logger.debug.assert_has_calls( - [mock.call("loading tensor: %s", ONNX_MODEL)] - ) + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, ONNX_MODEL)]) mock_logger.warning.assert_called_once_with( "tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx" ) @@ -393,7 +394,7 @@ class LoadTensorTests(unittest.TestCase): result = load_tensor(name, map_location) - mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) mock_logger.warning.assert_called_once_with( "unknown tensor type, falling back to PyTorch: %s", "xyz" ) @@ -434,13 +435,15 @@ class FixDiffusionNameTests(unittest.TestCase): ) +CACHE_PATH = "/path/to/cache" + + class BuildCachePathsTests(unittest.TestCase): def test_build_cache_paths_without_format(self): client = "client1" - cache = "/path/to/cache" - conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, ONNX_MODEL, client, cache) + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH) expected_paths = [ path.join("/path/to/cache", ONNX_MODEL), @@ -451,11 +454,10 @@ class BuildCachePathsTests(unittest.TestCase): def test_build_cache_paths_with_format(self): name = "model" client = "client2" - cache = "/path/to/cache" model_format = "onnx" - conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, name, client, cache, model_format) + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) expected_paths = [ path.join("/path/to/cache", ONNX_MODEL), @@ -465,11 +467,12 @@ class BuildCachePathsTests(unittest.TestCase): def test_build_cache_paths_with_existing_extension(self): client = "client3" - cache = "/path/to/cache" model_format = "onnx" - conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, TORCH_MODEL, client, cache, model_format) + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths( + conversion, TORCH_MODEL, client, CACHE_PATH, model_format + ) expected_paths = [ path.join("/path/to/cache", TORCH_MODEL), @@ -480,11 +483,10 @@ class BuildCachePathsTests(unittest.TestCase): def test_build_cache_paths_with_empty_extension(self): name = "model" client = "client4" - cache = "/path/to/cache" model_format = "onnx" - conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, name, client, cache, model_format) + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) expected_paths = [ path.join("/path/to/cache", ONNX_MODEL), diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 2df53576..ea170818 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -68,6 +68,7 @@ class TestWorkerPool(unittest.TestCase): self.assertTrue(self.pool.cancel("test")) self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None)) + @unittest.skip("TODO") def test_cancel_running(self): pass @@ -144,15 +145,19 @@ class TestWorkerPool(unittest.TestCase): status, _progress, _queue = self.pool.status("test") self.assertEqual(status, JobStatus.SUCCESS) + @unittest.skip("TODO") def test_recycle_live(self): pass + @unittest.skip("TODO") def test_recycle_dead(self): pass + @unittest.skip("TODO") def test_running_status(self): pass + @unittest.skip("TODO") def test_progress_update(self): pass