lint fixes
This commit is contained in:
parent
5ffb44c8fa
commit
2c988b8a16
|
@ -7,7 +7,7 @@ import torch
|
|||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
|
@ -28,8 +28,9 @@ class CorrectCodeformerStage(BaseStage):
|
|||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and
|
||||
|
|
|
@ -4,7 +4,13 @@ from typing import Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -60,6 +66,7 @@ class CorrectGFPGANStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
from PIL import Image
|
||||
|
||||
from ..errors import CancelledException
|
||||
from ..output import save_metadata
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..server.model_cache import ModelTypes
|
||||
|
@ -44,7 +45,7 @@ class EditSafetyStage(BaseStage):
|
|||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -65,23 +66,24 @@ class EditSafetyStage(BaseStage):
|
|||
images = sources.as_images()
|
||||
results = []
|
||||
for i, image in enumerate(images):
|
||||
prompt = sources.metadata[i].params.prompt
|
||||
metadata = sources.metadata[i]
|
||||
prompt = metadata.params.prompt
|
||||
check = nsfw_checker.check_for_nsfw(image, prompt=prompt)
|
||||
|
||||
if check.is_csam:
|
||||
logger.warning("flagging csam result: %s, %s", i, prompt)
|
||||
is_csam = True
|
||||
continue
|
||||
|
||||
if check.is_nsfw and block_nsfw:
|
||||
report_name = f"csam-report-{worker.job}-{i}"
|
||||
report_path = save_metadata(server, report_name, metadata)
|
||||
logger.info("saved csam report: %s", report_path)
|
||||
elif check.is_nsfw and block_nsfw:
|
||||
logger.warning("blocking nsfw image: %s, %s", i, prompt)
|
||||
results.append(Image.new("RGB", image.size, color="black"))
|
||||
continue
|
||||
|
||||
else:
|
||||
results.append(image)
|
||||
|
||||
if is_csam:
|
||||
# TODO: save metadata to a report file
|
||||
logger.warning("blocking csam result")
|
||||
raise CancelledException(reason="csam")
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,7 @@ from PIL import Image
|
|||
from ..models.onnx import OnnxModel
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
|
@ -65,6 +66,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -5,7 +5,13 @@ from typing import Optional
|
|||
from PIL import Image
|
||||
|
||||
from ..onnx import OnnxRRDBNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -102,6 +108,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -6,7 +6,14 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -58,6 +65,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
highres: Optional[HighresParams] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -28,6 +28,7 @@ from ..utils import (
|
|||
get_boolean,
|
||||
get_from_list,
|
||||
get_from_map,
|
||||
get_list,
|
||||
get_not_empty,
|
||||
get_size,
|
||||
load_config,
|
||||
|
@ -648,7 +649,7 @@ def job_create(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
legacy_job_name = request.args.get("job", None)
|
||||
job_list = request.args.get("jobs", "").split(",")
|
||||
job_list = get_list(request.args, "jobs")
|
||||
|
||||
if legacy_job_name is not None:
|
||||
job_list.append(legacy_job_name)
|
||||
|
@ -672,7 +673,7 @@ def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
legacy_job_name = request.args.get("job", None)
|
||||
job_list = request.args.get("jobs", "").split(",")
|
||||
job_list = get_list(request.args, "jobs")
|
||||
|
||||
if legacy_job_name is not None:
|
||||
job_list.append(legacy_job_name)
|
||||
|
|
|
@ -4,9 +4,6 @@ from onnx_web.chain.result import ImageMetadata
|
|||
|
||||
|
||||
class ImageMetadataTests(unittest.TestCase):
|
||||
def test_image_metadata(self):
|
||||
pass
|
||||
|
||||
def test_from_exif_normal(self):
|
||||
exif_data = """test prompt
|
||||
Negative prompt: negative prompt
|
||||
|
@ -36,8 +33,3 @@ Steps: 30, Seed: 5
|
|||
self.assertEqual(metadata.params.cfg, 4.0)
|
||||
self.assertEqual(metadata.params.steps, 30)
|
||||
self.assertEqual(metadata.params.seed, 5)
|
||||
|
||||
|
||||
class StageResultTests(unittest.TestCase):
|
||||
def test_stage_result(self):
|
||||
pass
|
||||
|
|
|
@ -306,6 +306,9 @@ class LoadTorchTests(unittest.TestCase):
|
|||
self.assertEqual(result, checkpoint)
|
||||
|
||||
|
||||
LOAD_TENSOR_LOG = "loading tensor: %s"
|
||||
|
||||
|
||||
class LoadTensorTests(unittest.TestCase):
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.path")
|
||||
|
@ -320,7 +323,7 @@ class LoadTensorTests(unittest.TestCase):
|
|||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
||||
mock_path.splitext.assert_called_once_with(name)
|
||||
mock_path.exists.assert_called_once_with(name)
|
||||
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||
|
@ -339,7 +342,7 @@ class LoadTensorTests(unittest.TestCase):
|
|||
|
||||
result = load_tensor(name)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
||||
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
|
@ -353,7 +356,7 @@ class LoadTensorTests(unittest.TestCase):
|
|||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
||||
mock_torch.load.assert_has_calls(
|
||||
[
|
||||
mock.call(name, map_location=map_location),
|
||||
|
@ -370,9 +373,7 @@ class LoadTensorTests(unittest.TestCase):
|
|||
|
||||
result = load_tensor(ONNX_MODEL, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls(
|
||||
[mock.call("loading tensor: %s", ONNX_MODEL)]
|
||||
)
|
||||
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, ONNX_MODEL)])
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
||||
)
|
||||
|
@ -393,7 +394,7 @@ class LoadTensorTests(unittest.TestCase):
|
|||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"unknown tensor type, falling back to PyTorch: %s", "xyz"
|
||||
)
|
||||
|
@ -434,13 +435,15 @@ class FixDiffusionNameTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
CACHE_PATH = "/path/to/cache"
|
||||
|
||||
|
||||
class BuildCachePathsTests(unittest.TestCase):
|
||||
def test_build_cache_paths_without_format(self):
|
||||
client = "client1"
|
||||
cache = "/path/to/cache"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, ONNX_MODEL, client, cache)
|
||||
conversion = ConversionContext(cache_path=CACHE_PATH)
|
||||
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
|
@ -451,11 +454,10 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
def test_build_cache_paths_with_format(self):
|
||||
name = "model"
|
||||
client = "client2"
|
||||
cache = "/path/to/cache"
|
||||
model_format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, model_format)
|
||||
conversion = ConversionContext(cache_path=CACHE_PATH)
|
||||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
|
@ -465,11 +467,12 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
|
||||
def test_build_cache_paths_with_existing_extension(self):
|
||||
client = "client3"
|
||||
cache = "/path/to/cache"
|
||||
model_format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, TORCH_MODEL, client, cache, model_format)
|
||||
conversion = ConversionContext(cache_path=CACHE_PATH)
|
||||
result = build_cache_paths(
|
||||
conversion, TORCH_MODEL, client, CACHE_PATH, model_format
|
||||
)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", TORCH_MODEL),
|
||||
|
@ -480,11 +483,10 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
def test_build_cache_paths_with_empty_extension(self):
|
||||
name = "model"
|
||||
client = "client4"
|
||||
cache = "/path/to/cache"
|
||||
model_format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, model_format)
|
||||
conversion = ConversionContext(cache_path=CACHE_PATH)
|
||||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
|
|
|
@ -68,6 +68,7 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.assertTrue(self.pool.cancel("test"))
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None))
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_cancel_running(self):
|
||||
pass
|
||||
|
||||
|
@ -144,15 +145,19 @@ class TestWorkerPool(unittest.TestCase):
|
|||
status, _progress, _queue = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.SUCCESS)
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_recycle_live(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_recycle_dead(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_running_status(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_progress_update(self):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue