diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index a6008d98..40377fe0 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -79,11 +79,11 @@ class SourceTxt2ImgStage(BaseStage): latents = get_tile_latents(latents, int(params.seed), latent_size, dims) # reseed latents as needed - reseed_rng = np.random.default_rng(params.seed) + reseed_rng = np.random.RandomState(params.seed) prompt, reseed = parse_reseed(prompt) for top, left, bottom, right, region_seed in reseed: if region_seed == -1: - region_seed = reseed_rng.integers(2**32) + region_seed = reseed_rng.random_integers(2**32 - 1) logger.debug( "reseed latent region: [:, :, %s:%s, %s:%s] with %s", diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 6c6ebc8e..a18cb75d 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -62,7 +62,12 @@ def run_txt2img_pipeline( ) # apply upscaling and correction, before highres - stage = StageParams(tile_size=params.unet_tile) + if params.is_panorama() and server.panorama_tiles: + highres_size = tile_size * highres.scale + else: + highres_size = params.unet_tile + + stage = StageParams(tile_size=highres_size) first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index d4dccc2d..6655e2bc 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -233,9 +233,6 @@ def get_tokens_from_prompt( pattern: Pattern, parser=parse_float_group, ) -> Tuple[str, List[Tuple[str, float]]]: - """ - TODO: replace with Arpeggio - """ remaining_prompt = prompt tokens = [] diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 9ab9210c..eb7f5a2a 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -10,13 +10,38 @@ from .model_cache import ModelCache logger = getLogger(__name__) +DEFAULT_ANY_PLATFORM = True DEFAULT_CACHE_LIMIT = 5 DEFAULT_JOB_LIMIT = 10 DEFAULT_IMAGE_FORMAT = "png" DEFAULT_SERVER_VERSION = "v0.10.0" +DEFAULT_SHOW_PROGRESS = True +DEFAULT_PANORAMA_TILES = False +DEFAULT_WORKER_RETRIES = 3 class ServerContext: + bundle_path: str + model_path: str + output_path: str + params_path: str + cors_origin: str + any_platform: bool + block_platforms: List[str] + default_platform: str + image_format: str + cache_limit: int + cache_path: str + show_progress: bool + optimizations: List[str] + extra_models: List[str] + job_limit: int + memory_limit: int + admin_token: str + server_version: str + worker_retries: int + panorama_tiles: bool + def __init__( self, bundle_path: str = ".", @@ -24,20 +49,21 @@ class ServerContext: output_path: str = ".", params_path: str = ".", cors_origin: str = "*", - any_platform: bool = True, + any_platform: bool = DEFAULT_ANY_PLATFORM, block_platforms: Optional[List[str]] = None, default_platform: Optional[str] = None, image_format: str = DEFAULT_IMAGE_FORMAT, cache_limit: int = DEFAULT_CACHE_LIMIT, cache_path: Optional[str] = None, - show_progress: bool = True, + show_progress: bool = DEFAULT_SHOW_PROGRESS, optimizations: Optional[List[str]] = None, extra_models: Optional[List[str]] = None, job_limit: int = DEFAULT_JOB_LIMIT, memory_limit: Optional[int] = None, admin_token: Optional[str] = None, server_version: Optional[str] = DEFAULT_SERVER_VERSION, - worker_retries: Optional[int] = 3, + worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES, + panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -58,6 +84,7 @@ class ServerContext: self.admin_token = admin_token or token_urlsafe() self.server_version = server_version self.worker_retries = worker_retries + self.panorama_tiles = panorama_tiles self.cache = ModelCache(self.cache_limit) @@ -76,12 +103,16 @@ class ServerContext: params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), # others cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), - any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), + any_platform=get_boolean( + environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM + ), block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), - show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), + show_progress=get_boolean( + environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS + ), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), @@ -90,7 +121,12 @@ class ServerContext: server_version=environ.get( "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION ), - worker_retries=int(environ.get("ONNX_WEB_WORKER_RETRIES", 3)), + worker_retries=int( + environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) + ), + panorama_tiles=get_boolean( + environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES + ), ) def torch_dtype(self): diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 9674ca2f..995a75d0 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -35,6 +35,7 @@ "ddpm", "deis", "denoise", + "denoised", "denoising", "directml", "Dreambooth",