1
0
Fork 0

lint fixes

This commit is contained in:
Sean Sube 2024-01-12 22:29:39 -06:00
parent 7966f001e1
commit c5d64e7b1e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 57 additions and 53 deletions

View File

@ -267,13 +267,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output_count = params.batch
if source_filter is not None and source_filter != "none":
logger.debug(
"including filtered source with outputs, filter: %s", source_filter
)
output_count += 1
job_name = make_job_name("img2img", params, size, extras=[strength]) job_name = make_job_name("img2img", params, size, extras=[strength])
queue = pool.submit( queue = pool.submit(
job_name, job_name,

View File

@ -423,7 +423,7 @@ class DevicePoolExecutor:
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
""" """
jobs: Tuple[str, int, JobStatus] = [] jobs: List[Tuple[str, int, JobStatus]] = []
jobs.extend( jobs.extend(
[ [
( (

View File

@ -6,6 +6,7 @@ import torch
from onnx import GraphProto, ModelProto, NodeProto from onnx import GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import from_array from onnx.numpy_helper import from_array
from onnx_web.constants import ONNX_MODEL
from onnx_web.convert.diffusion.lora import ( from onnx_web.convert.diffusion.lora import (
blend_loras, blend_loras,
blend_node_conv_gemm, blend_node_conv_gemm,
@ -231,7 +232,6 @@ class BlendLoRATests(unittest.TestCase):
@patch("onnx_web.convert.diffusion.lora.load") @patch("onnx_web.convert.diffusion.lora.load")
@patch("onnx_web.convert.diffusion.lora.load_tensor") @patch("onnx_web.convert.diffusion.lora.load_tensor")
def test_blend_loras_load_str(self, mock_load_tensor, mock_load): def test_blend_loras_load_str(self, mock_load_tensor, mock_load):
base_name = "model.onnx"
loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)] loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)]
model_type = "unet" model_type = "unet"
model_index = 2 model_index = 2
@ -241,10 +241,12 @@ class BlendLoRATests(unittest.TestCase):
mock_load_tensor.return_value = MagicMock() mock_load_tensor.return_value = MagicMock()
# Call the blend_loras function # Call the blend_loras function
blended_model = blend_loras(None, base_name, loras, model_type, model_index, xl) blended_model = blend_loras(
None, ONNX_MODEL, loras, model_type, model_index, xl
)
# Assert that the InferenceSession is called with the correct arguments # Assert that the InferenceSession is called with the correct arguments
mock_load.assert_called_once_with(base_name) mock_load.assert_called_once_with(ONNX_MODEL)
# Assert that the model is loaded successfully # Assert that the model is loaded successfully
self.assertEqual(blended_model, mock_load.return_value) self.assertEqual(blended_model, mock_load.return_value)

View File

@ -3,6 +3,7 @@ from os import path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from onnx_web.constants import ONNX_MODEL
from onnx_web.convert.utils import ( from onnx_web.convert.utils import (
DEFAULT_OPSET, DEFAULT_OPSET,
ConversionContext, ConversionContext,
@ -267,36 +268,41 @@ class ResolveTensorTests(unittest.TestCase):
self.assertIsNone(resolve_tensor("missing")) self.assertIsNone(resolve_tensor("missing"))
TORCH_MODEL = "model.pth"
class LoadTorchTests(unittest.TestCase): class LoadTorchTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch") @patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_load(self, mock_torch, mock_logger): def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
name = "model.pth"
map_location = "cpu" map_location = "cpu"
checkpoint = MagicMock() checkpoint = MagicMock()
mock_torch.load.return_value = checkpoint mock_torch.load.return_value = checkpoint
result = load_torch(name, map_location) result = load_torch(TORCH_MODEL, map_location)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) mock_logger.debug.assert_called_once_with(
mock_torch.load.assert_called_once_with(name, map_location=map_location) "loading tensor with Torch: %s", TORCH_MODEL
)
mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location)
self.assertEqual(result, checkpoint) self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch") @patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger): def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
name = "model.pth"
checkpoint = MagicMock() checkpoint = MagicMock()
mock_torch.load.side_effect = Exception() mock_torch.load.side_effect = Exception()
mock_torch.jit.load.return_value = checkpoint mock_torch.jit.load.return_value = checkpoint
result = load_torch(name) result = load_torch(TORCH_MODEL)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) mock_logger.debug.assert_called_once_with(
mock_logger.exception.assert_called_once_with( "loading tensor with Torch: %s", TORCH_MODEL
"error loading with Torch, trying with Torch JIT: %s", name
) )
mock_torch.jit.load.assert_called_once_with(name) mock_logger.exception.assert_called_once_with(
"error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL
)
mock_torch.jit.load.assert_called_once_with(TORCH_MODEL)
self.assertEqual(result, checkpoint) self.assertEqual(result, checkpoint)
@ -358,20 +364,21 @@ class LoadTensorTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger") @patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch") @patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger): def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
name = "model.onnx"
map_location = "cpu" map_location = "cpu"
checkpoint = MagicMock() checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint] mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location) result = load_tensor(ONNX_MODEL, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) mock_logger.debug.assert_has_calls(
[mock.call("loading tensor: %s", ONNX_MODEL)]
)
mock_logger.warning.assert_called_once_with( mock_logger.warning.assert_called_once_with(
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx" "tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
) )
mock_torch.load.assert_has_calls( mock_torch.load.assert_has_calls(
[ [
mock.call(name, map_location=map_location), mock.call(ONNX_MODEL, map_location=map_location),
] ]
) )
self.assertEqual(result, checkpoint) self.assertEqual(result, checkpoint)
@ -429,16 +436,15 @@ class FixDiffusionNameTests(unittest.TestCase):
class BuildCachePathsTests(unittest.TestCase): class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_without_format(self): def test_build_cache_paths_without_format(self):
name = "model.onnx"
client = "client1" client = "client1"
cache = "/path/to/cache" cache = "/path/to/cache"
conversion = ConversionContext(cache_path=cache) conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache) result = build_cache_paths(conversion, ONNX_MODEL, client, cache)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", "model.onnx"), path.join("/path/to/cache", ONNX_MODEL),
path.join("/path/to/cache/client1", "model.onnx"), path.join("/path/to/cache/client1", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
@ -452,23 +458,22 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, cache, format) result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", "model.onnx"), path.join("/path/to/cache", ONNX_MODEL),
path.join("/path/to/cache/client2", "model.onnx"), path.join("/path/to/cache/client2", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_existing_extension(self): def test_build_cache_paths_with_existing_extension(self):
name = "model.pth"
client = "client3" client = "client3"
cache = "/path/to/cache" cache = "/path/to/cache"
format = "onnx" format = "onnx"
conversion = ConversionContext(cache_path=cache) conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format) result = build_cache_paths(conversion, TORCH_MODEL, client, cache, format)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", "model.pth"), path.join("/path/to/cache", TORCH_MODEL),
path.join("/path/to/cache/client3", "model.pth"), path.join("/path/to/cache/client3", TORCH_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
@ -482,8 +487,8 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, cache, format) result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", "model.onnx"), path.join("/path/to/cache", ONNX_MODEL),
path.join("/path/to/cache/client4", "model.onnx"), path.join("/path/to/cache/client4", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)

