1
0
Fork 0

basic region prompt parsing

This commit is contained in:
Sean Sube 2023-11-05 15:46:37 -06:00
parent 44851e3785
commit 8ba9f3c0b7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 31 additions and 19 deletions

View File

@ -267,12 +267,20 @@ def load_pipeline(
# update panorama params
if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap)
logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile,
unet_stride,
params.vae_tile,
params.vae_overlap,
)
pipe.set_window_size(params.unet_tile // 8, unet_stride)
for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap)
getattr(pipe, vae).set_window_size(
params.vae_tile // 8, params.vae_overlap
)
run_gc([device])

View File

@ -303,7 +303,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 3.b. Encode region prompts
regions = parse_regions(prompt)
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = []
region_embeds: List[
Tuple[
List[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
]
] = []
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
@ -322,9 +329,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
current_region_embeds[0] = np.concatenate(
(current_region_embeds[1], current_region_embeds[0]), axis=0
)
add_region_embeds.append(np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0
))
add_region_embeds.append(
np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0
)
)
region_embeds.append(current_region_embeds)
@ -492,7 +501,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
if mode:
if mode == "replace":
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1
else:

View File

@ -43,8 +43,9 @@ def run_txt2img_pipeline(
highres: HighresParams,
) -> None:
# if using panorama, the pipeline will tile itself (views)
if params.is_panorama() or params.is_xl():
if params.is_panorama():
tile_size = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile_size)
else:
tile_size = params.unet_tile

View File

@ -3,7 +3,7 @@ from copy import deepcopy
from logging import getLogger
from math import ceil
from re import Pattern, compile
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple
import numpy as np
import torch
@ -21,6 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(add|replace):([^\>])\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -446,7 +447,8 @@ def slice_prompt(prompt: str, slice: int) -> str:
return prompt
Region = Tuple[int, int, int, int, bool, str]
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
def parse_regions(prompt: str) -> List[Region]:
return []
return get_tokens_from_prompt(prompt, REGION_TOKEN)

View File

@ -2,16 +2,8 @@ import sys
from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from optimum.onnxruntime.modeling_diffusion import (
ORTModel,
ORTStableDiffusionPipelineBase,
)
from ..torch_before_ort import SessionOptions
from ..utils import run_gc
from .context import ServerContext