feat(api): add flag to disable progress bars (#158)
This commit is contained in:
parent
400e579491
commit
b4e66ef502
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from onnx_web.image import valid_image
|
||||
|
||||
from onnx_web.image import valid_image
|
||||
from onnx_web.output import save_image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
|
@ -34,6 +34,9 @@ def blend_mask(
|
|||
save_image(server, "last-mask.png", mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
|
||||
resized = [
|
||||
valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size)
|
||||
for s in resized
|
||||
]
|
||||
|
||||
return Image.composite(resized[0], resized[1], mult_mask)
|
||||
|
|
|
@ -22,7 +22,6 @@ import traceback
|
|||
from logging import getLogger
|
||||
from typing import Dict, List
|
||||
|
||||
import huggingface_hub.utils.tqdm
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
@ -45,6 +44,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
|||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.utils.tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
|
@ -1054,9 +1054,8 @@ def download_model(db_config: TrainingConfig, token):
|
|||
logger.debug("nothing to fetch")
|
||||
return None, None
|
||||
|
||||
mytqdm = huggingface_hub.utils.tqdm.tqdm
|
||||
out_model = None
|
||||
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
||||
for repo_file in tqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
||||
out = hf_hub_download(
|
||||
hub_url,
|
||||
filename=repo_file,
|
||||
|
|
|
@ -8,9 +8,11 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import requests
|
||||
import safetensors
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from huggingface_hub.utils.tqdm import tqdm
|
||||
from yaml import safe_load
|
||||
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -18,7 +20,7 @@ ModelDict = Dict[str, Union[str, int]]
|
|||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||
|
||||
|
||||
class ConversionContext:
|
||||
class ConversionContext(ServerContext):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
|
@ -28,10 +30,8 @@ class ConversionContext:
|
|||
opset: Optional[int] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path or environ.get(
|
||||
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
|
||||
)
|
||||
self.cache_path = cache_path or path.join(self.model_path, ".cache")
|
||||
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
||||
|
||||
self.half = half
|
||||
self.opset = opset
|
||||
self.token = token
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from typing import Tuple, Union
|
||||
|
||||
from .params import Border, Point, Size
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Dict, List, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from diffusers.utils.logging import disable_progress_bar
|
||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||
from flask_cors import CORS
|
||||
from jsonschema import validate
|
||||
|
@ -388,6 +389,9 @@ load_models(context)
|
|||
load_params(context)
|
||||
load_platforms(context)
|
||||
|
||||
if not context.show_progress:
|
||||
disable_progress_bar()
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, origins=context.cors_origin)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class ServerContext:
|
|||
image_format: str = "png",
|
||||
cache: ModelCache = None,
|
||||
cache_path: str = None,
|
||||
show_progress: bool = True,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -40,6 +41,7 @@ class ServerContext:
|
|||
self.image_format = image_format
|
||||
self.cache = cache or ModelCache(num_workers)
|
||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||
self.show_progress = show_progress
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -61,6 +63,7 @@ class ServerContext:
|
|||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
cache=ModelCache(limit=cache_limit),
|
||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -70,6 +70,9 @@ Others:
|
|||
- `ONNX_WEB_NUM_WORKERS`
|
||||
- number of background workers for image pipelines
|
||||
- this should be equal to or less than the number of available GPUs
|
||||
- `ONNX_WEB_SHOW_PROGRESS`
|
||||
- show progress bars in the logs
|
||||
- disabling this can reduce noise in server logs, especially when logging to a file
|
||||
|
||||
### Server Parameters
|
||||
|
||||
|
|
Loading…
Reference in New Issue