1
0
Fork 0

name threads, max queues, type/lint fixes

This commit is contained in:
Sean Sube 2023-02-28 21:44:52 -06:00
parent c95ac1fbdd
commit c99aa67220
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
30 changed files with 179 additions and 111 deletions

View File

@ -39,4 +39,4 @@ lint-fix:
flake8 onnx_web flake8 onnx_web
typecheck: typecheck:
mypy -m onnx_web.serve mypy onnx_web

View File

@ -1,7 +1,7 @@
version: 1 version: 1
formatters: formatters:
simple: simple:
format: '[%(asctime)s] %(levelname)s: %(name)s: %(message)s' format: '[%(asctime)s] %(levelname)s: %(processName)s %(threadName)s %(name)s: %(message)s'
handlers: handlers:
console: console:
class: logging.StreamHandler class: logging.StreamHandler

View File

@ -62,7 +62,7 @@ class ChainPipeline:
def __init__( def __init__(
self, self,
stages: List[PipelineStage] = None, stages: Optional[List[PipelineStage]] = None,
): ):
""" """
Create a new pipeline that will run the given stages. Create a new pipeline that will run the given stages.
@ -82,7 +82,7 @@ class ChainPipeline:
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs **pipeline_kwargs
) -> Image.Image: ) -> Image.Image:
""" """

View File

