1
0
Fork 0

feat(api): add feature flag for single-tile panorama highres

This commit is contained in:
Sean Sube 2023-11-11 17:03:01 -06:00
parent 798fa5fc6d
commit 5fb2de85c5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 51 additions and 12 deletions

View File

@ -79,11 +79,11 @@ class SourceTxt2ImgStage(BaseStage):
latents = get_tile_latents(latents, int(params.seed), latent_size, dims) latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
# reseed latents as needed # reseed latents as needed
reseed_rng = np.random.default_rng(params.seed) reseed_rng = np.random.RandomState(params.seed)
prompt, reseed = parse_reseed(prompt) prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed: for top, left, bottom, right, region_seed in reseed:
if region_seed == -1: if region_seed == -1:
region_seed = reseed_rng.integers(2**32) region_seed = reseed_rng.random_integers(2**32 - 1)
logger.debug( logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s", "reseed latent region: [:, :, %s:%s, %s:%s] with %s",

View File

@ -62,7 +62,12 @@ def run_txt2img_pipeline(
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
stage = StageParams(tile_size=params.unet_tile) if params.is_panorama() and server.panorama_tiles:
highres_size = tile_size * highres.scale
else:
highres_size = params.unet_tile
stage = StageParams(tile_size=highres_size)
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(

View File

@ -233,9 +233,6 @@ def get_tokens_from_prompt(
pattern: Pattern, pattern: Pattern,
parser=parse_float_group, parser=parse_float_group,
) -> Tuple[str, List[Tuple[str, float]]]: ) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt remaining_prompt = prompt
tokens = [] tokens = []

View File

@ -10,13 +10,38 @@ from .model_cache import ModelCache
logger = getLogger(__name__) logger = getLogger(__name__)
DEFAULT_ANY_PLATFORM = True
DEFAULT_CACHE_LIMIT = 5 DEFAULT_CACHE_LIMIT = 5
DEFAULT_JOB_LIMIT = 10 DEFAULT_JOB_LIMIT = 10
DEFAULT_IMAGE_FORMAT = "png" DEFAULT_IMAGE_FORMAT = "png"
DEFAULT_SERVER_VERSION = "v0.10.0" DEFAULT_SERVER_VERSION = "v0.10.0"
DEFAULT_SHOW_PROGRESS = True
DEFAULT_PANORAMA_TILES = False
DEFAULT_WORKER_RETRIES = 3
class ServerContext: class ServerContext:
bundle_path: str
model_path: str
output_path: str
params_path: str
cors_origin: str
any_platform: bool
block_platforms: List[str]
default_platform: str
image_format: str
cache_limit: int
cache_path: str
show_progress: bool
optimizations: List[str]
extra_models: List[str]
job_limit: int
memory_limit: int
admin_token: str
server_version: str
worker_retries: int
panorama_tiles: bool
def __init__( def __init__(
self, self,
bundle_path: str = ".", bundle_path: str = ".",
@ -24,20 +49,21 @@ class ServerContext:
output_path: str = ".", output_path: str = ".",
params_path: str = ".", params_path: str = ".",
cors_origin: str = "*", cors_origin: str = "*",
any_platform: bool = True, any_platform: bool = DEFAULT_ANY_PLATFORM,
block_platforms: Optional[List[str]] = None, block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None, default_platform: Optional[str] = None,
image_format: str = DEFAULT_IMAGE_FORMAT, image_format: str = DEFAULT_IMAGE_FORMAT,
cache_limit: int = DEFAULT_CACHE_LIMIT, cache_limit: int = DEFAULT_CACHE_LIMIT,
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
show_progress: bool = True, show_progress: bool = DEFAULT_SHOW_PROGRESS,
optimizations: Optional[List[str]] = None, optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None, extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT, job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None, memory_limit: Optional[int] = None,
admin_token: Optional[str] = None, admin_token: Optional[str] = None,
server_version: Optional[str] = DEFAULT_SERVER_VERSION, server_version: Optional[str] = DEFAULT_SERVER_VERSION,
worker_retries: Optional[int] = 3, worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -58,6 +84,7 @@ class ServerContext:
self.admin_token = admin_token or token_urlsafe() self.admin_token = admin_token or token_urlsafe()
self.server_version = server_version self.server_version = server_version
self.worker_retries = worker_retries self.worker_retries = worker_retries
self.panorama_tiles = panorama_tiles
self.cache = ModelCache(self.cache_limit) self.cache = ModelCache(self.cache_limit)
@ -76,12 +103,16 @@ class ServerContext:
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others # others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
@ -90,7 +121,12 @@ class ServerContext:
server_version=environ.get( server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
), ),
worker_retries=int(environ.get("ONNX_WEB_WORKER_RETRIES", 3)), worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
panorama_tiles=get_boolean(
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
),
) )
def torch_dtype(self): def torch_dtype(self):

View File

@ -35,6 +35,7 @@
"ddpm", "ddpm",
"deis", "deis",
"denoise", "denoise",
"denoised",
"denoising", "denoising",
"directml", "directml",
"Dreambooth", "Dreambooth",