View File

@ -31,16 +31,16 @@ class TestUtils(unittest.TestCase):
self.assertEqual(split_list(" a ,b "), ["a", "b"]) self.assertEqual(split_list(" a ,b "), ["a", "b"])
def test_get_boolean_empty(self): def test_get_boolean_empty(self):
self.assertEqual(get_boolean({}, "key", False), False) self.assertFalse(get_boolean({}, "key", False))
self.assertEqual(get_boolean({}, "key", True), True) self.assertTrue(get_boolean({}, "key", True))
def test_get_boolean_true(self): def test_get_boolean_true(self):
self.assertEqual(get_boolean({"key": True}, "key", False), True) self.assertTrue(get_boolean({"key": True}, "key", False))
self.assertEqual(get_boolean({"key": True}, "key", True), True) self.assertTrue(get_boolean({"key": True}, "key", True))
def test_get_boolean_false(self): def test_get_boolean_false(self):
self.assertEqual(get_boolean({"key": False}, "key", False), False) self.assertFalse(get_boolean({"key": False}, "key", False))
self.assertEqual(get_boolean({"key": False}, "key", True), False) self.assertFalse(get_boolean({"key": False}, "key", True))
def test_get_list_empty(self): def test_get_list_empty(self):
self.assertEqual(get_list({}, "key", ""), []) self.assertEqual(get_list({}, "key", ""), [])

