basic region prompt parsing
This commit is contained in:
parent
44851e3785
commit
8ba9f3c0b7
|
@ -267,12 +267,20 @@ def load_pipeline(
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
|
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
|
||||||
logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap)
|
logger.debug(
|
||||||
|
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||||
|
params.unet_tile,
|
||||||
|
unet_stride,
|
||||||
|
params.vae_tile,
|
||||||
|
params.vae_overlap,
|
||||||
|
)
|
||||||
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
||||||
|
|
||||||
for vae in VAE_COMPONENTS:
|
for vae in VAE_COMPONENTS:
|
||||||
if hasattr(pipe, vae):
|
if hasattr(pipe, vae):
|
||||||
getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap)
|
getattr(pipe, vae).set_window_size(
|
||||||
|
params.vae_tile // 8, params.vae_overlap
|
||||||
|
)
|
||||||
|
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
|
|
|
@ -303,7 +303,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
# 3.b. Encode region prompts
|
# 3.b. Encode region prompts
|
||||||
regions = parse_regions(prompt)
|
regions = parse_regions(prompt)
|
||||||
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = []
|
region_embeds: List[
|
||||||
|
Tuple[
|
||||||
|
List[np.ndarray],
|
||||||
|
Optional[np.ndarray],
|
||||||
|
Optional[np.ndarray],
|
||||||
|
Optional[np.ndarray],
|
||||||
|
]
|
||||||
|
] = []
|
||||||
add_region_embeds: List[np.ndarray] = []
|
add_region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||||
|
@ -322,9 +329,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
current_region_embeds[0] = np.concatenate(
|
current_region_embeds[0] = np.concatenate(
|
||||||
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
||||||
)
|
)
|
||||||
add_region_embeds.append(np.concatenate(
|
add_region_embeds.append(
|
||||||
|
np.concatenate(
|
||||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
region_embeds.append(current_region_embeds)
|
region_embeds.append(current_region_embeds)
|
||||||
|
|
||||||
|
@ -492,7 +501,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
if mode:
|
if mode == "replace":
|
||||||
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
|
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
|
||||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -43,8 +43,9 @@ def run_txt2img_pipeline(
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
# if using panorama, the pipeline will tile itself (views)
|
# if using panorama, the pipeline will tile itself (views)
|
||||||
if params.is_panorama() or params.is_xl():
|
if params.is_panorama():
|
||||||
tile_size = max(params.unet_tile, size.width, size.height)
|
tile_size = max(params.unet_tile, size.width, size.height)
|
||||||
|
logger.debug("adjusting tile size for panorama to %s", tile_size)
|
||||||
else:
|
else:
|
||||||
tile_size = params.unet_tile
|
tile_size = params.unet_tile
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from copy import deepcopy
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import Pattern, compile
|
from re import Pattern, compile
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,6 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
|
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(add|replace):([^\>])\>")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
@ -446,7 +447,8 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
Region = Tuple[int, int, int, int, bool, str]
|
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
||||||
|
|
||||||
|
|
||||||
def parse_regions(prompt: str) -> List[Region]:
|
def parse_regions(prompt: str) -> List[Region]:
|
||||||
return []
|
return get_tokens_from_prompt(prompt, REGION_TOKEN)
|
||||||
|
|
|
@ -2,16 +2,8 @@ import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from optimum.onnxruntime.modeling_diffusion import (
|
|
||||||
ORTModel,
|
|
||||||
ORTStableDiffusionPipelineBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..torch_before_ort import SessionOptions
|
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue