lint fixes
This commit is contained in:
parent
7966f001e1
commit
c5d64e7b1e
|
@ -267,13 +267,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output_count = params.batch
|
||||
if source_filter is not None and source_filter != "none":
|
||||
logger.debug(
|
||||
"including filtered source with outputs, filter: %s", source_filter
|
||||
)
|
||||
output_count += 1
|
||||
|
||||
job_name = make_job_name("img2img", params, size, extras=[strength])
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
|
|
|
@ -423,7 +423,7 @@ class DevicePoolExecutor:
|
|||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||
"""
|
||||
|
||||
jobs: Tuple[str, int, JobStatus] = []
|
||||
jobs: List[Tuple[str, int, JobStatus]] = []
|
||||
jobs.extend(
|
||||
[
|
||||
(
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
|||
from onnx import GraphProto, ModelProto, NodeProto
|
||||
from onnx.numpy_helper import from_array
|
||||
|
||||
from onnx_web.constants import ONNX_MODEL
|
||||
from onnx_web.convert.diffusion.lora import (
|
||||
blend_loras,
|
||||
blend_node_conv_gemm,
|
||||
|
@ -231,7 +232,6 @@ class BlendLoRATests(unittest.TestCase):
|
|||
@patch("onnx_web.convert.diffusion.lora.load")
|
||||
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
||||
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
||||
base_name = "model.onnx"
|
||||
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
||||
model_type = "unet"
|
||||
model_index = 2
|
||||
|
@ -241,10 +241,12 @@ class BlendLoRATests(unittest.TestCase):
|
|||
mock_load_tensor.return_value = MagicMock()
|
||||
|
||||
# Call the blend_loras function
|
||||
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
|
||||
blended_model = blend_loras(
|
||||
None, ONNX_MODEL, loras, model_type, model_index, xl
|
||||
)
|
||||
|
||||
# Assert that the InferenceSession is called with the correct arguments
|
||||
mock_load.assert_called_once_with(base_name)
|
||||
mock_load.assert_called_once_with(ONNX_MODEL)
|
||||
|
||||
# Assert that the model is loaded successfully
|
||||
self.assertEqual(blended_model, mock_load.return_value)
|
||||
|
|
|
@ -3,6 +3,7 @@ from os import path
|
|||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from onnx_web.constants import ONNX_MODEL
|
||||
from onnx_web.convert.utils import (
|
||||
DEFAULT_OPSET,
|
||||
ConversionContext,
|
||||
|
@ -267,36 +268,41 @@ class ResolveTensorTests(unittest.TestCase):
|
|||
self.assertIsNone(resolve_tensor("missing"))
|
||||
|
||||
|
||||
TORCH_MODEL = "model.pth"
|
||||
|
||||
|
||||
class LoadTorchTests(unittest.TestCase):
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
||||
name = "model.pth"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.return_value = checkpoint
|
||||
|
||||
result = load_torch(name, map_location)
|
||||
result = load_torch(TORCH_MODEL, map_location)
|
||||
|
||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||
mock_logger.debug.assert_called_once_with(
|
||||
"loading tensor with Torch: %s", TORCH_MODEL
|
||||
)
|
||||
mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
||||
name = "model.pth"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = Exception()
|
||||
mock_torch.jit.load.return_value = checkpoint
|
||||
|
||||
result = load_torch(name)
|
||||
result = load_torch(TORCH_MODEL)
|
||||
|
||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||
mock_logger.exception.assert_called_once_with(
|
||||
"error loading with Torch, trying with Torch JIT: %s", name
|
||||
mock_logger.debug.assert_called_once_with(
|
||||
"loading tensor with Torch: %s", TORCH_MODEL
|
||||
)
|
||||
mock_torch.jit.load.assert_called_once_with(name)
|
||||
mock_logger.exception.assert_called_once_with(
|
||||
"error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL
|
||||
)
|
||||
mock_torch.jit.load.assert_called_once_with(TORCH_MODEL)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
|
||||
|
@ -358,20 +364,21 @@ class LoadTensorTests(unittest.TestCase):
|
|||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
||||
name = "model.onnx"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = [checkpoint]
|
||||
|
||||
result = load_tensor(name, map_location)
|
||||
result = load_tensor(ONNX_MODEL, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.debug.assert_has_calls(
|
||||
[mock.call("loading tensor: %s", ONNX_MODEL)]
|
||||
)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
||||
)
|
||||
mock_torch.load.assert_has_calls(
|
||||
[
|
||||
mock.call(name, map_location=map_location),
|
||||
mock.call(ONNX_MODEL, map_location=map_location),
|
||||
]
|
||||
)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
@ -429,16 +436,15 @@ class FixDiffusionNameTests(unittest.TestCase):
|
|||
|
||||
class BuildCachePathsTests(unittest.TestCase):
|
||||
def test_build_cache_paths_without_format(self):
|
||||
name = "model.onnx"
|
||||
client = "client1"
|
||||
cache = "/path/to/cache"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache)
|
||||
result = build_cache_paths(conversion, ONNX_MODEL, client, cache)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client1", "model.onnx"),
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join("/path/to/cache/client1", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
|
@ -452,23 +458,22 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client2", "model.onnx"),
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join("/path/to/cache/client2", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
def test_build_cache_paths_with_existing_extension(self):
|
||||
name = "model.pth"
|
||||
client = "client3"
|
||||
cache = "/path/to/cache"
|
||||
format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
result = build_cache_paths(conversion, TORCH_MODEL, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.pth"),
|
||||
path.join("/path/to/cache/client3", "model.pth"),
|
||||
path.join("/path/to/cache", TORCH_MODEL),
|
||||
path.join("/path/to/cache/client3", TORCH_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
|
@ -482,8 +487,8 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client4", "model.onnx"),
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join("/path/to/cache/client4", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
|
|
|
@ -31,16 +31,16 @@ class TestUtils(unittest.TestCase):
|
|||
self.assertEqual(split_list(" a ,b "), ["a", "b"])
|
||||
|
||||
def test_get_boolean_empty(self):
|
||||
self.assertEqual(get_boolean({}, "key", False), False)
|
||||
self.assertEqual(get_boolean({}, "key", True), True)
|
||||
self.assertFalse(get_boolean({}, "key", False))
|
||||
self.assertTrue(get_boolean({}, "key", True))
|
||||
|
||||
def test_get_boolean_true(self):
|
||||
self.assertEqual(get_boolean({"key": True}, "key", False), True)
|
||||
self.assertEqual(get_boolean({"key": True}, "key", True), True)
|
||||
self.assertTrue(get_boolean({"key": True}, "key", False))
|
||||
self.assertTrue(get_boolean({"key": True}, "key", True))
|
||||
|
||||
def test_get_boolean_false(self):
|
||||
self.assertEqual(get_boolean({"key": False}, "key", False), False)
|
||||
self.assertEqual(get_boolean({"key": False}, "key", True), False)
|
||||
self.assertFalse(get_boolean({"key": False}, "key", False))
|
||||
self.assertFalse(get_boolean({"key": False}, "key", True))
|
||||
|
||||
def test_get_list_empty(self):
|
||||
self.assertEqual(get_list({}, "key", ""), [])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||
import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material';
|
||||
import { SxProps, Theme, ThemeProvider, createTheme } from '@mui/material/styles';
|
||||
import { createTheme, ThemeProvider } from '@mui/material/styles';
|
||||
import { Allotment } from 'allotment';
|
||||
import * as React from 'react';
|
||||
import { useContext, useMemo } from 'react';
|
||||
|
@ -21,7 +21,7 @@ import { Models } from './tab/Models.js';
|
|||
import { Settings } from './tab/Settings.js';
|
||||
import { Txt2Img } from './tab/Txt2Img.js';
|
||||
import { Upscale } from './tab/Upscale.js';
|
||||
import { TAB_LABELS, getTab, getTheme } from './utils.js';
|
||||
import { getTab, getTheme, TAB_LABELS } from './utils.js';
|
||||
|
||||
import 'allotment/dist/style.css';
|
||||
import './main.css';
|
||||
|
@ -49,8 +49,6 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
[prefersDarkMode, stateTheme],
|
||||
);
|
||||
|
||||
const historyStyle: SxProps<Theme> = layout.history.style;
|
||||
|
||||
return (
|
||||
<ThemeProvider theme={theme}>
|
||||
<CssBaseline />
|
||||
|
@ -59,7 +57,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
<Logo />
|
||||
</Box>
|
||||
{props.motd && <Motd />}
|
||||
{renderBody(direction, historyStyle, historyWidth)}
|
||||
{renderBody(direction, historyWidth)}
|
||||
</Container>
|
||||
</ThemeProvider>
|
||||
);
|
||||
|
@ -77,25 +75,30 @@ export function selectHistoryWidth(state: OnnxState) {
|
|||
return state.historyWidth;
|
||||
}
|
||||
|
||||
function renderBody(direction: Layout, historyStyle: SxProps<Theme>, historyWidth: number) {
|
||||
function renderBody(direction: Layout, historyWidth: number) {
|
||||
if (direction === 'vertical') {
|
||||
return <VerticalBody direction={direction} style={historyStyle} width={historyWidth} />;
|
||||
return <VerticalBody direction={direction} width={historyWidth} />;
|
||||
} else {
|
||||
return <HorizontalBody direction={direction} style={historyStyle} width={historyWidth} />;
|
||||
return <HorizontalBody direction={direction} width={historyWidth} />;
|
||||
}
|
||||
}
|
||||
|
||||
// used for both horizontal and vertical
|
||||
export interface BodyProps {
|
||||
direction: Layout;
|
||||
style: SxProps<Theme>;
|
||||
width: number;
|
||||
}
|
||||
|
||||
export function HorizontalBody(props: BodyProps) {
|
||||
const layout = LAYOUT_STYLES[props.direction];
|
||||
|
||||
return <Allotment separator className='body-allotment' minSize={LAYOUT_MIN} defaultSizes={LAYOUT_PROPORTIONS} snap>
|
||||
return <Allotment
|
||||
className='body-allotment'
|
||||
defaultSizes={LAYOUT_PROPORTIONS}
|
||||
minSize={LAYOUT_MIN}
|
||||
separator
|
||||
snap
|
||||
>
|
||||
<TabGroup direction={props.direction} />
|
||||
<Box className='box-history' sx={layout.history.style}>
|
||||
<ImageHistory width={props.width} />
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
"backlighting",
|
||||
"basicsr",
|
||||
"bokeh",
|
||||
"BSRGAN",
|
||||
"Civitai",
|
||||
"ckpt",
|
||||
"cnet",
|
||||
|
|
Loading…
Reference in New Issue