@ -9,6 +9,7 @@ from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,7 +21,7 @@ def blend_img2img(
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
*, *,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:

View File

@ -31,7 +31,7 @@ def blend_inpaint(
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)

View File

@ -21,7 +21,7 @@ def blend_mask(
*, *,
sources: Optional[List[Image.Image]] = None, sources: Optional[List[Image.Image]] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info("blending image using mask") logger.info("blending image using mask")

View File

@ -5,6 +5,7 @@ from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,7 +19,7 @@ def correct_codeformer(
_params: ImageParams, _params: ImageParams,
source: Image.Image, source: Image.Image,
*, *,
stage_source: Image.Image = None, stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:

View File

@ -7,6 +7,7 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,9 +21,9 @@ def persist_s3(
*, *,
output: str, output: str,
bucket: str, bucket: str,
endpoint_url: str = None, endpoint_url: Optional[str] = None,
profile_name: str = None, profile_name: Optional[str] = None,
stage_source: Image.Image = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source source = stage_source or source

View File

@ -5,6 +5,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,7 +19,7 @@ def reduce_crop(
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Image.Image = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source source = stage_source or source

View File

@ -9,6 +9,7 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -21,7 +22,7 @@ def source_txt2img(
_source: Image.Image, _source: Image.Image,
*, *,
size: Size, size: Size,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)

View File

@ -31,7 +31,7 @@ def upscale_outpaint(
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source source = stage_source or source

View File

@ -9,12 +9,10 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = (None, None)
TAG_X4_V3 = "real-esrgan-x4-v3" TAG_X4_V3 = "real-esrgan-x4-v3"
@ -104,7 +102,7 @@ def upscale_resrgan(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Image.Image = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source source = stage_source or source

View File

@ -13,6 +13,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -70,8 +71,8 @@ def upscale_stable_diffusion(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Image.Image = None, stage_source: Optional[Image.Image] = None,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)

View File

@ -110,3 +110,6 @@ def process_tile_order(
elif order == TileOrder.spiral: elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile) logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs) return process_tile_spiral(source, tile, scale, filters, **kwargs)
else:
logger.warn("unknown tile order: %s", order)
raise ValueError()

View File

@ -3,7 +3,7 @@ from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import makedirs, path from os import makedirs, path
from sys import exit from sys import exit
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
@ -36,7 +36,7 @@ warnings.filterwarnings(
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
) )
Models = Dict[str, List[Tuple[str, str, Optional[int]]]] Models = Dict[str, List[Any]]
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -23,7 +23,7 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext(ServerContext): class ConversionContext(ServerContext):
def __init__( def __init__(
self, self,
model_path: Optional[str] = None, model_path: str,
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
half: Optional[bool] = False, half: Optional[bool] = False,
@ -31,7 +31,7 @@ class ConversionContext(ServerContext):
token: Optional[str] = None, token: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(self, model_path=model_path, cache_path=cache_path) super().__init__(model_path=model_path, cache_path=cache_path)
self.half = half self.half = half
self.opset = opset self.opset = opset
@ -153,7 +153,7 @@ def source_format(model: Dict) -> Optional[str]:
return model["format"] return model["format"]
if "source" in model: if "source" in model:
ext = path.splitext(model["source"]) _name, ext = path.splitext(model["source"])
if ext in model_formats: if ext in model_formats:
return ext return ext
@ -183,7 +183,7 @@ class Config(object):
setattr(target, k, v) setattr(target, k, v)
def load_yaml(file: str) -> str: def load_yaml(file: str) -> Config:
with open(file, "r") as f: with open(file, "r") as f:
data = safe_load(f.read()) data = safe_load(f.read())
return Config(data) return Config(data)

View File

@ -6,6 +6,7 @@ from diffusers.utils.logging import disable_progress_bar
from flask import Flask from flask import Flask
from flask_cors import CORS from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars from huggingface_hub.utils.tqdm import disable_progress_bars
from setproctitle import setproctitle
from torch.multiprocessing import set_start_method from torch.multiprocessing import set_start_method
from .server.api import register_api_routes from .server.api import register_api_routes
@ -26,6 +27,7 @@ logger = getLogger(__name__)
def main(): def main():
setproctitle("onnx-web server")
set_start_method("spawn", force=True) set_start_method("spawn", force=True)
context = ServerContext.from_environ() context = ServerContext.from_environ()

View File

@ -15,7 +15,7 @@ from .utils import base_join
logger = getLogger(__name__) logger = getLogger(__name__)
def hash_value(sha, param: Param): def hash_value(sha, param: Optional[Param]):
if param is None: if param is None:
return return
elif isinstance(param, bool): elif isinstance(param, bool):
@ -63,7 +63,7 @@ def make_output_name(
mode: str, mode: str,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
extras: Optional[Tuple[Param]] = None, extras: Optional[List[Optional[Param]]] = None,
) -> List[str]: ) -> List[str]:
now = int(time()) now = int(time())
sha = sha256() sha = sha256()

View File

@ -101,12 +101,12 @@ class DeviceParams:
self.device = device self.device = device
self.provider = provider self.provider = provider
self.options = options self.options = options
self.optimizations = optimizations self.optimizations = optimizations or []
def __str__(self) -> str: def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options) return "%s - %s (%s)" % (self.device, self.provider, self.options)
def ort_provider(self) -> Tuple[str, Any]: def ort_provider(self) -> Union[str, Tuple[str, Any]]:
if self.options is None: if self.options is None:
return self.provider return self.provider
else: else:

View File

@ -81,7 +81,7 @@ def introspect(context: ServerContext, app: Flask):
return { return {
"name": "onnx-web", "name": "onnx-web",
"routes": [ "routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()} {"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()}
for rule in app.url_map.iter_rules() for rule in app.url_map.iter_rules()
], ],
} }
@ -119,10 +119,10 @@ def list_schedulers(context: ServerContext):
def img2img(context: ServerContext, pool: DevicePoolExecutor): def img2img(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files: source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(context)
@ -136,7 +136,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"), get_config_value("strength", "min"),
) )
output = make_output_name(context, "img2img", params, size, extras=(strength,)) output = make_output_name(context, "img2img", params, size, extras=[strength])
job_name = output[0] job_name = output[0]
logger.info("img2img job queued for: %s", job_name) logger.info("img2img job queued for: %s", job_name)
@ -179,16 +179,15 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
def inpaint(context: ServerContext, pool: DevicePoolExecutor): def inpaint(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files: source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
if "mask" not in request.files: mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required") return error_reply("mask image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(context)
@ -207,7 +206,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
"inpaint", "inpaint",
params, params,
size, size,
extras=( extras=[
expand.left, expand.left,
expand.right, expand.right,
expand.top, expand.top,
@ -216,7 +215,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
noise_source.__name__, noise_source.__name__,
fill_color, fill_color,
tile_order, tile_order,
), ],
) )
job_name = output[0] job_name = output[0]
logger.info("inpaint job queued for: %s", job_name) logger.info("inpaint job queued for: %s", job_name)
@ -245,10 +244,10 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
def upscale(context: ServerContext, pool: DevicePoolExecutor): def upscale(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files: source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required") return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(context)
@ -324,6 +323,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
stage.name, stage.name,
) )
source_file = request.files.get(stage_source_name) source_file = request.files.get(stage_source_name)
if source_file is not None:
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height)) source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source kwargs["stage_source"] = source
@ -335,6 +335,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
stage.name, stage.name,
) )
mask_file = request.files.get(stage_mask_name) mask_file = request.files.get(stage_mask_name)
if mask_file is not None:
mask = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height)) mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask kwargs["stage_mask"] = mask
@ -360,10 +361,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
def blend(context: ServerContext, pool: DevicePoolExecutor): def blend(context: ServerContext, pool: DevicePoolExecutor):
if "mask" not in request.files: mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required") return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask) mask = valid_image(mask)
@ -372,6 +373,9 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
for i in range(max_sources): for i in range(max_sources):
source_file = request.files.get("source:%s" % (i)) source_file = request.files.get("source:%s" % (i))
if source_file is None:
logger.warning("missing source %s", i)
else:
source = Image.open(BytesIO(source_file.read())).convert("RGBA") source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size) source = valid_image(source, mask.size, mask.size)
sources.append(source) sources.append(source)
@ -403,10 +407,11 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context) device, params, size = pipeline_from_request(context)
output = make_output_name(context, "txt2txt", params, size) output = make_output_name(context, "txt2txt", params, size)
logger.info("upscale job queued for: %s", output) job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
pool.submit( pool.submit(
output, job_name,
run_txt2txt_pipeline, run_txt2txt_pipeline,
context, context,
params, params,
@ -420,6 +425,8 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
def cancel(context: ServerContext, pool: DevicePoolExecutor): def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
cancel = pool.cancel(output_file) cancel = pool.cancel(output_file)
@ -428,6 +435,8 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
def ready(context: ServerContext, pool: DevicePoolExecutor): def ready(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
done, progress = pool.done(output_file) done, progress = pool.done(output_file)
@ -436,7 +445,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if path.exists(output): if path.exists(output):
return ready_reply(True) return ready_reply(True)
return ready_reply(done, progress=progress) return ready_reply(done or False, progress=progress)
def status(context: ServerContext, pool: DevicePoolExecutor): def status(context: ServerContext, pool: DevicePoolExecutor):

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from typing import List from typing import List, Optional
from ..utils import get_boolean from ..utils import get_boolean
from .model_cache import ModelCache from .model_cache import ModelCache
@ -18,13 +18,13 @@ class ServerContext:
cors_origin: str = "*", cors_origin: str = "*",
num_workers: int = 1, num_workers: int = 1,
any_platform: bool = True, any_platform: bool = True,
block_platforms: List[str] = None, block_platforms: Optional[List[str]] = None,
default_platform: str = None, default_platform: Optional[str] = None,
image_format: str = "png", image_format: str = "png",
cache: ModelCache = None, cache: Optional[ModelCache] = None,
cache_path: str = None, cache_path: Optional[str] = None,
show_progress: bool = True, show_progress: bool = True,
optimizations: List[str] = None, optimizations: Optional[List[str]] = None,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path

View File

@ -119,9 +119,8 @@ def patch_not_impl():
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
if url in cache_path_map: cache_path = cache_path_map.get(url, None)
cache_path = cache_path_map.get(url) if cache_path is None:
else:
parsed = urlparse(url) parsed = urlparse(url)
cache_path = path.basename(parsed.path) cache_path = path.basename(parsed.path)

View File

@ -22,13 +22,13 @@ def run_txt2txt_pipeline(
device = job.get_device() device = job.get_device()
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device()) pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to( input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
device.torch_device() device.torch_str()
) )
results = model.generate( results = pipe.generate(
input_ids, input_ids,
do_sample=True, do_sample=True,
max_length=tokens, max_length=tokens,

View File

@ -12,6 +12,7 @@ from .chain import (
from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext from .server import ServerContext
from .worker import ProgressCallback, WorkerContext from .worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__) logger = getLogger(__name__)
@ -24,7 +25,7 @@ def run_upscale_correction(
image: Image.Image, image: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
callback: ProgressCallback = None, callback: Optional[ProgressCallback] = None,
) -> Image.Image: ) -> Image.Image:
""" """
This is a convenience method for a chain pipeline that will run upscaling and This is a convenience method for a chain pipeline that will run upscaling and

View File

@ -2,7 +2,7 @@ import gc
import threading import threading
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Sequence, Union
import torch import torch
@ -36,7 +36,7 @@ def get_and_clamp_int(
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]: def get_from_list(args: Any, key: str, values: Sequence[Any]) -> Optional[Any]:
selected = args.get(key, None) selected = args.get(key, None)
if selected in values: if selected in values:
return selected return selected
@ -82,7 +82,7 @@ def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
raise ValueError("invalid size") raise ValueError("invalid size")
def run_gc(devices: List[DeviceParams] = None): def run_gc(devices: Optional[List[DeviceParams]] = None):
logger.debug( logger.debug(
"running garbage collection with %s active threads", threading.active_count() "running garbage collection with %s active threads", threading.active_count()
) )

View File

@ -12,20 +12,20 @@ ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext: class WorkerContext:
cancel: "Value[bool]" = None cancel: "Value[bool]"
job: str = None job: str
pending: "Queue[Tuple[Callable, Any, Any]]" = None pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
progress: "Value[int]" = None progress: "Value[int]"
def __init__( def __init__(
self, self,
job: str, job: str,
device: DeviceParams, device: DeviceParams,
cancel: "Value[bool]" = None, cancel: "Value[bool]",
logs: "Queue[str]" = None, logs: "Queue[str]",
pending: "Queue[Any]" = None, pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, int]]" = None, progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[str]" = None, finished: "Queue[Tuple[str, str]]",
): ):
self.job = job self.job = job
self.device = device self.device = device

View File

@ -2,7 +2,7 @@ from collections import Counter
from logging import getLogger from logging import getLogger
from queue import Empty from queue import Empty
from threading import Thread from threading import Thread
from typing import Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Queue, Value from torch.multiprocessing import Process, Queue, Value
@ -15,23 +15,38 @@ logger = getLogger(__name__)
class DevicePoolExecutor: class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None # Device -> Context server: ServerContext
devices: List[DeviceParams] = None devices: List[DeviceParams]
pending: Dict[str, "Queue[WorkerContext]"] = None max_jobs_per_worker: int
workers: Dict[str, Process] = None max_pending_per_worker: int
active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus] join_timeout: float
finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus]
context: Dict[str, WorkerContext] # Device -> Context
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
total_jobs: int
logs: "Queue"
progress: "Queue[Tuple[str, str, int]]"
finished: "Queue[Tuple[str, str]]"
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
devices: List[DeviceParams], devices: List[DeviceParams],
max_jobs_per_worker: int = 10, max_jobs_per_worker: int = 10,
max_pending_per_worker: int = 100,
join_timeout: float = 1.0, join_timeout: float = 1.0,
): ):
self.server = server self.server = server
self.devices = devices self.devices = devices
self.max_jobs_per_worker = max_jobs_per_worker self.max_jobs_per_worker = max_jobs_per_worker
self.max_pending_per_worker = max_pending_per_worker
self.join_timeout = join_timeout self.join_timeout = join_timeout
self.context = {} self.context = {}
@ -44,9 +59,9 @@ class DevicePoolExecutor:
self.finished_jobs = [] self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.logs = Queue() self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue() self.progress = Queue(self.max_pending_per_worker)
self.finished = Queue() self.finished = Queue(self.max_pending_per_worker)
self.create_logger_worker() self.create_logger_worker()
self.create_progress_worker() self.create_progress_worker()
@ -67,7 +82,7 @@ class DevicePoolExecutor:
pending = self.pending[name] pending = self.pending[name]
else: else:
logger.debug("creating new pending job queue") logger.debug("creating new pending job queue")
pending = Queue() pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending self.pending[name] = pending
context = WorkerContext( context = WorkerContext(
@ -80,7 +95,11 @@ class DevicePoolExecutor:
pending=pending, pending=pending,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process(target=worker_main, args=(context, self.server)) self.workers[name] = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
)
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name].start() self.workers[name].start()
@ -102,7 +121,9 @@ class DevicePoolExecutor:
except Exception as err: except Exception as err:
logger.error("error in log worker: %s", err) logger.error("error in log worker: %s", err)
logger_thread = Thread(target=logger_worker, args=(self.logs,), daemon=True) logger_thread = Thread(
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True
)
self.threads["logger"] = logger_thread self.threads["logger"] = logger_thread
logger.debug("starting logger worker") logger.debug("starting logger worker")
@ -128,7 +149,12 @@ class DevicePoolExecutor:
except Exception as err: except Exception as err:
logger.error("error in progress worker: %s", err) logger.error("error in progress worker: %s", err)
progress_thread = Thread(target=progress_worker, args=(self.progress,), daemon=True) progress_thread = Thread(
name="onnx-web progress",
target=progress_worker,
args=(self.progress,),
daemon=True,
)
self.threads["progress"] = progress_thread self.threads["progress"] = progress_thread
logger.debug("starting progress worker") logger.debug("starting progress worker")
@ -152,7 +178,12 @@ class DevicePoolExecutor:
except Exception as err: except Exception as err:
logger.error("error in finished worker: %s", err) logger.error("error in finished worker: %s", err)
finished_thread = Thread(target=finished_worker, args=(self.finished,), daemon=True) finished_thread = Thread(
name="onnx-web finished",
target=finished_worker,
args=(self.finished,),
daemon=True,
)
self.threads["finished"] = finished_thread self.threads["finished"] = finished_thread
logger.debug("started finished worker") logger.debug("started finished worker")
@ -221,8 +252,18 @@ class DevicePoolExecutor:
return (False, progress) return (False, progress)
def join(self): def join(self):
logger.debug("stopping worker pool") logger.info("stopping worker pool")
logger.debug("closing queues")
self.logs.close()
self.finished.close()
self.progress.close()
for queue in self.pending.values():
queue.close()
self.pending.clear()
logger.debug("stopping device workers")
for device, worker in self.workers.items(): for device, worker in self.workers.items():
if worker.is_alive(): if worker.is_alive():
logger.debug("stopping worker for device %s", device) logger.debug("stopping worker for device %s", device)
@ -235,15 +276,6 @@ class DevicePoolExecutor:
logger.debug("stopping worker thread: %s", name) logger.debug("stopping worker thread: %s", name)
thread.join(self.join_timeout) thread.join(self.join_timeout)
logger.debug("closing queues")
self.logs.close()
self.finished.close()
self.progress.close()
for queue in self.pending.values():
queue.close()
self.pending.clear()
logger.debug("worker pool fully joined") logger.debug("worker pool fully joined")
def recycle(self): def recycle(self):
@ -292,7 +324,7 @@ class DevicePoolExecutor:
def status(self) -> List[Tuple[str, int, bool, bool]]: def status(self) -> List[Tuple[str, int, bool, bool]]:
history = [ history = [
(name, progress, False, name in self.cancelled_jobs) (name, progress, False, name in self.cancelled_jobs)
for name, _device, progress in self.active_jobs.items() for name, (_device, progress) in self.active_jobs.items()
] ]
history.extend( history.extend(
[ [

View File

@ -4,7 +4,6 @@ from sys import exit
from traceback import format_exception from traceback import format_exception
from setproctitle import setproctitle from setproctitle import setproctitle
from torch.multiprocessing import Queue
from ..server import ServerContext, apply_patches from ..server import ServerContext, apply_patches
from ..torch_before_ort import get_available_providers from ..torch_before_ort import get_available_providers

View File

@ -8,16 +8,35 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
[tool.mypy] [tool.mypy]
# ignore_missing_imports = true # ignore_missing_imports = true
exclude = [
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [
"basicsr.archs.rrdbnet_arch", "basicsr.archs.rrdbnet_arch",
"basicsr.utils.download_util",
"basicsr.utils",
"basicsr",
"boto3", "boto3",
"codeformer", "codeformer",
"codeformer.facelib.utils.misc",
"codeformer.facelib.utils",
"codeformer.facelib",
"diffusers", "diffusers",
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.paint_by_example",
"diffusers.pipelines.stable_diffusion",
"diffusers.pipeline_utils", "diffusers.pipeline_utils",
"diffusers.utils.logging",
"facexlib.utils",
"facexlib",
"gfpgan", "gfpgan",
"onnxruntime", "onnxruntime",
"realesrgan" "realesrgan",
"realesrgan.archs.srvgg_arch",
"safetensors",
"transformers"
] ]
ignore_missing_imports = true ignore_missing_imports = true

View File

@ -4,7 +4,7 @@ test_images=0
while true; while true;
do do
curl "http://${test_host}:5000/api/txt2img?"\ curl "http://${test_host}:5000/api/txt2img?"\
'cfg=16.00&steps=3&scheduler=deis-multi&seed=-1&'\ 'cfg=16.00&steps=3&scheduler=ddim&seed=-1&'\
'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\ 'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\
'model=stable-diffusion-onnx-v1-5&platform=any&'\ 'model=stable-diffusion-onnx-v1-5&platform=any&'\
'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\ 'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\