feat(api): add flag to disable progress bars (#158)
This commit is contained in:
parent
400e579491
commit
b4e66ef502
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from onnx_web.image import valid_image
|
|
||||||
|
|
||||||
|
from onnx_web.image import valid_image
|
||||||
from onnx_web.output import save_image
|
from onnx_web.output import save_image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
@ -34,6 +34,9 @@ def blend_mask(
|
||||||
save_image(server, "last-mask.png", mask)
|
save_image(server, "last-mask.png", mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]
|
resized = [
|
||||||
|
valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size)
|
||||||
|
for s in resized
|
||||||
|
]
|
||||||
|
|
||||||
return Image.composite(resized[0], resized[1], mult_mask)
|
return Image.composite(resized[0], resized[1], mult_mask)
|
||||||
|
|
|
@ -22,7 +22,6 @@ import traceback
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import huggingface_hub.utils.tqdm
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
@ -45,6 +44,7 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
from huggingface_hub.utils.tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
|
@ -1054,9 +1054,8 @@ def download_model(db_config: TrainingConfig, token):
|
||||||
logger.debug("nothing to fetch")
|
logger.debug("nothing to fetch")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
mytqdm = huggingface_hub.utils.tqdm.tqdm
|
|
||||||
out_model = None
|
out_model = None
|
||||||
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
for repo_file in tqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
||||||
out = hf_hub_download(
|
out = hf_hub_download(
|
||||||
hub_url,
|
hub_url,
|
||||||
filename=repo_file,
|
filename=repo_file,
|
||||||
|
|
|
@ -8,9 +8,11 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||||
import requests
|
import requests
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from huggingface_hub.utils.tqdm import tqdm
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +20,7 @@ ModelDict = Dict[str, Union[str, int]]
|
||||||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||||
|
|
||||||
|
|
||||||
class ConversionContext:
|
class ConversionContext(ServerContext):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: Optional[str] = None,
|
model_path: Optional[str] = None,
|
||||||
|
@ -28,10 +30,8 @@ class ConversionContext:
|
||||||
opset: Optional[int] = None,
|
opset: Optional[int] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_path = model_path or environ.get(
|
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
||||||
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
|
|
||||||
)
|
|
||||||
self.cache_path = cache_path or path.join(self.model_path, ".cache")
|
|
||||||
self.half = half
|
self.half = half
|
||||||
self.opset = opset
|
self.opset = opset
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import random
|
from numpy import random
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
from .params import Border, Point, Size
|
from .params import Border, Point, Size
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Dict, List, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
from diffusers.utils.logging import disable_progress_bar
|
||||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
|
@ -388,6 +389,9 @@ load_models(context)
|
||||||
load_params(context)
|
load_params(context)
|
||||||
load_platforms(context)
|
load_platforms(context)
|
||||||
|
|
||||||
|
if not context.show_progress:
|
||||||
|
disable_progress_bar()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app, origins=context.cors_origin)
|
CORS(app, origins=context.cors_origin)
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ class ServerContext:
|
||||||
image_format: str = "png",
|
image_format: str = "png",
|
||||||
cache: ModelCache = None,
|
cache: ModelCache = None,
|
||||||
cache_path: str = None,
|
cache_path: str = None,
|
||||||
|
show_progress: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -40,6 +41,7 @@ class ServerContext:
|
||||||
self.image_format = image_format
|
self.image_format = image_format
|
||||||
self.cache = cache or ModelCache(num_workers)
|
self.cache = cache or ModelCache(num_workers)
|
||||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||||
|
self.show_progress = show_progress
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -61,6 +63,7 @@ class ServerContext:
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
cache=ModelCache(limit=cache_limit),
|
cache=ModelCache(limit=cache_limit),
|
||||||
|
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,9 @@ Others:
|
||||||
- `ONNX_WEB_NUM_WORKERS`
|
- `ONNX_WEB_NUM_WORKERS`
|
||||||
- number of background workers for image pipelines
|
- number of background workers for image pipelines
|
||||||
- this should be equal to or less than the number of available GPUs
|
- this should be equal to or less than the number of available GPUs
|
||||||
|
- `ONNX_WEB_SHOW_PROGRESS`
|
||||||
|
- show progress bars in the logs
|
||||||
|
- disabling this can reduce noise in server logs, especially when logging to a file
|
||||||
|
|
||||||
### Server Parameters
|
### Server Parameters
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue