feat(api): add feature flag for single-tile panorama highres
This commit is contained in:
parent
798fa5fc6d
commit
5fb2de85c5
|
@ -79,11 +79,11 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||||
|
|
||||||
# reseed latents as needed
|
# reseed latents as needed
|
||||||
reseed_rng = np.random.default_rng(params.seed)
|
reseed_rng = np.random.RandomState(params.seed)
|
||||||
prompt, reseed = parse_reseed(prompt)
|
prompt, reseed = parse_reseed(prompt)
|
||||||
for top, left, bottom, right, region_seed in reseed:
|
for top, left, bottom, right, region_seed in reseed:
|
||||||
if region_seed == -1:
|
if region_seed == -1:
|
||||||
region_seed = reseed_rng.integers(2**32)
|
region_seed = reseed_rng.random_integers(2**32 - 1)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||||
|
|
|
@ -62,7 +62,12 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
stage = StageParams(tile_size=params.unet_tile)
|
if params.is_panorama() and server.panorama_tiles:
|
||||||
|
highres_size = tile_size * highres.scale
|
||||||
|
else:
|
||||||
|
highres_size = params.unet_tile
|
||||||
|
|
||||||
|
stage = StageParams(tile_size=highres_size)
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
|
|
|
@ -233,9 +233,6 @@ def get_tokens_from_prompt(
|
||||||
pattern: Pattern,
|
pattern: Pattern,
|
||||||
parser=parse_float_group,
|
parser=parse_float_group,
|
||||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
"""
|
|
||||||
TODO: replace with Arpeggio
|
|
||||||
"""
|
|
||||||
remaining_prompt = prompt
|
remaining_prompt = prompt
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
|
@ -10,13 +10,38 @@ from .model_cache import ModelCache
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_ANY_PLATFORM = True
|
||||||
DEFAULT_CACHE_LIMIT = 5
|
DEFAULT_CACHE_LIMIT = 5
|
||||||
DEFAULT_JOB_LIMIT = 10
|
DEFAULT_JOB_LIMIT = 10
|
||||||
DEFAULT_IMAGE_FORMAT = "png"
|
DEFAULT_IMAGE_FORMAT = "png"
|
||||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||||
|
DEFAULT_SHOW_PROGRESS = True
|
||||||
|
DEFAULT_PANORAMA_TILES = False
|
||||||
|
DEFAULT_WORKER_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
class ServerContext:
|
class ServerContext:
|
||||||
|
bundle_path: str
|
||||||
|
model_path: str
|
||||||
|
output_path: str
|
||||||
|
params_path: str
|
||||||
|
cors_origin: str
|
||||||
|
any_platform: bool
|
||||||
|
block_platforms: List[str]
|
||||||
|
default_platform: str
|
||||||
|
image_format: str
|
||||||
|
cache_limit: int
|
||||||
|
cache_path: str
|
||||||
|
show_progress: bool
|
||||||
|
optimizations: List[str]
|
||||||
|
extra_models: List[str]
|
||||||
|
job_limit: int
|
||||||
|
memory_limit: int
|
||||||
|
admin_token: str
|
||||||
|
server_version: str
|
||||||
|
worker_retries: int
|
||||||
|
panorama_tiles: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bundle_path: str = ".",
|
bundle_path: str = ".",
|
||||||
|
@ -24,20 +49,21 @@ class ServerContext:
|
||||||
output_path: str = ".",
|
output_path: str = ".",
|
||||||
params_path: str = ".",
|
params_path: str = ".",
|
||||||
cors_origin: str = "*",
|
cors_origin: str = "*",
|
||||||
any_platform: bool = True,
|
any_platform: bool = DEFAULT_ANY_PLATFORM,
|
||||||
block_platforms: Optional[List[str]] = None,
|
block_platforms: Optional[List[str]] = None,
|
||||||
default_platform: Optional[str] = None,
|
default_platform: Optional[str] = None,
|
||||||
image_format: str = DEFAULT_IMAGE_FORMAT,
|
image_format: str = DEFAULT_IMAGE_FORMAT,
|
||||||
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
show_progress: bool = True,
|
show_progress: bool = DEFAULT_SHOW_PROGRESS,
|
||||||
optimizations: Optional[List[str]] = None,
|
optimizations: Optional[List[str]] = None,
|
||||||
extra_models: Optional[List[str]] = None,
|
extra_models: Optional[List[str]] = None,
|
||||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||||
memory_limit: Optional[int] = None,
|
memory_limit: Optional[int] = None,
|
||||||
admin_token: Optional[str] = None,
|
admin_token: Optional[str] = None,
|
||||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||||
worker_retries: Optional[int] = 3,
|
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||||
|
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -58,6 +84,7 @@ class ServerContext:
|
||||||
self.admin_token = admin_token or token_urlsafe()
|
self.admin_token = admin_token or token_urlsafe()
|
||||||
self.server_version = server_version
|
self.server_version = server_version
|
||||||
self.worker_retries = worker_retries
|
self.worker_retries = worker_retries
|
||||||
|
self.panorama_tiles = panorama_tiles
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
|
@ -76,12 +103,16 @@ class ServerContext:
|
||||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||||
# others
|
# others
|
||||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||||
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
|
any_platform=get_boolean(
|
||||||
|
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||||
|
),
|
||||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
show_progress=get_boolean(
|
||||||
|
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||||
|
),
|
||||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
|
@ -90,7 +121,12 @@ class ServerContext:
|
||||||
server_version=environ.get(
|
server_version=environ.get(
|
||||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
||||||
),
|
),
|
||||||
worker_retries=int(environ.get("ONNX_WEB_WORKER_RETRIES", 3)),
|
worker_retries=int(
|
||||||
|
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||||
|
),
|
||||||
|
panorama_tiles=get_boolean(
|
||||||
|
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def torch_dtype(self):
|
def torch_dtype(self):
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
"ddpm",
|
"ddpm",
|
||||||
"deis",
|
"deis",
|
||||||
"denoise",
|
"denoise",
|
||||||
|
"denoised",
|
||||||
"denoising",
|
"denoising",
|
||||||
"directml",
|
"directml",
|
||||||
"Dreambooth",
|
"Dreambooth",
|
||||||
|
|
Loading…
Reference in New Issue