1
0
Fork 0

name threads, max queues, type/lint fixes

This commit is contained in:
Sean Sube 2023-02-28 21:44:52 -06:00
parent c95ac1fbdd
commit c99aa67220
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
30 changed files with 179 additions and 111 deletions

View File

@ -39,4 +39,4 @@ lint-fix:
flake8 onnx_web
typecheck:
mypy -m onnx_web.serve
mypy onnx_web

View File

@ -1,7 +1,7 @@
version: 1
formatters:
simple:
format: '[%(asctime)s] %(levelname)s: %(name)s: %(message)s'
format: '[%(asctime)s] %(levelname)s: %(processName)s %(threadName)s %(name)s: %(message)s'
handlers:
console:
class: logging.StreamHandler

View File

@ -62,7 +62,7 @@ class ChainPipeline:
def __init__(
self,
stages: List[PipelineStage] = None,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
@ -82,7 +82,7 @@ class ChainPipeline:
server: ServerContext,
params: ImageParams,
source: Image.Image,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> Image.Image:
"""

View File

@ -9,6 +9,7 @@ from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -20,7 +21,7 @@ def blend_img2img(
params: ImageParams,
source: Image.Image,
*,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:

View File

@ -31,7 +31,7 @@ def blend_inpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)

View File

@ -21,7 +21,7 @@ def blend_mask(
*,
sources: Optional[List[Image.Image]] = None,
stage_mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")

View File

@ -5,6 +5,7 @@ from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -18,7 +19,7 @@ def correct_codeformer(
_params: ImageParams,
source: Image.Image,
*,
stage_source: Image.Image = None,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:

View File

@ -7,6 +7,7 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -20,9 +21,9 @@ def persist_s3(
*,
output: str,
bucket: str,
endpoint_url: str = None,
profile_name: str = None,
stage_source: Image.Image = None,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source

View File

@ -5,6 +5,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -18,7 +19,7 @@ def reduce_crop(
*,
origin: Size,
size: Size,
stage_source: Image.Image = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source

View File

@ -9,6 +9,7 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -21,7 +22,7 @@ def source_txt2img(
_source: Image.Image,
*,
size: Size,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)

View File

@ -31,7 +31,7 @@ def upscale_outpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source

View File

@ -9,12 +9,10 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from typing import Optional
logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = (None, None)
TAG_X4_V3 = "real-esrgan-x4-v3"
@ -104,7 +102,7 @@ def upscale_resrgan(
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Image.Image = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source

View File

@ -13,6 +13,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -70,8 +71,8 @@ def upscale_stable_diffusion(
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Image.Image = None,
callback: ProgressCallback = None,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)

View File

@ -110,3 +110,6 @@ def process_tile_order(
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)
else:
logger.warn("unknown tile order: %s", order)
raise ValueError()

View File

@ -3,7 +3,7 @@ from argparse import ArgumentParser
from logging import getLogger
from os import makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from jsonschema import ValidationError, validate
@ -36,7 +36,7 @@ warnings.filterwarnings(
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
Models = Dict[str, List[Any]]
logger = getLogger(__name__)

View File

@ -23,7 +23,7 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext(ServerContext):
def __init__(
self,
model_path: Optional[str] = None,
model_path: str,
cache_path: Optional[str] = None,
device: Optional[str] = None,
half: Optional[bool] = False,
@ -31,7 +31,7 @@ class ConversionContext(ServerContext):
token: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(self, model_path=model_path, cache_path=cache_path)
super().__init__(model_path=model_path, cache_path=cache_path)
self.half = half
self.opset = opset
@ -153,7 +153,7 @@ def source_format(model: Dict) -> Optional[str]:
return model["format"]
if "source" in model:
ext = path.splitext(model["source"])
_name, ext = path.splitext(model["source"])
if ext in model_formats:
return ext
@ -183,7 +183,7 @@ class Config(object):
setattr(target, k, v)
def load_yaml(file: str) -> str:
def load_yaml(file: str) -> Config:
with open(file, "r") as f:
data = safe_load(f.read())
return Config(data)

View File

@ -6,6 +6,7 @@ from diffusers.utils.logging import disable_progress_bar
from flask import Flask
from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars
from setproctitle import setproctitle
from torch.multiprocessing import set_start_method
from .server.api import register_api_routes
@ -26,6 +27,7 @@ logger = getLogger(__name__)
def main():
setproctitle("onnx-web server")
set_start_method("spawn", force=True)
context = ServerContext.from_environ()

View File

@ -15,7 +15,7 @@ from .utils import base_join
logger = getLogger(__name__)
def hash_value(sha, param: Param):
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
@ -63,7 +63,7 @@ def make_output_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None,
extras: Optional[List[Optional[Param]]] = None,
) -> List[str]:
now = int(time())
sha = sha256()

View File

@ -101,12 +101,12 @@ class DeviceParams:
self.device = device
self.provider = provider
self.options = options
self.optimizations = optimizations
self.optimizations = optimizations or []
def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options)
def ort_provider(self) -> Tuple[str, Any]:
def ort_provider(self) -> Union[str, Tuple[str, Any]]:
if self.options is None:
return self.provider
else:

View File

@ -81,7 +81,7 @@ def introspect(context: ServerContext, app: Flask):
return {
"name": "onnx-web",
"routes": [
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
{"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()}
for rule in app.url_map.iter_rules()
],
}
@ -119,10 +119,10 @@ def list_schedulers(context: ServerContext):
def img2img(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
@ -136,7 +136,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"),
)
output = make_output_name(context, "img2img", params, size, extras=(strength,))
output = make_output_name(context, "img2img", params, size, extras=[strength])
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)
@ -179,16 +179,15 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
if "mask" not in request.files:
mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
@ -207,7 +206,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
"inpaint",
params,
size,
extras=(
extras=[
expand.left,
expand.right,
expand.top,
@ -216,7 +215,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
noise_source.__name__,
fill_color,
tile_order,
),
],
)
job_name = output[0]
logger.info("inpaint job queued for: %s", job_name)
@ -245,10 +244,10 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
def upscale(context: ServerContext, pool: DevicePoolExecutor):
if "source" not in request.files:
source_file = request.files.get("source")
if source_file is None:
return error_reply("source image is required")
source_file = request.files.get("source")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(context)
@ -324,9 +323,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
stage.name,
)
source_file = request.files.get(stage_source_name)
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if source_file is not None:
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
logger.debug(
@ -335,9 +335,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
stage.name,
)
mask_file = request.files.get(stage_mask_name)
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
if mask_file is not None:
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
@ -360,10 +361,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
def blend(context: ServerContext, pool: DevicePoolExecutor):
if "mask" not in request.files:
mask_file = request.files.get("mask")
if mask_file is None:
return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)
@ -372,9 +373,12 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
if source_file is None:
logger.warning("missing source %s", i)
else:
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
@ -403,10 +407,11 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
output = make_output_name(context, "txt2txt", params, size)
logger.info("upscale job queued for: %s", output)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)
pool.submit(
output,
job_name,
run_txt2txt_pipeline,
context,
params,
@ -420,6 +425,8 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
cancel = pool.cancel(output_file)
@ -428,6 +435,8 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
def ready(context: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None)
if output_file is None:
return error_reply("output name is required")
done, progress = pool.done(output_file)
@ -436,7 +445,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if path.exists(output):
return ready_reply(True)
return ready_reply(done, progress=progress)
return ready_reply(done or False, progress=progress)
def status(context: ServerContext, pool: DevicePoolExecutor):

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import environ, path
from typing import List
from typing import List, Optional
from ..utils import get_boolean
from .model_cache import ModelCache
@ -18,13 +18,13 @@ class ServerContext:
cors_origin: str = "*",
num_workers: int = 1,
any_platform: bool = True,
block_platforms: List[str] = None,
default_platform: str = None,
block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None,
image_format: str = "png",
cache: ModelCache = None,
cache_path: str = None,
cache: Optional[ModelCache] = None,
cache_path: Optional[str] = None,
show_progress: bool = True,
optimizations: List[str] = None,
optimizations: Optional[List[str]] = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path

View File

@ -119,9 +119,8 @@ def patch_not_impl():
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
if url in cache_path_map:
cache_path = cache_path_map.get(url)
else:
cache_path = cache_path_map.get(url, None)
if cache_path is None:
parsed = urlparse(url)
cache_path = path.basename(parsed.path)

View File

@ -22,13 +22,13 @@ def run_txt2txt_pipeline(
device = job.get_device()
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device())
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
tokenizer = AutoTokenizer.from_pretrained(model)
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
device.torch_device()
device.torch_str()
)
results = model.generate(
results = pipe.generate(
input_ids,
do_sample=True,
max_length=tokens,

View File

@ -12,6 +12,7 @@ from .chain import (
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext
from .worker import ProgressCallback, WorkerContext
from typing import Optional
logger = getLogger(__name__)
@ -24,7 +25,7 @@ def run_upscale_correction(
image: Image.Image,
*,
upscale: UpscaleParams,
callback: ProgressCallback = None,
callback: Optional[ProgressCallback] = None,
) -> Image.Image:
"""
This is a convenience method for a chain pipeline that will run upscaling and

View File

@ -2,7 +2,7 @@ import gc
import threading
from logging import getLogger
from os import environ, path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
import torch
@ -36,7 +36,7 @@ def get_and_clamp_int(
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
def get_from_list(args: Any, key: str, values: Sequence[Any]) -> Optional[Any]:
selected = args.get(key, None)
if selected in values:
return selected
@ -82,7 +82,7 @@ def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
raise ValueError("invalid size")
def run_gc(devices: List[DeviceParams] = None):
def run_gc(devices: Optional[List[DeviceParams]] = None):
logger.debug(
"running garbage collection with %s active threads", threading.active_count()
)

View File

@ -12,20 +12,20 @@ ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext:
cancel: "Value[bool]" = None
job: str = None
pending: "Queue[Tuple[Callable, Any, Any]]" = None
progress: "Value[int]" = None
cancel: "Value[bool]"
job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
progress: "Value[int]"
def __init__(
self,
job: str,
device: DeviceParams,
cancel: "Value[bool]" = None,
logs: "Queue[str]" = None,
pending: "Queue[Any]" = None,
progress: "Queue[Tuple[str, int]]" = None,
finished: "Queue[str]" = None,
cancel: "Value[bool]",
logs: "Queue[str]",
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]",
):
self.job = job
self.device = device

View File

@ -2,7 +2,7 @@ from collections import Counter
from logging import getLogger
from queue import Empty
from threading import Thread
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Queue, Value
@ -15,23 +15,38 @@ logger = getLogger(__name__)
class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None # Device -> Context
devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None
active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus]
finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus]
server: ServerContext
devices: List[DeviceParams]
max_jobs_per_worker: int
max_pending_per_worker: int
join_timeout: float
context: Dict[str, WorkerContext] # Device -> Context
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
total_jobs: int
logs: "Queue"
progress: "Queue[Tuple[str, str, int]]"
finished: "Queue[Tuple[str, str]]"
def __init__(
self,
server: ServerContext,
devices: List[DeviceParams],
max_jobs_per_worker: int = 10,
max_pending_per_worker: int = 100,
join_timeout: float = 1.0,
):
self.server = server
self.devices = devices
self.max_jobs_per_worker = max_jobs_per_worker
self.max_pending_per_worker = max_pending_per_worker
self.join_timeout = join_timeout
self.context = {}
@ -44,9 +59,9 @@ class DevicePoolExecutor:
self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.logs = Queue()
self.progress = Queue()
self.finished = Queue()
self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker)
self.finished = Queue(self.max_pending_per_worker)
self.create_logger_worker()
self.create_progress_worker()
@ -67,7 +82,7 @@ class DevicePoolExecutor:
pending = self.pending[name]
else:
logger.debug("creating new pending job queue")
pending = Queue()
pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending
context = WorkerContext(
@ -80,7 +95,11 @@ class DevicePoolExecutor:
pending=pending,
)
self.context[name] = context
self.workers[name] = Process(target=worker_main, args=(context, self.server))
self.workers[name] = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
)
logger.debug("starting worker for device %s", device)
self.workers[name].start()
@ -102,7 +121,9 @@ class DevicePoolExecutor:
except Exception as err:
logger.error("error in log worker: %s", err)
logger_thread = Thread(target=logger_worker, args=(self.logs,), daemon=True)
logger_thread = Thread(
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True
)
self.threads["logger"] = logger_thread
logger.debug("starting logger worker")
@ -128,7 +149,12 @@ class DevicePoolExecutor:
except Exception as err:
logger.error("error in progress worker: %s", err)
progress_thread = Thread(target=progress_worker, args=(self.progress,), daemon=True)
progress_thread = Thread(
name="onnx-web progress",
target=progress_worker,
args=(self.progress,),
daemon=True,
)
self.threads["progress"] = progress_thread
logger.debug("starting progress worker")
@ -152,7 +178,12 @@ class DevicePoolExecutor:
except Exception as err:
logger.error("error in finished worker: %s", err)
finished_thread = Thread(target=finished_worker, args=(self.finished,), daemon=True)
finished_thread = Thread(
name="onnx-web finished",
target=finished_worker,
args=(self.finished,),
daemon=True,
)
self.threads["finished"] = finished_thread
logger.debug("started finished worker")
@ -221,8 +252,18 @@ class DevicePoolExecutor:
return (False, progress)
def join(self):
logger.debug("stopping worker pool")
logger.info("stopping worker pool")
logger.debug("closing queues")
self.logs.close()
self.finished.close()
self.progress.close()
for queue in self.pending.values():
queue.close()
self.pending.clear()
logger.debug("stopping device workers")
for device, worker in self.workers.items():
if worker.is_alive():
logger.debug("stopping worker for device %s", device)
@ -235,15 +276,6 @@ class DevicePoolExecutor:
logger.debug("stopping worker thread: %s", name)
thread.join(self.join_timeout)
logger.debug("closing queues")
self.logs.close()
self.finished.close()
self.progress.close()
for queue in self.pending.values():
queue.close()
self.pending.clear()
logger.debug("worker pool fully joined")
def recycle(self):
@ -292,7 +324,7 @@ class DevicePoolExecutor:
def status(self) -> List[Tuple[str, int, bool, bool]]:
history = [
(name, progress, False, name in self.cancelled_jobs)
for name, _device, progress in self.active_jobs.items()
for name, (_device, progress) in self.active_jobs.items()
]
history.extend(
[

View File

@ -4,7 +4,6 @@ from sys import exit
from traceback import format_exception
from setproctitle import setproctitle
from torch.multiprocessing import Queue
from ..server import ServerContext, apply_patches
from ..torch_before_ort import get_available_providers

View File

@ -8,16 +8,35 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
[tool.mypy]
# ignore_missing_imports = true
exclude = [
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
]
[[tool.mypy.overrides]]
module = [
"basicsr.archs.rrdbnet_arch",
"basicsr.utils.download_util",
"basicsr.utils",
"basicsr",
"boto3",
"codeformer",
"codeformer.facelib.utils.misc",
"codeformer.facelib.utils",
"codeformer.facelib",
"diffusers",
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.paint_by_example",
"diffusers.pipelines.stable_diffusion",
"diffusers.pipeline_utils",
"diffusers.utils.logging",
"facexlib.utils",
"facexlib",
"gfpgan",
"onnxruntime",
"realesrgan"
"realesrgan",
"realesrgan.archs.srvgg_arch",
"safetensors",
"transformers"
]
ignore_missing_imports = true

View File

@ -4,7 +4,7 @@ test_images=0
while true;
do
curl "http://${test_host}:5000/api/txt2img?"\
'cfg=16.00&steps=3&scheduler=deis-multi&seed=-1&'\
'cfg=16.00&steps=3&scheduler=ddim&seed=-1&'\
'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\
'model=stable-diffusion-onnx-v1-5&platform=any&'\
'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\