diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 21333498..9beb2c26 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -2,8 +2,8 @@ from logging import getLogger from typing import List, Optional from PIL import Image -from onnx_web.image import valid_image +from onnx_web.image import valid_image from onnx_web.output import save_image from ..params import ImageParams, StageParams @@ -34,6 +34,9 @@ def blend_mask( save_image(server, "last-mask.png", mask) save_image(server, "last-mult-mask.png", mult_mask) - resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized] + resized = [ + valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) + for s in resized + ] return Image.composite(resized[0], resized[1], mult_mask) diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 45dd0ada..4aed40b9 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -22,7 +22,6 @@ import traceback from logging import getLogger from typing import Dict, List -import huggingface_hub.utils.tqdm import torch from diffusers import ( AutoencoderKL, @@ -45,6 +44,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub.utils.tqdm import tqdm from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -1054,9 +1054,8 @@ def download_model(db_config: TrainingConfig, token): logger.debug("nothing to fetch") return None, None - mytqdm = huggingface_hub.utils.tqdm.tqdm out_model = None - for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"): + for repo_file in tqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"): out = hf_hub_download( hub_url, filename=repo_file, diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 676e2663..0f63230d 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -8,9 +8,11 @@ from typing import Dict, List, Optional, Tuple, Union import requests import safetensors import torch -from tqdm.auto import tqdm +from huggingface_hub.utils.tqdm import tqdm from yaml import safe_load +from ..utils import ServerContext + logger = getLogger(__name__) @@ -18,7 +20,7 @@ ModelDict = Dict[str, Union[str, int]] LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] -class ConversionContext: +class ConversionContext(ServerContext): def __init__( self, model_path: Optional[str] = None, @@ -28,10 +30,8 @@ class ConversionContext: opset: Optional[int] = None, token: Optional[str] = None, ) -> None: - self.model_path = model_path or environ.get( - "ONNX_WEB_MODEL_PATH", path.join("..", "models") - ) - self.cache_path = cache_path or path.join(self.model_path, ".cache") + super().__init__(self, model_path=model_path, cache_path=cache_path) + self.half = half self.opset = opset self.token = token diff --git a/api/onnx_web/image.py b/api/onnx_web/image.py index be7153ea..bfe687d8 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image.py @@ -1,7 +1,8 @@ +from typing import Tuple, Union + import numpy as np from numpy import random from PIL import Image, ImageChops, ImageFilter, ImageOps -from typing import Tuple, Union from .params import Border, Point, Size @@ -210,4 +211,4 @@ def valid_image( # check for square - return image \ No newline at end of file + return image diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 1ae14860..3e75c041 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -9,6 +9,7 @@ from typing import Dict, List, Tuple, Union import numpy as np import torch import yaml +from diffusers.utils.logging import disable_progress_bar from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask_cors import CORS from jsonschema import validate @@ -388,6 +389,9 @@ load_models(context) load_params(context) load_platforms(context) +if not context.show_progress: + disable_progress_bar() + app = Flask(__name__) CORS(app, origins=context.cors_origin) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index a23198e4..8c2bb24e 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -27,6 +27,7 @@ class ServerContext: image_format: str = "png", cache: ModelCache = None, cache_path: str = None, + show_progress: bool = True, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -40,6 +41,7 @@ class ServerContext: self.image_format = image_format self.cache = cache or ModelCache(num_workers) self.cache_path = cache_path or path.join(model_path, ".cache") + self.show_progress = show_progress @classmethod def from_environ(cls): @@ -61,6 +63,7 @@ class ServerContext: default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), cache=ModelCache(limit=cache_limit), + show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), ) diff --git a/docs/server-admin.md b/docs/server-admin.md index f5c0eeae..23e4e07b 100644 --- a/docs/server-admin.md +++ b/docs/server-admin.md @@ -70,6 +70,9 @@ Others: - `ONNX_WEB_NUM_WORKERS` - number of background workers for image pipelines - this should be equal to or less than the number of available GPUs +- `ONNX_WEB_SHOW_PROGRESS` + - show progress bars in the logs + - disabling this can reduce noise in server logs, especially when logging to a file ### Server Parameters