1
0
Fork 0

feat(api): add flag to disable progress bars (#158)

This commit is contained in:
Sean Sube 2023-02-18 09:25:01 -06:00
parent 400e579491
commit b4e66ef502
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 26 additions and 13 deletions

View File

@ -2,8 +2,8 @@ from logging import getLogger
from typing import List, Optional
from PIL import Image
from onnx_web.image import valid_image
from onnx_web.image import valid_image
from onnx_web.output import save_image
from ..params import ImageParams, StageParams
@ -34,6 +34,9 @@ def blend_mask(
save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask)
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
resized = [
valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size)
for s in resized
]
return Image.composite(resized[0], resized[1], mult_mask)

View File

@ -22,7 +22,6 @@ import traceback
from logging import getLogger
from typing import Dict, List
import huggingface_hub.utils.tqdm
import torch
from diffusers import (
AutoencoderKL,
@ -45,6 +44,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils.tqdm import tqdm
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@ -1054,9 +1054,8 @@ def download_model(db_config: TrainingConfig, token):
logger.debug("nothing to fetch")
return None, None
mytqdm = huggingface_hub.utils.tqdm.tqdm
out_model = None
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
for repo_file in tqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
out = hf_hub_download(
hub_url,
filename=repo_file,

View File

@ -8,9 +8,11 @@ from typing import Dict, List, Optional, Tuple, Union
import requests
import safetensors
import torch
from tqdm.auto import tqdm
from huggingface_hub.utils.tqdm import tqdm
from yaml import safe_load
from ..utils import ServerContext
logger = getLogger(__name__)
@ -18,7 +20,7 @@ ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext:
class ConversionContext(ServerContext):
def __init__(
self,
model_path: Optional[str] = None,
@ -28,10 +30,8 @@ class ConversionContext:
opset: Optional[int] = None,
token: Optional[str] = None,
) -> None:
self.model_path = model_path or environ.get(
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
)
self.cache_path = cache_path or path.join(self.model_path, ".cache")
super().__init__(self, model_path=model_path, cache_path=cache_path)
self.half = half
self.opset = opset
self.token = token

View File

@ -1,7 +1,8 @@
from typing import Tuple, Union
import numpy as np
from numpy import random
from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple, Union
from .params import Border, Point, Size

View File

@ -9,6 +9,7 @@ from typing import Dict, List, Tuple, Union
import numpy as np
import torch
import yaml
from diffusers.utils.logging import disable_progress_bar
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from jsonschema import validate
@ -388,6 +389,9 @@ load_models(context)
load_params(context)
load_platforms(context)
if not context.show_progress:
disable_progress_bar()
app = Flask(__name__)
CORS(app, origins=context.cors_origin)

View File

@ -27,6 +27,7 @@ class ServerContext:
image_format: str = "png",
cache: ModelCache = None,
cache_path: str = None,
show_progress: bool = True,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -40,6 +41,7 @@ class ServerContext:
self.image_format = image_format
self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
@classmethod
def from_environ(cls):
@ -61,6 +63,7 @@ class ServerContext:
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
)

View File

@ -70,6 +70,9 @@ Others:
- `ONNX_WEB_NUM_WORKERS`
- number of background workers for image pipelines
- this should be equal to or less than the number of available GPUs
- `ONNX_WEB_SHOW_PROGRESS`
- show progress bars in the logs
- disabling this can reduce noise in server logs, especially when logging to a file
### Server Parameters