name threads, max queues, type/lint fixes
This commit is contained in:
parent
c95ac1fbdd
commit
c99aa67220
|
@ -39,4 +39,4 @@ lint-fix:
|
|||
flake8 onnx_web
|
||||
|
||||
typecheck:
|
||||
mypy -m onnx_web.serve
|
||||
mypy onnx_web
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
version: 1
|
||||
formatters:
|
||||
simple:
|
||||
format: '[%(asctime)s] %(levelname)s: %(name)s: %(message)s'
|
||||
format: '[%(asctime)s] %(levelname)s: %(processName)s %(threadName)s %(name)s: %(message)s'
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
|
|
|
@ -62,7 +62,7 @@ class ChainPipeline:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[PipelineStage] = None,
|
||||
stages: Optional[List[PipelineStage]] = None,
|
||||
):
|
||||
"""
|
||||
Create a new pipeline that will run the given stages.
|
||||
|
@ -82,7 +82,7 @@ class ChainPipeline:
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs
|
||||
) -> Image.Image:
|
||||
"""
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..diffusion.load import load_pipeline
|
|||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,7 +21,7 @@ def blend_img2img(
|
|||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
|
|
|
@ -31,7 +31,7 @@ def blend_inpaint(
|
|||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
|
|
|
@ -21,7 +21,7 @@ def blend_mask(
|
|||
*,
|
||||
sources: Optional[List[Image.Image]] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
_callback: ProgressCallback = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info("blending image using mask")
|
||||
|
|
|
@ -5,6 +5,7 @@ from PIL import Image
|
|||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +19,7 @@ def correct_codeformer(
|
|||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
stage_source: Image.Image = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
|
|
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,9 +21,9 @@ def persist_s3(
|
|||
*,
|
||||
output: str,
|
||||
bucket: str,
|
||||
endpoint_url: str = None,
|
||||
profile_name: str = None,
|
||||
stage_source: Image.Image = None,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
|
|
@ -5,6 +5,7 @@ from PIL import Image
|
|||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +19,7 @@ def reduce_crop(
|
|||
*,
|
||||
origin: Size,
|
||||
size: Size,
|
||||
stage_source: Image.Image = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
|
|||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -21,7 +22,7 @@ def source_txt2img(
|
|||
_source: Image.Image,
|
||||
*,
|
||||
size: Size,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
|
|
|
@ -31,7 +31,7 @@ def upscale_outpaint(
|
|||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
|
|
@ -9,12 +9,10 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = (None, None)
|
||||
|
||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||
|
||||
|
||||
|
@ -104,7 +102,7 @@ def upscale_resrgan(
|
|||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Image.Image = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
|
|
@ -13,6 +13,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -70,8 +71,8 @@ def upscale_stable_diffusion(
|
|||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Image.Image = None,
|
||||
callback: ProgressCallback = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
|
|
|
@ -110,3 +110,6 @@ def process_tile_order(
|
|||
elif order == TileOrder.spiral:
|
||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
||||
else:
|
||||
logger.warn("unknown tile order: %s", order)
|
||||
raise ValueError()
|
||||
|
|
|
@ -3,7 +3,7 @@ from argparse import ArgumentParser
|
|||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jsonschema import ValidationError, validate
|
||||
|
@ -36,7 +36,7 @@ warnings.filterwarnings(
|
|||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
|
||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||
Models = Dict[str, List[Any]]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
|||
class ConversionContext(ServerContext):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
model_path: str,
|
||||
cache_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
half: Optional[bool] = False,
|
||||
|
@ -31,7 +31,7 @@ class ConversionContext(ServerContext):
|
|||
token: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
||||
super().__init__(model_path=model_path, cache_path=cache_path)
|
||||
|
||||
self.half = half
|
||||
self.opset = opset
|
||||
|
@ -153,7 +153,7 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
return model["format"]
|
||||
|
||||
if "source" in model:
|
||||
ext = path.splitext(model["source"])
|
||||
_name, ext = path.splitext(model["source"])
|
||||
if ext in model_formats:
|
||||
return ext
|
||||
|
||||
|
@ -183,7 +183,7 @@ class Config(object):
|
|||
setattr(target, k, v)
|
||||
|
||||
|
||||
def load_yaml(file: str) -> str:
|
||||
def load_yaml(file: str) -> Config:
|
||||
with open(file, "r") as f:
|
||||
data = safe_load(f.read())
|
||||
return Config(data)
|
||||
|
|
|
@ -6,6 +6,7 @@ from diffusers.utils.logging import disable_progress_bar
|
|||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||
from setproctitle import setproctitle
|
||||
from torch.multiprocessing import set_start_method
|
||||
|
||||
from .server.api import register_api_routes
|
||||
|
@ -26,6 +27,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def main():
|
||||
setproctitle("onnx-web server")
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
|
|
|
@ -15,7 +15,7 @@ from .utils import base_join
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def hash_value(sha, param: Param):
|
||||
def hash_value(sha, param: Optional[Param]):
|
||||
if param is None:
|
||||
return
|
||||
elif isinstance(param, bool):
|
||||
|
@ -63,7 +63,7 @@ def make_output_name(
|
|||
mode: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[Tuple[Param]] = None,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
) -> List[str]:
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
|
|
@ -101,12 +101,12 @@ class DeviceParams:
|
|||
self.device = device
|
||||
self.provider = provider
|
||||
self.options = options
|
||||
self.optimizations = optimizations
|
||||
self.optimizations = optimizations or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||
|
||||
def ort_provider(self) -> Tuple[str, Any]:
|
||||
def ort_provider(self) -> Union[str, Tuple[str, Any]]:
|
||||
if self.options is None:
|
||||
return self.provider
|
||||
else:
|
||||
|
|
|
@ -81,7 +81,7 @@ def introspect(context: ServerContext, app: Flask):
|
|||
return {
|
||||
"name": "onnx-web",
|
||||
"routes": [
|
||||
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
||||
{"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()}
|
||||
for rule in app.url_map.iter_rules()
|
||||
],
|
||||
}
|
||||
|
@ -119,10 +119,10 @@ def list_schedulers(context: ServerContext):
|
|||
|
||||
|
||||
def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||
if "source" not in request.files:
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source_file = request.files.get("source")
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
|
@ -136,7 +136,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
||||
output = make_output_name(context, "img2img", params, size, extras=[strength])
|
||||
job_name = output[0]
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
|
@ -179,16 +179,15 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
|
||||
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||
if "source" not in request.files:
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
||||
if "mask" not in request.files:
|
||||
mask_file = request.files.get("mask")
|
||||
if mask_file is None:
|
||||
return error_reply("mask image is required")
|
||||
|
||||
source_file = request.files.get("source")
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
mask_file = request.files.get("mask")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
|
@ -207,7 +206,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
"inpaint",
|
||||
params,
|
||||
size,
|
||||
extras=(
|
||||
extras=[
|
||||
expand.left,
|
||||
expand.right,
|
||||
expand.top,
|
||||
|
@ -216,7 +215,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
noise_source.__name__,
|
||||
fill_color,
|
||||
tile_order,
|
||||
),
|
||||
],
|
||||
)
|
||||
job_name = output[0]
|
||||
logger.info("inpaint job queued for: %s", job_name)
|
||||
|
@ -245,10 +244,10 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
|
||||
def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
||||
if "source" not in request.files:
|
||||
source_file = request.files.get("source")
|
||||
if source_file is None:
|
||||
return error_reply("source image is required")
|
||||
|
||||
source_file = request.files.get("source")
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(context)
|
||||
|
@ -324,6 +323,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
stage.name,
|
||||
)
|
||||
source_file = request.files.get(stage_source_name)
|
||||
if source_file is not None:
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
source = valid_image(source, max_dims=(size.width, size.height))
|
||||
kwargs["stage_source"] = source
|
||||
|
@ -335,6 +335,7 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
stage.name,
|
||||
)
|
||||
mask_file = request.files.get(stage_mask_name)
|
||||
if mask_file is not None:
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
mask = valid_image(mask, max_dims=(size.width, size.height))
|
||||
kwargs["stage_mask"] = mask
|
||||
|
@ -360,10 +361,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
|
||||
def blend(context: ServerContext, pool: DevicePoolExecutor):
|
||||
if "mask" not in request.files:
|
||||
mask_file = request.files.get("mask")
|
||||
if mask_file is None:
|
||||
return error_reply("mask image is required")
|
||||
|
||||
mask_file = request.files.get("mask")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
mask = valid_image(mask)
|
||||
|
||||
|
@ -372,6 +373,9 @@ def blend(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
for i in range(max_sources):
|
||||
source_file = request.files.get("source:%s" % (i))
|
||||
if source_file is None:
|
||||
logger.warning("missing source %s", i)
|
||||
else:
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||
source = valid_image(source, mask.size, mask.size)
|
||||
sources.append(source)
|
||||
|
@ -403,10 +407,11 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
|||
device, params, size = pipeline_from_request(context)
|
||||
|
||||
output = make_output_name(context, "txt2txt", params, size)
|
||||
logger.info("upscale job queued for: %s", output)
|
||||
job_name = output[0]
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
output,
|
||||
job_name,
|
||||
run_txt2txt_pipeline,
|
||||
context,
|
||||
params,
|
||||
|
@ -420,6 +425,8 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
||||
cancel = pool.cancel(output_file)
|
||||
|
||||
|
@ -428,6 +435,8 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
||||
done, progress = pool.done(output_file)
|
||||
|
||||
|
@ -436,7 +445,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
if path.exists(output):
|
||||
return ready_reply(True)
|
||||
|
||||
return ready_reply(done, progress=progress)
|
||||
return ready_reply(done or False, progress=progress)
|
||||
|
||||
|
||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from ..utils import get_boolean
|
||||
from .model_cache import ModelCache
|
||||
|
@ -18,13 +18,13 @@ class ServerContext:
|
|||
cors_origin: str = "*",
|
||||
num_workers: int = 1,
|
||||
any_platform: bool = True,
|
||||
block_platforms: List[str] = None,
|
||||
default_platform: str = None,
|
||||
block_platforms: Optional[List[str]] = None,
|
||||
default_platform: Optional[str] = None,
|
||||
image_format: str = "png",
|
||||
cache: ModelCache = None,
|
||||
cache_path: str = None,
|
||||
cache: Optional[ModelCache] = None,
|
||||
cache_path: Optional[str] = None,
|
||||
show_progress: bool = True,
|
||||
optimizations: List[str] = None,
|
||||
optimizations: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
|
|
@ -119,9 +119,8 @@ def patch_not_impl():
|
|||
|
||||
|
||||
def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
|
||||
if url in cache_path_map:
|
||||
cache_path = cache_path_map.get(url)
|
||||
else:
|
||||
cache_path = cache_path_map.get(url, None)
|
||||
if cache_path is None:
|
||||
parsed = urlparse(url)
|
||||
cache_path = path.basename(parsed.path)
|
||||
|
||||
|
|
|
@ -22,13 +22,13 @@ def run_txt2txt_pipeline(
|
|||
|
||||
device = job.get_device()
|
||||
|
||||
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device())
|
||||
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
|
||||
device.torch_device()
|
||||
device.torch_str()
|
||||
)
|
||||
results = model.generate(
|
||||
results = pipe.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
max_length=tokens,
|
||||
|
|
|
@ -12,6 +12,7 @@ from .chain import (
|
|||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .server import ServerContext
|
||||
from .worker import ProgressCallback, WorkerContext
|
||||
from typing import Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -24,7 +25,7 @@ def run_upscale_correction(
|
|||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
callback: ProgressCallback = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
|
|
|
@ -2,7 +2,7 @@ import gc
|
|||
import threading
|
||||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -36,7 +36,7 @@ def get_and_clamp_int(
|
|||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
||||
def get_from_list(args: Any, key: str, values: Sequence[Any]) -> Optional[Any]:
|
||||
selected = args.get(key, None)
|
||||
if selected in values:
|
||||
return selected
|
||||
|
@ -82,7 +82,7 @@ def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
|||
raise ValueError("invalid size")
|
||||
|
||||
|
||||
def run_gc(devices: List[DeviceParams] = None):
|
||||
def run_gc(devices: Optional[List[DeviceParams]] = None):
|
||||
logger.debug(
|
||||
"running garbage collection with %s active threads", threading.active_count()
|
||||
)
|
||||
|
|
|
@ -12,20 +12,20 @@ ProgressCallback = Callable[[int, int, Any], None]
|
|||
|
||||
|
||||
class WorkerContext:
|
||||
cancel: "Value[bool]" = None
|
||||
job: str = None
|
||||
pending: "Queue[Tuple[Callable, Any, Any]]" = None
|
||||
progress: "Value[int]" = None
|
||||
cancel: "Value[bool]"
|
||||
job: str
|
||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
||||
progress: "Value[int]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job: str,
|
||||
device: DeviceParams,
|
||||
cancel: "Value[bool]" = None,
|
||||
logs: "Queue[str]" = None,
|
||||
pending: "Queue[Any]" = None,
|
||||
progress: "Queue[Tuple[str, int]]" = None,
|
||||
finished: "Queue[str]" = None,
|
||||
cancel: "Value[bool]",
|
||||
logs: "Queue[str]",
|
||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
||||
progress: "Queue[Tuple[str, str, int]]",
|
||||
finished: "Queue[Tuple[str, str]]",
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import Counter
|
|||
from logging import getLogger
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.multiprocessing import Process, Queue, Value
|
||||
|
||||
|
@ -15,23 +15,38 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
context: Dict[str, WorkerContext] = None # Device -> Context
|
||||
devices: List[DeviceParams] = None
|
||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||
workers: Dict[str, Process] = None
|
||||
active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus]
|
||||
finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus]
|
||||
server: ServerContext
|
||||
devices: List[DeviceParams]
|
||||
max_jobs_per_worker: int
|
||||
max_pending_per_worker: int
|
||||
join_timeout: float
|
||||
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
||||
threads: Dict[str, Thread]
|
||||
workers: Dict[str, Process]
|
||||
|
||||
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
||||
total_jobs: int
|
||||
|
||||
logs: "Queue"
|
||||
progress: "Queue[Tuple[str, str, int]]"
|
||||
finished: "Queue[Tuple[str, str]]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
devices: List[DeviceParams],
|
||||
max_jobs_per_worker: int = 10,
|
||||
max_pending_per_worker: int = 100,
|
||||
join_timeout: float = 1.0,
|
||||
):
|
||||
self.server = server
|
||||
self.devices = devices
|
||||
self.max_jobs_per_worker = max_jobs_per_worker
|
||||
self.max_pending_per_worker = max_pending_per_worker
|
||||
self.join_timeout = join_timeout
|
||||
|
||||
self.context = {}
|
||||
|
@ -44,9 +59,9 @@ class DevicePoolExecutor:
|
|||
self.finished_jobs = []
|
||||
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||
|
||||
self.logs = Queue()
|
||||
self.progress = Queue()
|
||||
self.finished = Queue()
|
||||
self.logs = Queue(self.max_pending_per_worker)
|
||||
self.progress = Queue(self.max_pending_per_worker)
|
||||
self.finished = Queue(self.max_pending_per_worker)
|
||||
|
||||
self.create_logger_worker()
|
||||
self.create_progress_worker()
|
||||
|
@ -67,7 +82,7 @@ class DevicePoolExecutor:
|
|||
pending = self.pending[name]
|
||||
else:
|
||||
logger.debug("creating new pending job queue")
|
||||
pending = Queue()
|
||||
pending = Queue(self.max_pending_per_worker)
|
||||
self.pending[name] = pending
|
||||
|
||||
context = WorkerContext(
|
||||
|
@ -80,7 +95,11 @@ class DevicePoolExecutor:
|
|||
pending=pending,
|
||||
)
|
||||
self.context[name] = context
|
||||
self.workers[name] = Process(target=worker_main, args=(context, self.server))
|
||||
self.workers[name] = Process(
|
||||
name=f"onnx-web worker: {name}",
|
||||
target=worker_main,
|
||||
args=(context, self.server),
|
||||
)
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name].start()
|
||||
|
@ -102,7 +121,9 @@ class DevicePoolExecutor:
|
|||
except Exception as err:
|
||||
logger.error("error in log worker: %s", err)
|
||||
|
||||
logger_thread = Thread(target=logger_worker, args=(self.logs,), daemon=True)
|
||||
logger_thread = Thread(
|
||||
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True
|
||||
)
|
||||
self.threads["logger"] = logger_thread
|
||||
|
||||
logger.debug("starting logger worker")
|
||||
|
@ -128,7 +149,12 @@ class DevicePoolExecutor:
|
|||
except Exception as err:
|
||||
logger.error("error in progress worker: %s", err)
|
||||
|
||||
progress_thread = Thread(target=progress_worker, args=(self.progress,), daemon=True)
|
||||
progress_thread = Thread(
|
||||
name="onnx-web progress",
|
||||
target=progress_worker,
|
||||
args=(self.progress,),
|
||||
daemon=True,
|
||||
)
|
||||
self.threads["progress"] = progress_thread
|
||||
|
||||
logger.debug("starting progress worker")
|
||||
|
@ -152,7 +178,12 @@ class DevicePoolExecutor:
|
|||
except Exception as err:
|
||||
logger.error("error in finished worker: %s", err)
|
||||
|
||||
finished_thread = Thread(target=finished_worker, args=(self.finished,), daemon=True)
|
||||
finished_thread = Thread(
|
||||
name="onnx-web finished",
|
||||
target=finished_worker,
|
||||
args=(self.finished,),
|
||||
daemon=True,
|
||||
)
|
||||
self.threads["finished"] = finished_thread
|
||||
|
||||
logger.debug("started finished worker")
|
||||
|
@ -221,8 +252,18 @@ class DevicePoolExecutor:
|
|||
return (False, progress)
|
||||
|
||||
def join(self):
|
||||
logger.debug("stopping worker pool")
|
||||
logger.info("stopping worker pool")
|
||||
|
||||
logger.debug("closing queues")
|
||||
self.logs.close()
|
||||
self.finished.close()
|
||||
self.progress.close()
|
||||
for queue in self.pending.values():
|
||||
queue.close()
|
||||
|
||||
self.pending.clear()
|
||||
|
||||
logger.debug("stopping device workers")
|
||||
for device, worker in self.workers.items():
|
||||
if worker.is_alive():
|
||||
logger.debug("stopping worker for device %s", device)
|
||||
|
@ -235,15 +276,6 @@ class DevicePoolExecutor:
|
|||
logger.debug("stopping worker thread: %s", name)
|
||||
thread.join(self.join_timeout)
|
||||
|
||||
logger.debug("closing queues")
|
||||
self.logs.close()
|
||||
self.finished.close()
|
||||
self.progress.close()
|
||||
for queue in self.pending.values():
|
||||
queue.close()
|
||||
|
||||
self.pending.clear()
|
||||
|
||||
logger.debug("worker pool fully joined")
|
||||
|
||||
def recycle(self):
|
||||
|
@ -292,7 +324,7 @@ class DevicePoolExecutor:
|
|||
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
||||
history = [
|
||||
(name, progress, False, name in self.cancelled_jobs)
|
||||
for name, _device, progress in self.active_jobs.items()
|
||||
for name, (_device, progress) in self.active_jobs.items()
|
||||
]
|
||||
history.extend(
|
||||
[
|
||||
|
|
|
@ -4,7 +4,6 @@ from sys import exit
|
|||
from traceback import format_exception
|
||||
|
||||
from setproctitle import setproctitle
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from ..server import ServerContext, apply_patches
|
||||
from ..torch_before_ort import get_available_providers
|
||||
|
|
|
@ -8,16 +8,35 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
|
|||
|
||||
[tool.mypy]
|
||||
# ignore_missing_imports = true
|
||||
exclude = [
|
||||
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
|
||||
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"basicsr.archs.rrdbnet_arch",
|
||||
"basicsr.utils.download_util",
|
||||
"basicsr.utils",
|
||||
"basicsr",
|
||||
"boto3",
|
||||
"codeformer",
|
||||
"codeformer.facelib.utils.misc",
|
||||
"codeformer.facelib.utils",
|
||||
"codeformer.facelib",
|
||||
"diffusers",
|
||||
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
||||
"diffusers.pipelines.paint_by_example",
|
||||
"diffusers.pipelines.stable_diffusion",
|
||||
"diffusers.pipeline_utils",
|
||||
"diffusers.utils.logging",
|
||||
"facexlib.utils",
|
||||
"facexlib",
|
||||
"gfpgan",
|
||||
"onnxruntime",
|
||||
"realesrgan"
|
||||
"realesrgan",
|
||||
"realesrgan.archs.srvgg_arch",
|
||||
"safetensors",
|
||||
"transformers"
|
||||
]
|
||||
ignore_missing_imports = true
|
|
@ -4,7 +4,7 @@ test_images=0
|
|||
while true;
|
||||
do
|
||||
curl "http://${test_host}:5000/api/txt2img?"\
|
||||
'cfg=16.00&steps=3&scheduler=deis-multi&seed=-1&'\
|
||||
'cfg=16.00&steps=3&scheduler=ddim&seed=-1&'\
|
||||
'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\
|
||||
'model=stable-diffusion-onnx-v1-5&platform=any&'\
|
||||
'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\
|
||||
|
|
Loading…
Reference in New Issue