from os import environ, path from time import time from struct import pack from typing import Any, Dict, List, Optional, Tuple from hashlib import sha256 from .params import ( ImageParams, Param, Size, ) class ServerContext: def __init__( self, bundle_path: str = '.', model_path: str = '.', output_path: str = '.', params_path: str = '.', cors_origin: str = '*', num_workers: int = 1, block_platforms: List[str] = [], ) -> None: self.bundle_path = bundle_path self.model_path = model_path self.output_path = output_path self.params_path = params_path self.cors_origin = cors_origin self.num_workers = num_workers self.block_platforms = block_platforms @classmethod def from_environ(cls): return ServerContext( bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH', path.join('..', 'gui', 'out')), model_path=environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models')), output_path=environ.get( 'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')), params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'), # others cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','), num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)), block_platforms=environ.get( 'ONNX_WEB_BLOCK_PLATFORMS', '').split(',') ) def is_debug() -> bool: return environ.get('DEBUG') is not None def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: return min(max(float(args.get(key, default_value)), min_value), max_value) def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int: return min(max(int(args.get(key, default_value)), min_value), max_value) def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]: selected = args.get(key, None) if selected in values: return selected print('invalid selection: %s' % (selected)) if len(values) > 0: return values[0] return None def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any: selected = args.get(key, default) if selected in values: return values[selected] else: return values[default] def get_not_empty(args: Any, key: str, default: Any) -> Any: val = args.get(key, default) if val is None or len(val) == 0: val = default return val def hash_value(sha, param: Param): if param is None: return elif isinstance(param, float): sha.update(bytearray(pack('!f', param))) elif isinstance(param, int): sha.update(bytearray(pack('!I', param))) elif isinstance(param, str): sha.update(param.encode('utf-8')) else: print('cannot hash param: %s, %s' % (param, type(param))) def make_output_name( mode: str, params: ImageParams, size: Size, extras: Optional[Tuple[Param]] = None ) -> str: now = int(time()) sha = sha256() hash_value(sha, mode) hash_value(sha, params.model) hash_value(sha, params.provider) hash_value(sha, params.scheduler.__name__) hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) hash_value(sha, params.steps) hash_value(sha, params.seed) hash_value(sha, size.width) hash_value(sha, size.height) if extras is not None: for param in extras: hash_value(sha, param) return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now) def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join('/', tail)), '/') return path.join(base, tail_path)