View File

@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { TabContext, TabList, TabPanel } from '@mui/lab'; import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material'; import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material';
import { SxProps, Theme, ThemeProvider, createTheme } from '@mui/material/styles'; import { createTheme, ThemeProvider } from '@mui/material/styles';
import { Allotment } from 'allotment'; import { Allotment } from 'allotment';
import * as React from 'react'; import * as React from 'react';
import { useContext, useMemo } from 'react'; import { useContext, useMemo } from 'react';
@ -21,7 +21,7 @@ import { Models } from './tab/Models.js';
import { Settings } from './tab/Settings.js'; import { Settings } from './tab/Settings.js';
import { Txt2Img } from './tab/Txt2Img.js'; import { Txt2Img } from './tab/Txt2Img.js';
import { Upscale } from './tab/Upscale.js'; import { Upscale } from './tab/Upscale.js';
import { TAB_LABELS, getTab, getTheme } from './utils.js'; import { getTab, getTheme, TAB_LABELS } from './utils.js';
import 'allotment/dist/style.css'; import 'allotment/dist/style.css';
import './main.css'; import './main.css';
@ -49,8 +49,6 @@ export function OnnxWeb(props: OnnxWebProps) {
[prefersDarkMode, stateTheme], [prefersDarkMode, stateTheme],
); );
const historyStyle: SxProps<Theme> = layout.history.style;
return ( return (
<ThemeProvider theme={theme}> <ThemeProvider theme={theme}>
<CssBaseline /> <CssBaseline />
@ -59,7 +57,7 @@ export function OnnxWeb(props: OnnxWebProps) {
<Logo /> <Logo />
</Box> </Box>
{props.motd && <Motd />} {props.motd && <Motd />}
{renderBody(direction, historyStyle, historyWidth)} {renderBody(direction, historyWidth)}
</Container> </Container>
</ThemeProvider> </ThemeProvider>
); );
@ -77,25 +75,30 @@ export function selectHistoryWidth(state: OnnxState) {
return state.historyWidth; return state.historyWidth;
} }
function renderBody(direction: Layout, historyStyle: SxProps<Theme>, historyWidth: number) { function renderBody(direction: Layout, historyWidth: number) {
if (direction === 'vertical') { if (direction === 'vertical') {
return <VerticalBody direction={direction} style={historyStyle} width={historyWidth} />; return <VerticalBody direction={direction} width={historyWidth} />;
} else { } else {
return <HorizontalBody direction={direction} style={historyStyle} width={historyWidth} />; return <HorizontalBody direction={direction} width={historyWidth} />;
} }
} }
// used for both horizontal and vertical // used for both horizontal and vertical
export interface BodyProps { export interface BodyProps {
direction: Layout; direction: Layout;
style: SxProps<Theme>;
width: number; width: number;
} }
export function HorizontalBody(props: BodyProps) { export function HorizontalBody(props: BodyProps) {
const layout = LAYOUT_STYLES[props.direction]; const layout = LAYOUT_STYLES[props.direction];
return <Allotment separator className='body-allotment' minSize={LAYOUT_MIN} defaultSizes={LAYOUT_PROPORTIONS} snap> return <Allotment
className='body-allotment'
defaultSizes={LAYOUT_PROPORTIONS}
minSize={LAYOUT_MIN}
separator
snap
>
<TabGroup direction={props.direction} /> <TabGroup direction={props.direction} />
<Box className='box-history' sx={layout.history.style}> <Box className='box-history' sx={layout.history.style}>
<ImageHistory width={props.width} /> <ImageHistory width={props.width} />

View File

@ -25,6 +25,7 @@
"backlighting", "backlighting",
"basicsr", "basicsr",
"bokeh", "bokeh",
"BSRGAN",
"Civitai", "Civitai",
"ckpt", "ckpt",
"cnet", "cnet",