basic region prompt parsing
This commit is contained in:
parent
44851e3785
commit
8ba9f3c0b7
|
@ -267,12 +267,20 @@ def load_pipeline(
|
|||
# update panorama params
|
||||
if params.is_panorama():
|
||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
|
||||
logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap)
|
||||
logger.debug(
|
||||
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||
params.unet_tile,
|
||||
unet_stride,
|
||||
params.vae_tile,
|
||||
params.vae_overlap,
|
||||
)
|
||||
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
||||
|
||||
for vae in VAE_COMPONENTS:
|
||||
if hasattr(pipe, vae):
|
||||
getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap)
|
||||
getattr(pipe, vae).set_window_size(
|
||||
params.vae_tile // 8, params.vae_overlap
|
||||
)
|
||||
|
||||
run_gc([device])
|
||||
|
||||
|
|
|
@ -303,7 +303,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
|
||||
# 3.b. Encode region prompts
|
||||
regions = parse_regions(prompt)
|
||||
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = []
|
||||
region_embeds: List[
|
||||
Tuple[
|
||||
List[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]
|
||||
] = []
|
||||
add_region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||
|
@ -322,9 +329,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
current_region_embeds[0] = np.concatenate(
|
||||
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
||||
)
|
||||
add_region_embeds.append(np.concatenate(
|
||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
||||
))
|
||||
add_region_embeds.append(
|
||||
np.concatenate(
|
||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
||||
)
|
||||
)
|
||||
|
||||
region_embeds.append(current_region_embeds)
|
||||
|
||||
|
@ -492,7 +501,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
)
|
||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mode:
|
||||
if mode == "replace":
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||
else:
|
||||
|
|
|
@ -43,8 +43,9 @@ def run_txt2img_pipeline(
|
|||
highres: HighresParams,
|
||||
) -> None:
|
||||
# if using panorama, the pipeline will tile itself (views)
|
||||
if params.is_panorama() or params.is_xl():
|
||||
if params.is_panorama():
|
||||
tile_size = max(params.unet_tile, size.width, size.height)
|
||||
logger.debug("adjusting tile size for panorama to %s", tile_size)
|
||||
else:
|
||||
tile_size = params.unet_tile
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from copy import deepcopy
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from re import Pattern, compile
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,6 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
|||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(add|replace):([^\>])\>")
|
||||
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||
|
@ -446,7 +447,8 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
|||
return prompt
|
||||
|
||||
|
||||
Region = Tuple[int, int, int, int, bool, str]
|
||||
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
||||
|
||||
|
||||
def parse_regions(prompt: str) -> List[Region]:
|
||||
return []
|
||||
return get_tokens_from_prompt(prompt, REGION_TOKEN)
|
||||
|
|
|
@ -2,16 +2,8 @@ import sys
|
|||
from functools import partial
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from optimum.onnxruntime.modeling_diffusion import (
|
||||
ORTModel,
|
||||
ORTStableDiffusionPipelineBase,
|
||||
)
|
||||
|
||||
from ..torch_before_ort import SessionOptions
|
||||
from ..utils import run_gc
|
||||
from .context import ServerContext
|
||||
|
||||
|
|
Loading…
Reference in New Issue