lint fixes
This commit is contained in:
parent
7966f001e1
commit
c5d64e7b1e
|
@ -267,13 +267,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output_count = params.batch
|
|
||||||
if source_filter is not None and source_filter != "none":
|
|
||||||
logger.debug(
|
|
||||||
"including filtered source with outputs, filter: %s", source_filter
|
|
||||||
)
|
|
||||||
output_count += 1
|
|
||||||
|
|
||||||
job_name = make_job_name("img2img", params, size, extras=[strength])
|
job_name = make_job_name("img2img", params, size, extras=[strength])
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
|
|
@ -423,7 +423,7 @@ class DevicePoolExecutor:
|
||||||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
jobs: Tuple[str, int, JobStatus] = []
|
jobs: List[Tuple[str, int, JobStatus]] = []
|
||||||
jobs.extend(
|
jobs.extend(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
from onnx import GraphProto, ModelProto, NodeProto
|
from onnx import GraphProto, ModelProto, NodeProto
|
||||||
from onnx.numpy_helper import from_array
|
from onnx.numpy_helper import from_array
|
||||||
|
|
||||||
|
from onnx_web.constants import ONNX_MODEL
|
||||||
from onnx_web.convert.diffusion.lora import (
|
from onnx_web.convert.diffusion.lora import (
|
||||||
blend_loras,
|
blend_loras,
|
||||||
blend_node_conv_gemm,
|
blend_node_conv_gemm,
|
||||||
|
@ -231,7 +232,6 @@ class BlendLoRATests(unittest.TestCase):
|
||||||
@patch("onnx_web.convert.diffusion.lora.load")
|
@patch("onnx_web.convert.diffusion.lora.load")
|
||||||
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
@patch("onnx_web.convert.diffusion.lora.load_tensor")
|
||||||
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
|
||||||
base_name = "model.onnx"
|
|
||||||
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
|
||||||
model_type = "unet"
|
model_type = "unet"
|
||||||
model_index = 2
|
model_index = 2
|
||||||
|
@ -241,10 +241,12 @@ class BlendLoRATests(unittest.TestCase):
|
||||||
mock_load_tensor.return_value = MagicMock()
|
mock_load_tensor.return_value = MagicMock()
|
||||||
|
|
||||||
# Call the blend_loras function
|
# Call the blend_loras function
|
||||||
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl)
|
blended_model = blend_loras(
|
||||||
|
None, ONNX_MODEL, loras, model_type, model_index, xl
|
||||||
|
)
|
||||||
|
|
||||||
# Assert that the InferenceSession is called with the correct arguments
|
# Assert that the InferenceSession is called with the correct arguments
|
||||||
mock_load.assert_called_once_with(base_name)
|
mock_load.assert_called_once_with(ONNX_MODEL)
|
||||||
|
|
||||||
# Assert that the model is loaded successfully
|
# Assert that the model is loaded successfully
|
||||||
self.assertEqual(blended_model, mock_load.return_value)
|
self.assertEqual(blended_model, mock_load.return_value)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from os import path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from onnx_web.constants import ONNX_MODEL
|
||||||
from onnx_web.convert.utils import (
|
from onnx_web.convert.utils import (
|
||||||
DEFAULT_OPSET,
|
DEFAULT_OPSET,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
|
@ -267,36 +268,41 @@ class ResolveTensorTests(unittest.TestCase):
|
||||||
self.assertIsNone(resolve_tensor("missing"))
|
self.assertIsNone(resolve_tensor("missing"))
|
||||||
|
|
||||||
|
|
||||||
|
TORCH_MODEL = "model.pth"
|
||||||
|
|
||||||
|
|
||||||
class LoadTorchTests(unittest.TestCase):
|
class LoadTorchTests(unittest.TestCase):
|
||||||
@patch("onnx_web.convert.utils.logger")
|
@patch("onnx_web.convert.utils.logger")
|
||||||
@patch("onnx_web.convert.utils.torch")
|
@patch("onnx_web.convert.utils.torch")
|
||||||
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
||||||
name = "model.pth"
|
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
checkpoint = MagicMock()
|
checkpoint = MagicMock()
|
||||||
mock_torch.load.return_value = checkpoint
|
mock_torch.load.return_value = checkpoint
|
||||||
|
|
||||||
result = load_torch(name, map_location)
|
result = load_torch(TORCH_MODEL, map_location)
|
||||||
|
|
||||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
mock_logger.debug.assert_called_once_with(
|
||||||
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
"loading tensor with Torch: %s", TORCH_MODEL
|
||||||
|
)
|
||||||
|
mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location)
|
||||||
self.assertEqual(result, checkpoint)
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
@patch("onnx_web.convert.utils.logger")
|
@patch("onnx_web.convert.utils.logger")
|
||||||
@patch("onnx_web.convert.utils.torch")
|
@patch("onnx_web.convert.utils.torch")
|
||||||
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
||||||
name = "model.pth"
|
|
||||||
checkpoint = MagicMock()
|
checkpoint = MagicMock()
|
||||||
mock_torch.load.side_effect = Exception()
|
mock_torch.load.side_effect = Exception()
|
||||||
mock_torch.jit.load.return_value = checkpoint
|
mock_torch.jit.load.return_value = checkpoint
|
||||||
|
|
||||||
result = load_torch(name)
|
result = load_torch(TORCH_MODEL)
|
||||||
|
|
||||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
mock_logger.debug.assert_called_once_with(
|
||||||
mock_logger.exception.assert_called_once_with(
|
"loading tensor with Torch: %s", TORCH_MODEL
|
||||||
"error loading with Torch, trying with Torch JIT: %s", name
|
|
||||||
)
|
)
|
||||||
mock_torch.jit.load.assert_called_once_with(name)
|
mock_logger.exception.assert_called_once_with(
|
||||||
|
"error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL
|
||||||
|
)
|
||||||
|
mock_torch.jit.load.assert_called_once_with(TORCH_MODEL)
|
||||||
self.assertEqual(result, checkpoint)
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
@ -358,20 +364,21 @@ class LoadTensorTests(unittest.TestCase):
|
||||||
@patch("onnx_web.convert.utils.logger")
|
@patch("onnx_web.convert.utils.logger")
|
||||||
@patch("onnx_web.convert.utils.torch")
|
@patch("onnx_web.convert.utils.torch")
|
||||||
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
||||||
name = "model.onnx"
|
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
checkpoint = MagicMock()
|
checkpoint = MagicMock()
|
||||||
mock_torch.load.side_effect = [checkpoint]
|
mock_torch.load.side_effect = [checkpoint]
|
||||||
|
|
||||||
result = load_tensor(name, map_location)
|
result = load_tensor(ONNX_MODEL, map_location)
|
||||||
|
|
||||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
mock_logger.debug.assert_has_calls(
|
||||||
|
[mock.call("loading tensor: %s", ONNX_MODEL)]
|
||||||
|
)
|
||||||
mock_logger.warning.assert_called_once_with(
|
mock_logger.warning.assert_called_once_with(
|
||||||
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
||||||
)
|
)
|
||||||
mock_torch.load.assert_has_calls(
|
mock_torch.load.assert_has_calls(
|
||||||
[
|
[
|
||||||
mock.call(name, map_location=map_location),
|
mock.call(ONNX_MODEL, map_location=map_location),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(result, checkpoint)
|
self.assertEqual(result, checkpoint)
|
||||||
|
@ -429,16 +436,15 @@ class FixDiffusionNameTests(unittest.TestCase):
|
||||||
|
|
||||||
class BuildCachePathsTests(unittest.TestCase):
|
class BuildCachePathsTests(unittest.TestCase):
|
||||||
def test_build_cache_paths_without_format(self):
|
def test_build_cache_paths_without_format(self):
|
||||||
name = "model.onnx"
|
|
||||||
client = "client1"
|
client = "client1"
|
||||||
cache = "/path/to/cache"
|
cache = "/path/to/cache"
|
||||||
|
|
||||||
conversion = ConversionContext(cache_path=cache)
|
conversion = ConversionContext(cache_path=cache)
|
||||||
result = build_cache_paths(conversion, name, client, cache)
|
result = build_cache_paths(conversion, ONNX_MODEL, client, cache)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", "model.onnx"),
|
path.join("/path/to/cache", ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client1", "model.onnx"),
|
path.join("/path/to/cache/client1", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
@ -452,23 +458,22 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
result = build_cache_paths(conversion, name, client, cache, format)
|
result = build_cache_paths(conversion, name, client, cache, format)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", "model.onnx"),
|
path.join("/path/to/cache", ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client2", "model.onnx"),
|
path.join("/path/to/cache/client2", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
def test_build_cache_paths_with_existing_extension(self):
|
def test_build_cache_paths_with_existing_extension(self):
|
||||||
name = "model.pth"
|
|
||||||
client = "client3"
|
client = "client3"
|
||||||
cache = "/path/to/cache"
|
cache = "/path/to/cache"
|
||||||
format = "onnx"
|
format = "onnx"
|
||||||
|
|
||||||
conversion = ConversionContext(cache_path=cache)
|
conversion = ConversionContext(cache_path=cache)
|
||||||
result = build_cache_paths(conversion, name, client, cache, format)
|
result = build_cache_paths(conversion, TORCH_MODEL, client, cache, format)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", "model.pth"),
|
path.join("/path/to/cache", TORCH_MODEL),
|
||||||
path.join("/path/to/cache/client3", "model.pth"),
|
path.join("/path/to/cache/client3", TORCH_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
@ -482,8 +487,8 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
result = build_cache_paths(conversion, name, client, cache, format)
|
result = build_cache_paths(conversion, name, client, cache, format)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", "model.onnx"),
|
path.join("/path/to/cache", ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client4", "model.onnx"),
|
path.join("/path/to/cache/client4", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
|
|
@ -31,16 +31,16 @@ class TestUtils(unittest.TestCase):
|
||||||
self.assertEqual(split_list(" a ,b "), ["a", "b"])
|
self.assertEqual(split_list(" a ,b "), ["a", "b"])
|
||||||
|
|
||||||
def test_get_boolean_empty(self):
|
def test_get_boolean_empty(self):
|
||||||
self.assertEqual(get_boolean({}, "key", False), False)
|
self.assertFalse(get_boolean({}, "key", False))
|
||||||
self.assertEqual(get_boolean({}, "key", True), True)
|
self.assertTrue(get_boolean({}, "key", True))
|
||||||
|
|
||||||
def test_get_boolean_true(self):
|
def test_get_boolean_true(self):
|
||||||
self.assertEqual(get_boolean({"key": True}, "key", False), True)
|
self.assertTrue(get_boolean({"key": True}, "key", False))
|
||||||
self.assertEqual(get_boolean({"key": True}, "key", True), True)
|
self.assertTrue(get_boolean({"key": True}, "key", True))
|
||||||
|
|
||||||
def test_get_boolean_false(self):
|
def test_get_boolean_false(self):
|
||||||
self.assertEqual(get_boolean({"key": False}, "key", False), False)
|
self.assertFalse(get_boolean({"key": False}, "key", False))
|
||||||
self.assertEqual(get_boolean({"key": False}, "key", True), False)
|
self.assertFalse(get_boolean({"key": False}, "key", True))
|
||||||
|
|
||||||
def test_get_list_empty(self):
|
def test_get_list_empty(self):
|
||||||
self.assertEqual(get_list({}, "key", ""), [])
|
self.assertEqual(get_list({}, "key", ""), [])
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||||
import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material';
|
import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material';
|
||||||
import { SxProps, Theme, ThemeProvider, createTheme } from '@mui/material/styles';
|
import { createTheme, ThemeProvider } from '@mui/material/styles';
|
||||||
import { Allotment } from 'allotment';
|
import { Allotment } from 'allotment';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useContext, useMemo } from 'react';
|
import { useContext, useMemo } from 'react';
|
||||||
|
@ -21,7 +21,7 @@ import { Models } from './tab/Models.js';
|
||||||
import { Settings } from './tab/Settings.js';
|
import { Settings } from './tab/Settings.js';
|
||||||
import { Txt2Img } from './tab/Txt2Img.js';
|
import { Txt2Img } from './tab/Txt2Img.js';
|
||||||
import { Upscale } from './tab/Upscale.js';
|
import { Upscale } from './tab/Upscale.js';
|
||||||
import { TAB_LABELS, getTab, getTheme } from './utils.js';
|
import { getTab, getTheme, TAB_LABELS } from './utils.js';
|
||||||
|
|
||||||
import 'allotment/dist/style.css';
|
import 'allotment/dist/style.css';
|
||||||
import './main.css';
|
import './main.css';
|
||||||
|
@ -49,8 +49,6 @@ export function OnnxWeb(props: OnnxWebProps) {
|
||||||
[prefersDarkMode, stateTheme],
|
[prefersDarkMode, stateTheme],
|
||||||
);
|
);
|
||||||
|
|
||||||
const historyStyle: SxProps<Theme> = layout.history.style;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ThemeProvider theme={theme}>
|
<ThemeProvider theme={theme}>
|
||||||
<CssBaseline />
|
<CssBaseline />
|
||||||
|
@ -59,7 +57,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
||||||
<Logo />
|
<Logo />
|
||||||
</Box>
|
</Box>
|
||||||
{props.motd && <Motd />}
|
{props.motd && <Motd />}
|
||||||
{renderBody(direction, historyStyle, historyWidth)}
|
{renderBody(direction, historyWidth)}
|
||||||
</Container>
|
</Container>
|
||||||
</ThemeProvider>
|
</ThemeProvider>
|
||||||
);
|
);
|
||||||
|
@ -77,25 +75,30 @@ export function selectHistoryWidth(state: OnnxState) {
|
||||||
return state.historyWidth;
|
return state.historyWidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderBody(direction: Layout, historyStyle: SxProps<Theme>, historyWidth: number) {
|
function renderBody(direction: Layout, historyWidth: number) {
|
||||||
if (direction === 'vertical') {
|
if (direction === 'vertical') {
|
||||||
return <VerticalBody direction={direction} style={historyStyle} width={historyWidth} />;
|
return <VerticalBody direction={direction} width={historyWidth} />;
|
||||||
} else {
|
} else {
|
||||||
return <HorizontalBody direction={direction} style={historyStyle} width={historyWidth} />;
|
return <HorizontalBody direction={direction} width={historyWidth} />;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// used for both horizontal and vertical
|
// used for both horizontal and vertical
|
||||||
export interface BodyProps {
|
export interface BodyProps {
|
||||||
direction: Layout;
|
direction: Layout;
|
||||||
style: SxProps<Theme>;
|
|
||||||
width: number;
|
width: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function HorizontalBody(props: BodyProps) {
|
export function HorizontalBody(props: BodyProps) {
|
||||||
const layout = LAYOUT_STYLES[props.direction];
|
const layout = LAYOUT_STYLES[props.direction];
|
||||||
|
|
||||||
return <Allotment separator className='body-allotment' minSize={LAYOUT_MIN} defaultSizes={LAYOUT_PROPORTIONS} snap>
|
return <Allotment
|
||||||
|
className='body-allotment'
|
||||||
|
defaultSizes={LAYOUT_PROPORTIONS}
|
||||||
|
minSize={LAYOUT_MIN}
|
||||||
|
separator
|
||||||
|
snap
|
||||||
|
>
|
||||||
<TabGroup direction={props.direction} />
|
<TabGroup direction={props.direction} />
|
||||||
<Box className='box-history' sx={layout.history.style}>
|
<Box className='box-history' sx={layout.history.style}>
|
||||||
<ImageHistory width={props.width} />
|
<ImageHistory width={props.width} />
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
"backlighting",
|
"backlighting",
|
||||||
"basicsr",
|
"basicsr",
|
||||||
"bokeh",
|
"bokeh",
|
||||||
|
"BSRGAN",
|
||||||
"Civitai",
|
"Civitai",
|
||||||
"ckpt",
|
"ckpt",
|
||||||
"cnet",
|
"cnet",
|
||||||
|
|
Loading…
Reference in New Issue