1
0
Fork 0

feat(api): add flag to disable progress bars (#158)

This commit is contained in:
Sean Sube 2023-02-18 09:25:01 -06:00
parent 400e579491
commit b4e66ef502
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 26 additions and 13 deletions

View File

@ -2,8 +2,8 @@ from logging import getLogger
from typing import List, Optional from typing import List, Optional
from PIL import Image from PIL import Image
from onnx_web.image import valid_image
from onnx_web.image import valid_image
from onnx_web.output import save_image from onnx_web.output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
@ -34,6 +34,9 @@ def blend_mask(
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized] resized = [
valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size)
for s in resized
]
return Image.composite(resized[0], resized[1], mult_mask) return Image.composite(resized[0], resized[1], mult_mask)

View File

@ -22,7 +22,6 @@ import traceback
from logging import getLogger from logging import getLogger
from typing import Dict, List from typing import Dict, List
import huggingface_hub.utils.tqdm
import torch import torch
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
@ -45,6 +44,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils.tqdm import tqdm
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
@ -1054,9 +1054,8 @@ def download_model(db_config: TrainingConfig, token):
logger.debug("nothing to fetch") logger.debug("nothing to fetch")
return None, None return None, None
mytqdm = huggingface_hub.utils.tqdm.tqdm
out_model = None out_model = None
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"): for repo_file in tqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
out = hf_hub_download( out = hf_hub_download(
hub_url, hub_url,
filename=repo_file, filename=repo_file,

View File

@ -8,9 +8,11 @@ from typing import Dict, List, Optional, Tuple, Union
import requests import requests
import safetensors import safetensors
import torch import torch
from tqdm.auto import tqdm from huggingface_hub.utils.tqdm import tqdm
from yaml import safe_load from yaml import safe_load
from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,7 +20,7 @@ ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext: class ConversionContext(ServerContext):
def __init__( def __init__(
self, self,
model_path: Optional[str] = None, model_path: Optional[str] = None,
@ -28,10 +30,8 @@ class ConversionContext:
opset: Optional[int] = None, opset: Optional[int] = None,
token: Optional[str] = None, token: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path or environ.get( super().__init__(self, model_path=model_path, cache_path=cache_path)
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
)
self.cache_path = cache_path or path.join(self.model_path, ".cache")
self.half = half self.half = half
self.opset = opset self.opset = opset
self.token = token self.token = token

View File

@ -1,7 +1,8 @@
from typing import Tuple, Union
import numpy as np import numpy as np
from numpy import random from numpy import random
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple, Union
from .params import Border, Point, Size from .params import Border, Point, Size
@ -210,4 +211,4 @@ def valid_image(
# check for square # check for square
return image return image

View File

@ -9,6 +9,7 @@ from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
from diffusers.utils.logging import disable_progress_bar
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS from flask_cors import CORS
from jsonschema import validate from jsonschema import validate
@ -388,6 +389,9 @@ load_models(context)
load_params(context) load_params(context)
load_platforms(context) load_platforms(context)
if not context.show_progress:
disable_progress_bar()
app = Flask(__name__) app = Flask(__name__)
CORS(app, origins=context.cors_origin) CORS(app, origins=context.cors_origin)

View File

@ -27,6 +27,7 @@ class ServerContext:
image_format: str = "png", image_format: str = "png",
cache: ModelCache = None, cache: ModelCache = None,
cache_path: str = None, cache_path: str = None,
show_progress: bool = True,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -40,6 +41,7 @@ class ServerContext:
self.image_format = image_format self.image_format = image_format
self.cache = cache or ModelCache(num_workers) self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache") self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
@classmethod @classmethod
def from_environ(cls): def from_environ(cls):
@ -61,6 +63,7 @@ class ServerContext:
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit), cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
) )

View File

@ -70,6 +70,9 @@ Others:
- `ONNX_WEB_NUM_WORKERS` - `ONNX_WEB_NUM_WORKERS`
- number of background workers for image pipelines - number of background workers for image pipelines
- this should be equal to or less than the number of available GPUs - this should be equal to or less than the number of available GPUs
- `ONNX_WEB_SHOW_PROGRESS`
- show progress bars in the logs
- disabling this can reduce noise in server logs, especially when logging to a file
### Server Parameters ### Server Parameters