1
0
Fork 0

basic region prompt parsing

This commit is contained in:
Sean Sube 2023-11-05 15:46:37 -06:00
parent 44851e3785
commit 8ba9f3c0b7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 31 additions and 19 deletions

View File

@ -267,12 +267,20 @@ def load_pipeline(
# update panorama params # update panorama params
if params.is_panorama(): if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8 unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap) logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile,
unet_stride,
params.vae_tile,
params.vae_overlap,
)
pipe.set_window_size(params.unet_tile // 8, unet_stride) pipe.set_window_size(params.unet_tile // 8, unet_stride)
for vae in VAE_COMPONENTS: for vae in VAE_COMPONENTS:
if hasattr(pipe, vae): if hasattr(pipe, vae):
getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap) getattr(pipe, vae).set_window_size(
params.vae_tile // 8, params.vae_overlap
)
run_gc([device]) run_gc([device])

View File

@ -303,7 +303,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 3.b. Encode region prompts # 3.b. Encode region prompts
regions = parse_regions(prompt) regions = parse_regions(prompt)
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = [] region_embeds: List[
Tuple[
List[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
]
] = []
add_region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions: for _top, _left, _bottom, _right, _mode, region_prompt in regions:
@ -322,9 +329,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
current_region_embeds[0] = np.concatenate( current_region_embeds[0] = np.concatenate(
(current_region_embeds[1], current_region_embeds[0]), axis=0 (current_region_embeds[1], current_region_embeds[0]), axis=0
) )
add_region_embeds.append(np.concatenate( add_region_embeds.append(
np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0 (current_region_embeds[3], current_region_embeds[2]), axis=0
)) )
)
region_embeds.append(current_region_embeds) region_embeds.append(current_region_embeds)
@ -492,7 +501,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_view_denoised = scheduler_output.prev_sample.numpy() latents_view_denoised = scheduler_output.prev_sample.numpy()
if mode: if mode == "replace":
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1 count[:, :, h_start:h_end, w_start:w_end] = 1
else: else:

View File

@ -43,8 +43,9 @@ def run_txt2img_pipeline(
highres: HighresParams, highres: HighresParams,
) -> None: ) -> None:
# if using panorama, the pipeline will tile itself (views) # if using panorama, the pipeline will tile itself (views)
if params.is_panorama() or params.is_xl(): if params.is_panorama():
tile_size = max(params.unet_tile, size.width, size.height) tile_size = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile_size)
else: else:
tile_size = params.unet_tile tile_size = params.unet_tile

View File

@ -3,7 +3,7 @@ from copy import deepcopy
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import Pattern, compile from re import Pattern, compile
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Literal, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -21,6 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>") INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(add|replace):([^\>])\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -446,7 +447,8 @@ def slice_prompt(prompt: str, slice: int) -> str:
return prompt return prompt
Region = Tuple[int, int, int, int, bool, str] Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
def parse_regions(prompt: str) -> List[Region]: def parse_regions(prompt: str) -> List[Region]:
return [] return get_tokens_from_prompt(prompt, REGION_TOKEN)

View File

@ -2,16 +2,8 @@ import sys
from functools import partial from functools import partial
from logging import getLogger from logging import getLogger
from os import path from os import path
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from optimum.onnxruntime.modeling_diffusion import (
ORTModel,
ORTStableDiffusionPipelineBase,
)
from ..torch_before_ort import SessionOptions
from ..utils import run_gc from ..utils import run_gc
from .context import ServerContext from .context import ServerContext