feat(api): add feature flag for single-tile panorama highres
This commit is contained in:
parent
798fa5fc6d
commit
5fb2de85c5
|
@ -79,11 +79,11 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||
|
||||
# reseed latents as needed
|
||||
reseed_rng = np.random.default_rng(params.seed)
|
||||
reseed_rng = np.random.RandomState(params.seed)
|
||||
prompt, reseed = parse_reseed(prompt)
|
||||
for top, left, bottom, right, region_seed in reseed:
|
||||
if region_seed == -1:
|
||||
region_seed = reseed_rng.integers(2**32)
|
||||
region_seed = reseed_rng.random_integers(2**32 - 1)
|
||||
|
||||
logger.debug(
|
||||
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||
|
|
|
@ -62,7 +62,12 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
stage = StageParams(tile_size=params.unet_tile)
|
||||
if params.is_panorama() and server.panorama_tiles:
|
||||
highres_size = tile_size * highres.scale
|
||||
else:
|
||||
highres_size = params.unet_tile
|
||||
|
||||
stage = StageParams(tile_size=highres_size)
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
|
|
|
@ -233,9 +233,6 @@ def get_tokens_from_prompt(
|
|||
pattern: Pattern,
|
||||
parser=parse_float_group,
|
||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
TODO: replace with Arpeggio
|
||||
"""
|
||||
remaining_prompt = prompt
|
||||
|
||||
tokens = []
|
||||
|
|
|
@ -10,13 +10,38 @@ from .model_cache import ModelCache
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_ANY_PLATFORM = True
|
||||
DEFAULT_CACHE_LIMIT = 5
|
||||
DEFAULT_JOB_LIMIT = 10
|
||||
DEFAULT_IMAGE_FORMAT = "png"
|
||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||
DEFAULT_SHOW_PROGRESS = True
|
||||
DEFAULT_PANORAMA_TILES = False
|
||||
DEFAULT_WORKER_RETRIES = 3
|
||||
|
||||
|
||||
class ServerContext:
|
||||
bundle_path: str
|
||||
model_path: str
|
||||
output_path: str
|
||||
params_path: str
|
||||
cors_origin: str
|
||||
any_platform: bool
|
||||
block_platforms: List[str]
|
||||
default_platform: str
|
||||
image_format: str
|
||||
cache_limit: int
|
||||
cache_path: str
|
||||
show_progress: bool
|
||||
optimizations: List[str]
|
||||
extra_models: List[str]
|
||||
job_limit: int
|
||||
memory_limit: int
|
||||
admin_token: str
|
||||
server_version: str
|
||||
worker_retries: int
|
||||
panorama_tiles: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bundle_path: str = ".",
|
||||
|
@ -24,20 +49,21 @@ class ServerContext:
|
|||
output_path: str = ".",
|
||||
params_path: str = ".",
|
||||
cors_origin: str = "*",
|
||||
any_platform: bool = True,
|
||||
any_platform: bool = DEFAULT_ANY_PLATFORM,
|
||||
block_platforms: Optional[List[str]] = None,
|
||||
default_platform: Optional[str] = None,
|
||||
image_format: str = DEFAULT_IMAGE_FORMAT,
|
||||
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
||||
cache_path: Optional[str] = None,
|
||||
show_progress: bool = True,
|
||||
show_progress: bool = DEFAULT_SHOW_PROGRESS,
|
||||
optimizations: Optional[List[str]] = None,
|
||||
extra_models: Optional[List[str]] = None,
|
||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||
memory_limit: Optional[int] = None,
|
||||
admin_token: Optional[str] = None,
|
||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||
worker_retries: Optional[int] = 3,
|
||||
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -58,6 +84,7 @@ class ServerContext:
|
|||
self.admin_token = admin_token or token_urlsafe()
|
||||
self.server_version = server_version
|
||||
self.worker_retries = worker_retries
|
||||
self.panorama_tiles = panorama_tiles
|
||||
|
||||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
|
@ -76,12 +103,16 @@ class ServerContext:
|
|||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||
# others
|
||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
|
||||
any_platform=get_boolean(
|
||||
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||
),
|
||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||
show_progress=get_boolean(
|
||||
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||
),
|
||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||
|
@ -90,7 +121,12 @@ class ServerContext:
|
|||
server_version=environ.get(
|
||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
||||
),
|
||||
worker_retries=int(environ.get("ONNX_WEB_WORKER_RETRIES", 3)),
|
||||
worker_retries=int(
|
||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||
),
|
||||
panorama_tiles=get_boolean(
|
||||
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
|
||||
),
|
||||
)
|
||||
|
||||
def torch_dtype(self):
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
"ddpm",
|
||||
"deis",
|
||||
"denoise",
|
||||
"denoised",
|
||||
"denoising",
|
||||
"directml",
|
||||
"Dreambooth",
|
||||
|
|
Loading…
Reference in New Issue