1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-11-08 22:04:15 -06:00
parent f4f3bda6f8
commit 4a2498ad8d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 52 additions and 11 deletions

View File

@ -101,7 +101,7 @@ def make_tile_mask(
# sort gradient points
p1 = adj_tile
p2 = (tile - adj_tile)
p2 = tile - adj_tile
points = [0, min(p1, p2), max(p1, p2), tile]
# build gradients

View File

@ -36,7 +36,14 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export
from ..utils import (
RESOLVE_FORMATS,
ConversionContext,
check_ext,
is_torch_2_0,
load_tensor,
onnx_export,
)
from .checkpoint import convert_extract_checkpoint
logger = getLogger(__name__)

View File

@ -619,7 +619,14 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
@ -660,14 +667,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather > 0.0:
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
mask = make_tile_mask(
(h_end - h_start, w_end - w_start), self.window, feather
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (

View File

@ -464,7 +464,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
@ -514,14 +521,19 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather > 0.0:
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
mask = make_tile_mask(
(h_end - h_start, w_end - w_start), self.window, feather
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (

View File

@ -3,7 +3,7 @@ from copy import deepcopy
from logging import getLogger
from math import ceil
from re import Pattern, compile
from typing import Dict, List, Literal, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
@ -21,7 +21,9 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>"
)
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -460,7 +462,15 @@ Region = Tuple[int, int, int, int, float, str]
def parse_region_group(group) -> Region:
top, left, bottom, right, weight, feather, prompt = group
return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt)
return (
int(top),
int(left),
int(bottom),
int(right),
float(weight),
float(feather),
prompt,
)
def parse_regions(prompt: str) -> Tuple[str, List[Region]]: