From 4a2498ad8d6dd1c90296271a1a326e0377ceeb8f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 8 Nov 2023 22:04:15 -0600 Subject: [PATCH] apply lint --- api/onnx_web/chain/tile.py | 2 +- api/onnx_web/convert/diffusion/diffusion.py | 9 ++++++++- api/onnx_web/diffusers/pipelines/panorama.py | 18 +++++++++++++++--- .../diffusers/pipelines/panorama_xl.py | 18 +++++++++++++++--- api/onnx_web/diffusers/utils.py | 16 +++++++++++++--- 5 files changed, 52 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 1e9fa38d..efbdc771 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -101,7 +101,7 @@ def make_tile_mask( # sort gradient points p1 = adj_tile - p2 = (tile - adj_tile) + p2 = tile - adj_tile points = [0, min(p1, p2), max(p1, p2), tile] # build gradients diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index cee41d52..45762ffe 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,7 +36,14 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc -from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext, is_torch_2_0, load_tensor, onnx_export +from ..utils import ( + RESOLVE_FORMATS, + ConversionContext, + check_ext, + is_torch_2_0, + load_tensor, + onnx_export, +) from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index afb05b76..98da1ac8 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -619,7 +619,14 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): # get the latents corresponding to the current view coordinates latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] - logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape) + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, + ) # expand the latents if we are doing classifier free guidance latent_region_input = ( @@ -660,14 +667,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): latents_region_denoised = scheduler_output.prev_sample.numpy() if feather > 0.0: - mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather) + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), self.window, feather + ) + mask = np.expand_dims(mask, axis=0) mask = np.repeat(mask, 4, axis=0) mask = np.expand_dims(mask, axis=0) else: mask = 1 if weight >= 10.0: - value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) count[:, :, h_start:h_end, w_start:w_end] = mask else: value[:, :, h_start:h_end, w_start:w_end] += ( diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index ee6d0413..f5efd2a1 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -464,7 +464,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix # get the latents corresponding to the current view coordinates latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] - logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape) + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, + ) # expand the latents if we are doing classifier free guidance latent_region_input = ( @@ -514,14 +521,19 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix latents_region_denoised = scheduler_output.prev_sample.numpy() if feather > 0.0: - mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather) + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), self.window, feather + ) + mask = np.expand_dims(mask, axis=0) mask = np.repeat(mask, 4, axis=0) mask = np.expand_dims(mask, axis=0) else: mask = 1 if weight >= 10.0: - value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) count[:, :, h_start:h_end, w_start:w_end] = mask else: value[:, :, h_start:h_end, w_start:w_end] += ( diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index fa4b1dcd..8d16662a 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,7 +3,7 @@ from copy import deepcopy from logging import getLogger from math import ceil from re import Pattern, compile -from typing import Dict, List, Literal, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -21,7 +21,9 @@ CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") -REGION_TOKEN = compile(r"\]+)\>") +REGION_TOKEN = compile( + r"\]+)\>" +) INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") @@ -460,7 +462,15 @@ Region = Tuple[int, int, int, int, float, str] def parse_region_group(group) -> Region: top, left, bottom, right, weight, feather, prompt = group - return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt) + return ( + int(top), + int(left), + int(bottom), + int(right), + float(weight), + float(feather), + prompt, + ) def parse_regions(prompt: str) -> Tuple[str, List[Region]]: