diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 39cb180f..3afedb07 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -36,14 +36,14 @@ from .server import ( DeviceParams, DevicePoolExecutor, ModelCache, + ServerContext, apply_patch_basicsr, apply_patch_codeformer, apply_patch_facexlib, apply_patches, - run_upscale_correction, ) +from .upscale import run_upscale_correction from .utils import ( - ServerContext, base_join, get_and_clamp_float, get_and_clamp_int, diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index b144b546..702981a5 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,8 +7,8 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext, is_debug +from ..server import JobContext, ProgressCallback, ServerContext +from ..utils import is_debug from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0a26650a..525f2507 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -7,8 +7,7 @@ from PIL import Image from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext +from ..server import JobContext, ProgressCallback, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 783f9fb4..1234aef7 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,8 +10,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext, is_debug +from ..server import JobContext, ProgressCallback, ServerContext +from ..utils import is_debug from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 521e5379..1fbe80ef 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -7,8 +7,8 @@ from onnx_web.image import valid_image from onnx_web.output import save_image from ..params import ImageParams, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext, is_debug +from ..server import JobContext, ProgressCallback, ServerContext +from ..utils import is_debug logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index ef157162..f6b26203 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -3,8 +3,7 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index ab2de54b..c9efc96a 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,8 +5,8 @@ import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server.device_pool import JobContext -from ..utils import ServerContext, run_gc +from ..server import JobContext, ServerContext +from ..utils import run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 28986a98..a8162607 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,8 +4,7 @@ from PIL import Image from ..output import save_image from ..params import ImageParams, StageParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 8b37eb69..c3a889fa 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,8 +5,7 @@ from boto3 import Session from PIL import Image from ..params import ImageParams, StageParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index efd8367f..43debc83 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -3,8 +3,7 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 2633309f..c5d143b5 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,8 +3,7 @@ from logging import getLogger from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 59c4921e..8135de81 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,8 +4,7 @@ from typing import Callable from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server.device_pool import JobContext -from ..utils import ServerContext +from ..server import JobContext, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 56be4686..b5e0dd1c 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -7,8 +7,7 @@ from PIL import Image from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext +from ..server import JobContext, ProgressCallback, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 875b9052..14a20cf3 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,8 +10,8 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext, is_debug +from ..server import JobContext, ProgressCallback, ServerContext +from ..utils import is_debug from .utils import process_tile_grid, process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index b7ee8569..a57a3100 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -6,8 +6,8 @@ from PIL import Image from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server.device_pool import JobContext -from ..utils import ServerContext, run_gc +from ..server import JobContext, ServerContext +from ..utils import run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 4ff0e245..851bdb86 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,8 +10,8 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server.device_pool import JobContext, ProgressCallback -from ..utils import ServerContext, run_gc +from ..server import JobContext, ProgressCallback, ServerContext +from ..utils import run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index f7c41cb5..24850252 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -11,7 +11,7 @@ import torch from huggingface_hub.utils.tqdm import tqdm from yaml import safe_load -from ..utils import ServerContext +from ..server import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 8516fe44..33978759 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -26,7 +26,8 @@ except ImportError: from .stub_scheduler import StubScheduler as DEISMultistepScheduler from ..params import DeviceParams, Size -from ..utils import ServerContext, run_gc +from ..server import ServerContext +from ..utils import run_gc logger = getLogger(__name__) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 9b9effb9..d42ff51c 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -11,10 +11,10 @@ from onnx_web.chain.base import ChainProgress from ..chain import upscale_outpaint from ..output import save_image, save_params -from ..params import Border, ImageParams, Size, StageParams -from ..server.device_pool import JobContext -from ..server.upscale import UpscaleParams, run_upscale_correction -from ..utils import ServerContext, run_gc +from ..params import Border, ImageParams, Size, StageParams, UpscaleParams +from ..server import JobContext, ServerContext +from ..upscale import run_upscale_correction +from ..utils import run_gc from .load import get_latents_from_seed, load_pipeline logger = getLogger(__name__) diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index bf0cf524..97d5c8b0 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -5,7 +5,7 @@ import numpy as np import torch from onnxruntime import InferenceSession, SessionOptions -from ..utils import ServerContext +from ..server import ServerContext class OnnxTensor: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 645b6ad3..f01b0be0 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -10,7 +10,8 @@ from PIL import Image from .diffusion.load import get_scheduler_name from .params import Border, ImageParams, Param, Size, UpscaleParams -from .utils import ServerContext, base_join +from .server import ServerContext +from .utils import base_join logger = getLogger(__name__) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index d4ba6c55..3be531cf 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -63,10 +63,8 @@ from .params import ( TileOrder, UpscaleParams, ) -from .server.device_pool import DevicePoolExecutor -from .server.hacks import apply_patches +from .server import DevicePoolExecutor, ServerContext, apply_patches from .utils import ( - ServerContext, base_join, get_and_clamp_float, get_and_clamp_int, diff --git a/api/onnx_web/server/__init__.py b/api/onnx_web/server/__init__.py index 81c49cbf..0403746c 100644 --- a/api/onnx_web/server/__init__.py +++ b/api/onnx_web/server/__init__.py @@ -1,4 +1,10 @@ -from .device_pool import DeviceParams, DevicePoolExecutor +from .device_pool import ( + DeviceParams, + DevicePoolExecutor, + Job, + JobContext, + ProgressCallback, +) from .hacks import ( apply_patch_basicsr, apply_patch_codeformer, @@ -6,4 +12,4 @@ from .hacks import ( apply_patches, ) from .model_cache import ModelCache -from .upscale import run_upscale_correction +from .context import ServerContext diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py new file mode 100644 index 00000000..a3e60f78 --- /dev/null +++ b/api/onnx_web/server/context.py @@ -0,0 +1,66 @@ +from logging import getLogger +from os import environ, path +from typing import List + +from ..utils import get_boolean +from .model_cache import ModelCache + +logger = getLogger(__name__) + + +class ServerContext: + def __init__( + self, + bundle_path: str = ".", + model_path: str = ".", + output_path: str = ".", + params_path: str = ".", + cors_origin: str = "*", + num_workers: int = 1, + any_platform: bool = True, + block_platforms: List[str] = [], + default_platform: str = None, + image_format: str = "png", + cache: ModelCache = None, + cache_path: str = None, + show_progress: bool = True, + optimizations: List[str] = [], + ) -> None: + self.bundle_path = bundle_path + self.model_path = model_path + self.output_path = output_path + self.params_path = params_path + self.cors_origin = cors_origin + self.num_workers = num_workers + self.any_platform = any_platform + self.block_platforms = block_platforms + self.default_platform = default_platform + self.image_format = image_format + self.cache = cache or ModelCache(num_workers) + self.cache_path = cache_path or path.join(model_path, ".cache") + self.show_progress = show_progress + self.optimizations = optimizations + + @classmethod + def from_environ(cls): + num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1)) + cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2)) + + return cls( + bundle_path=environ.get( + "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") + ), + model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), + output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), + params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), + # others + cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), + num_workers=num_workers, + any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), + block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), + default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), + image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), + cache=ModelCache(limit=cache_limit), + show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), + optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), + ) diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 0835b5ea..4b7ac0b9 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -8,7 +8,7 @@ import basicsr.utils.download_util import codeformer.facelib.utils.misc import facexlib.utils -from ..utils import ServerContext +from .context import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/server/upscale.py b/api/onnx_web/upscale.py similarity index 93% rename from api/onnx_web/server/upscale.py rename to api/onnx_web/upscale.py index 725ae7dc..c04d6efb 100644 --- a/api/onnx_web/server/upscale.py +++ b/api/onnx_web/upscale.py @@ -2,16 +2,15 @@ from logging import getLogger from PIL import Image -from ..chain import ( +from .chain import ( ChainPipeline, correct_codeformer, correct_gfpgan, upscale_resrgan, upscale_stable_diffusion, ) -from ..params import ImageParams, SizeChart, StageParams, UpscaleParams -from ..utils import ServerContext -from .device_pool import JobContext, ProgressCallback +from .params import ImageParams, SizeChart, StageParams, UpscaleParams +from .server import JobContext, ProgressCallback, ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 20d9d32c..1d6464b1 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -7,69 +7,10 @@ from typing import Any, Dict, List, Optional, Union import torch from .params import DeviceParams, SizeChart -from .server.model_cache import ModelCache logger = getLogger(__name__) -class ServerContext: - def __init__( - self, - bundle_path: str = ".", - model_path: str = ".", - output_path: str = ".", - params_path: str = ".", - cors_origin: str = "*", - num_workers: int = 1, - any_platform: bool = True, - block_platforms: List[str] = [], - default_platform: str = None, - image_format: str = "png", - cache: ModelCache = None, - cache_path: str = None, - show_progress: bool = True, - optimizations: List[str] = [], - ) -> None: - self.bundle_path = bundle_path - self.model_path = model_path - self.output_path = output_path - self.params_path = params_path - self.cors_origin = cors_origin - self.num_workers = num_workers - self.any_platform = any_platform - self.block_platforms = block_platforms - self.default_platform = default_platform - self.image_format = image_format - self.cache = cache or ModelCache(num_workers) - self.cache_path = cache_path or path.join(model_path, ".cache") - self.show_progress = show_progress - self.optimizations = optimizations - - @classmethod - def from_environ(cls): - num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1)) - cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2)) - - return cls( - bundle_path=environ.get( - "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") - ), - model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), - output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), - params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - # others - cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), - num_workers=num_workers, - any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), - block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), - default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), - cache=ModelCache(limit=cache_limit), - show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), - optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), - ) - - def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") return path.join(base, tail_path)