1
0
Fork 0

lint fixes

This commit is contained in:
Sean Sube 2024-01-14 19:24:50 -06:00
parent 5ffb44c8fa
commit 2c988b8a16
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 69 additions and 42 deletions

View File

@ -7,7 +7,7 @@ import torch
from PIL import Image
from torchvision.transforms.functional import normalize
from ..params import ImageParams, StageParams, UpscaleParams
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .base import BaseStage
@ -28,8 +28,9 @@ class CorrectCodeformerStage(BaseStage):
_params: ImageParams,
sources: StageResult,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and

View File

@ -4,7 +4,13 @@ from typing import Optional
from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import (
DeviceParams,
HighresParams,
ImageParams,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -60,6 +66,7 @@ class CorrectGFPGANStage(BaseStage):
sources: StageResult,
*,
upscale: UpscaleParams,
highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from PIL import Image
from ..errors import CancelledException
from ..output import save_metadata
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..server.model_cache import ModelTypes
@ -44,7 +45,7 @@ class EditSafetyStage(BaseStage):
def run(
self,
_worker: WorkerContext,
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
@ -65,23 +66,24 @@ class EditSafetyStage(BaseStage):
images = sources.as_images()
results = []
for i, image in enumerate(images):
prompt = sources.metadata[i].params.prompt
metadata = sources.metadata[i]
prompt = metadata.params.prompt
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
if check.is_csam:
logger.warning("flagging csam result: %s, %s", i, prompt)
is_csam = True
continue
if check.is_nsfw and block_nsfw:
report_name = f"csam-report-{worker.job}-{i}"
report_path = save_metadata(server, report_name, metadata)
logger.info("saved csam report: %s", report_path)
elif check.is_nsfw and block_nsfw:
logger.warning("blocking nsfw image: %s, %s", i, prompt)
results.append(Image.new("RGB", image.size, color="black"))
continue
results.append(image)
else:
results.append(image)
if is_csam:
# TODO: save metadata to a report file
logger.warning("blocking csam result")
raise CancelledException(reason="csam")
else:

View File

@ -8,6 +8,7 @@ from PIL import Image
from ..models.onnx import OnnxModel
from ..params import (
DeviceParams,
HighresParams,
ImageParams,
Size,
SizeChart,
@ -65,6 +66,7 @@ class UpscaleBSRGANStage(BaseStage):
sources: StageResult,
*,
upscale: UpscaleParams,
highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:

View File

@ -5,7 +5,13 @@ from typing import Optional
from PIL import Image
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import (
DeviceParams,
HighresParams,
ImageParams,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -102,6 +108,7 @@ class UpscaleRealESRGANStage(BaseStage):
sources: StageResult,
*,
upscale: UpscaleParams,
highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:

View File

@ -6,7 +6,14 @@ import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
from ..params import (
DeviceParams,
HighresParams,
ImageParams,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -58,6 +65,7 @@ class UpscaleSwinIRStage(BaseStage):
sources: StageResult,
*,
upscale: UpscaleParams,
highres: Optional[HighresParams] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:

View File

@ -28,6 +28,7 @@ from ..utils import (
get_boolean,
get_from_list,
get_from_map,
get_list,
get_not_empty,
get_size,
load_config,
@ -648,7 +649,7 @@ def job_create(server: ServerContext, pool: DevicePoolExecutor):
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
job_list = get_list(request.args, "jobs")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
@ -672,7 +673,7 @@ def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
def job_status(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
job_list = get_list(request.args, "jobs")
if legacy_job_name is not None:
job_list.append(legacy_job_name)

View File

@ -4,9 +4,6 @@ from onnx_web.chain.result import ImageMetadata
class ImageMetadataTests(unittest.TestCase):
def test_image_metadata(self):
pass
def test_from_exif_normal(self):
exif_data = """test prompt
Negative prompt: negative prompt
@ -36,8 +33,3 @@ Steps: 30, Seed: 5
self.assertEqual(metadata.params.cfg, 4.0)
self.assertEqual(metadata.params.steps, 30)
self.assertEqual(metadata.params.seed, 5)
class StageResultTests(unittest.TestCase):
def test_stage_result(self):
pass

View File

@ -306,6 +306,9 @@ class LoadTorchTests(unittest.TestCase):
self.assertEqual(result, checkpoint)
LOAD_TENSOR_LOG = "loading tensor: %s"
class LoadTensorTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.path")
@ -320,7 +323,7 @@ class LoadTensorTests(unittest.TestCase):
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
mock_path.splitext.assert_called_once_with(name)
mock_path.exists.assert_called_once_with(name)
mock_torch.load.assert_called_once_with(name, map_location=map_location)
@ -339,7 +342,7 @@ class LoadTensorTests(unittest.TestCase):
result = load_tensor(name)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
self.assertEqual(result, checkpoint)
@ -353,7 +356,7 @@ class LoadTensorTests(unittest.TestCase):
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
@ -370,9 +373,7 @@ class LoadTensorTests(unittest.TestCase):
result = load_tensor(ONNX_MODEL, map_location)
mock_logger.debug.assert_has_calls(
[mock.call("loading tensor: %s", ONNX_MODEL)]
)
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, ONNX_MODEL)])
mock_logger.warning.assert_called_once_with(
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
)
@ -393,7 +394,7 @@ class LoadTensorTests(unittest.TestCase):
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
mock_logger.warning.assert_called_once_with(
"unknown tensor type, falling back to PyTorch: %s", "xyz"
)
@ -434,13 +435,15 @@ class FixDiffusionNameTests(unittest.TestCase):
)
CACHE_PATH = "/path/to/cache"
class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_without_format(self):
client = "client1"
cache = "/path/to/cache"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, ONNX_MODEL, client, cache)
conversion = ConversionContext(cache_path=CACHE_PATH)
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),
@ -451,11 +454,10 @@ class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_with_format(self):
name = "model"
client = "client2"
cache = "/path/to/cache"
model_format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, model_format)
conversion = ConversionContext(cache_path=CACHE_PATH)
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),
@ -465,11 +467,12 @@ class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_with_existing_extension(self):
client = "client3"
cache = "/path/to/cache"
model_format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, TORCH_MODEL, client, cache, model_format)
conversion = ConversionContext(cache_path=CACHE_PATH)
result = build_cache_paths(
conversion, TORCH_MODEL, client, CACHE_PATH, model_format
)
expected_paths = [
path.join("/path/to/cache", TORCH_MODEL),
@ -480,11 +483,10 @@ class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_with_empty_extension(self):
name = "model"
client = "client4"
cache = "/path/to/cache"
model_format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, model_format)
conversion = ConversionContext(cache_path=CACHE_PATH)
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),

View File

@ -68,6 +68,7 @@ class TestWorkerPool(unittest.TestCase):
self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None))
@unittest.skip("TODO")
def test_cancel_running(self):
pass
@ -144,15 +145,19 @@ class TestWorkerPool(unittest.TestCase):
status, _progress, _queue = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS)
@unittest.skip("TODO")
def test_recycle_live(self):
pass
@unittest.skip("TODO")
def test_recycle_dead(self):
pass
@unittest.skip("TODO")
def test_running_status(self):
pass
@unittest.skip("TODO")
def test_progress_update(self):
pass