diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index f838a019..354fcc0e 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -267,13 +267,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): replace_wildcards(params, get_wildcard_data()) - output_count = params.batch - if source_filter is not None and source_filter != "none": - logger.debug( - "including filtered source with outputs, filter: %s", source_filter - ) - output_count += 1 - job_name = make_job_name("img2img", params, size, extras=[strength]) queue = pool.submit( job_name, diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 90ecb503..cc9d990b 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -423,7 +423,7 @@ class DevicePoolExecutor: Returns a tuple of: job/device, progress, progress, finished, cancelled, failed """ - jobs: Tuple[str, int, JobStatus] = [] + jobs: List[Tuple[str, int, JobStatus]] = [] jobs.extend( [ ( diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index 17dc3b6a..61c345a6 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -6,6 +6,7 @@ import torch from onnx import GraphProto, ModelProto, NodeProto from onnx.numpy_helper import from_array +from onnx_web.constants import ONNX_MODEL from onnx_web.convert.diffusion.lora import ( blend_loras, blend_node_conv_gemm, @@ -231,7 +232,6 @@ class BlendLoRATests(unittest.TestCase): @patch("onnx_web.convert.diffusion.lora.load") @patch("onnx_web.convert.diffusion.lora.load_tensor") def test_blend_loras_load_str(self, mock_load_tensor, mock_load): - base_name = "model.onnx" loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)] model_type = "unet" model_index = 2 @@ -241,10 +241,12 @@ class BlendLoRATests(unittest.TestCase): mock_load_tensor.return_value = MagicMock() # Call the blend_loras function - blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl) + blended_model = blend_loras( + None, ONNX_MODEL, loras, model_type, model_index, xl + ) # Assert that the InferenceSession is called with the correct arguments - mock_load.assert_called_once_with(base_name) + mock_load.assert_called_once_with(ONNX_MODEL) # Assert that the model is loaded successfully self.assertEqual(blended_model, mock_load.return_value) diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index da8bc550..e3cbed05 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -3,6 +3,7 @@ from os import path from unittest import mock from unittest.mock import MagicMock, patch +from onnx_web.constants import ONNX_MODEL from onnx_web.convert.utils import ( DEFAULT_OPSET, ConversionContext, @@ -267,36 +268,41 @@ class ResolveTensorTests(unittest.TestCase): self.assertIsNone(resolve_tensor("missing")) +TORCH_MODEL = "model.pth" + + class LoadTorchTests(unittest.TestCase): @patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.torch") def test_load_torch_with_torch_load(self, mock_torch, mock_logger): - name = "model.pth" map_location = "cpu" checkpoint = MagicMock() mock_torch.load.return_value = checkpoint - result = load_torch(name, map_location) + result = load_torch(TORCH_MODEL, map_location) - mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) - mock_torch.load.assert_called_once_with(name, map_location=map_location) + mock_logger.debug.assert_called_once_with( + "loading tensor with Torch: %s", TORCH_MODEL + ) + mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location) self.assertEqual(result, checkpoint) @patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.torch") def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger): - name = "model.pth" checkpoint = MagicMock() mock_torch.load.side_effect = Exception() mock_torch.jit.load.return_value = checkpoint - result = load_torch(name) + result = load_torch(TORCH_MODEL) - mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) - mock_logger.exception.assert_called_once_with( - "error loading with Torch, trying with Torch JIT: %s", name + mock_logger.debug.assert_called_once_with( + "loading tensor with Torch: %s", TORCH_MODEL ) - mock_torch.jit.load.assert_called_once_with(name) + mock_logger.exception.assert_called_once_with( + "error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL + ) + mock_torch.jit.load.assert_called_once_with(TORCH_MODEL) self.assertEqual(result, checkpoint) @@ -358,20 +364,21 @@ class LoadTensorTests(unittest.TestCase): @patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.torch") def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger): - name = "model.onnx" map_location = "cpu" checkpoint = MagicMock() mock_torch.load.side_effect = [checkpoint] - result = load_tensor(name, map_location) + result = load_tensor(ONNX_MODEL, map_location) - mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.debug.assert_has_calls( + [mock.call("loading tensor: %s", ONNX_MODEL)] + ) mock_logger.warning.assert_called_once_with( "tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx" ) mock_torch.load.assert_has_calls( [ - mock.call(name, map_location=map_location), + mock.call(ONNX_MODEL, map_location=map_location), ] ) self.assertEqual(result, checkpoint) @@ -429,16 +436,15 @@ class FixDiffusionNameTests(unittest.TestCase): class BuildCachePathsTests(unittest.TestCase): def test_build_cache_paths_without_format(self): - name = "model.onnx" client = "client1" cache = "/path/to/cache" conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, name, client, cache) + result = build_cache_paths(conversion, ONNX_MODEL, client, cache) expected_paths = [ - path.join("/path/to/cache", "model.onnx"), - path.join("/path/to/cache/client1", "model.onnx"), + path.join("/path/to/cache", ONNX_MODEL), + path.join("/path/to/cache/client1", ONNX_MODEL), ] self.assertEqual(result, expected_paths) @@ -452,23 +458,22 @@ class BuildCachePathsTests(unittest.TestCase): result = build_cache_paths(conversion, name, client, cache, format) expected_paths = [ - path.join("/path/to/cache", "model.onnx"), - path.join("/path/to/cache/client2", "model.onnx"), + path.join("/path/to/cache", ONNX_MODEL), + path.join("/path/to/cache/client2", ONNX_MODEL), ] self.assertEqual(result, expected_paths) def test_build_cache_paths_with_existing_extension(self): - name = "model.pth" client = "client3" cache = "/path/to/cache" format = "onnx" conversion = ConversionContext(cache_path=cache) - result = build_cache_paths(conversion, name, client, cache, format) + result = build_cache_paths(conversion, TORCH_MODEL, client, cache, format) expected_paths = [ - path.join("/path/to/cache", "model.pth"), - path.join("/path/to/cache/client3", "model.pth"), + path.join("/path/to/cache", TORCH_MODEL), + path.join("/path/to/cache/client3", TORCH_MODEL), ] self.assertEqual(result, expected_paths) @@ -482,8 +487,8 @@ class BuildCachePathsTests(unittest.TestCase): result = build_cache_paths(conversion, name, client, cache, format) expected_paths = [ - path.join("/path/to/cache", "model.onnx"), - path.join("/path/to/cache/client4", "model.onnx"), + path.join("/path/to/cache", ONNX_MODEL), + path.join("/path/to/cache/client4", ONNX_MODEL), ] self.assertEqual(result, expected_paths) diff --git a/api/tests/test_utils.py b/api/tests/test_utils.py index 6db6838a..04c2818d 100644 --- a/api/tests/test_utils.py +++ b/api/tests/test_utils.py @@ -31,16 +31,16 @@ class TestUtils(unittest.TestCase): self.assertEqual(split_list(" a ,b "), ["a", "b"]) def test_get_boolean_empty(self): - self.assertEqual(get_boolean({}, "key", False), False) - self.assertEqual(get_boolean({}, "key", True), True) + self.assertFalse(get_boolean({}, "key", False)) + self.assertTrue(get_boolean({}, "key", True)) def test_get_boolean_true(self): - self.assertEqual(get_boolean({"key": True}, "key", False), True) - self.assertEqual(get_boolean({"key": True}, "key", True), True) + self.assertTrue(get_boolean({"key": True}, "key", False)) + self.assertTrue(get_boolean({"key": True}, "key", True)) def test_get_boolean_false(self): - self.assertEqual(get_boolean({"key": False}, "key", False), False) - self.assertEqual(get_boolean({"key": False}, "key", True), False) + self.assertFalse(get_boolean({"key": False}, "key", False)) + self.assertFalse(get_boolean({"key": False}, "key", True)) def test_get_list_empty(self): self.assertEqual(get_list({}, "key", ""), []) diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 387bbc39..8de0d4b8 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -1,7 +1,7 @@ import { mustExist } from '@apextoaster/js-utils'; import { TabContext, TabList, TabPanel } from '@mui/lab'; import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material'; -import { SxProps, Theme, ThemeProvider, createTheme } from '@mui/material/styles'; +import { createTheme, ThemeProvider } from '@mui/material/styles'; import { Allotment } from 'allotment'; import * as React from 'react'; import { useContext, useMemo } from 'react'; @@ -21,7 +21,7 @@ import { Models } from './tab/Models.js'; import { Settings } from './tab/Settings.js'; import { Txt2Img } from './tab/Txt2Img.js'; import { Upscale } from './tab/Upscale.js'; -import { TAB_LABELS, getTab, getTheme } from './utils.js'; +import { getTab, getTheme, TAB_LABELS } from './utils.js'; import 'allotment/dist/style.css'; import './main.css'; @@ -49,8 +49,6 @@ export function OnnxWeb(props: OnnxWebProps) { [prefersDarkMode, stateTheme], ); - const historyStyle: SxProps = layout.history.style; - return ( @@ -59,7 +57,7 @@ export function OnnxWeb(props: OnnxWebProps) { {props.motd && } - {renderBody(direction, historyStyle, historyWidth)} + {renderBody(direction, historyWidth)} ); @@ -77,25 +75,30 @@ export function selectHistoryWidth(state: OnnxState) { return state.historyWidth; } -function renderBody(direction: Layout, historyStyle: SxProps, historyWidth: number) { +function renderBody(direction: Layout, historyWidth: number) { if (direction === 'vertical') { - return ; + return ; } else { - return ; + return ; } } // used for both horizontal and vertical export interface BodyProps { direction: Layout; - style: SxProps; width: number; } export function HorizontalBody(props: BodyProps) { const layout = LAYOUT_STYLES[props.direction]; - return + return diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index f64caca5..3abd6864 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -25,6 +25,7 @@ "backlighting", "basicsr", "bokeh", + "BSRGAN", "Civitai", "ckpt", "cnet",