1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-12-29 08:19:58 -06:00
parent ce90ffb0ee
commit 1035915d36
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 19 additions and 9 deletions

View File

@ -32,7 +32,10 @@ class BlendDenoiseLocalStdStage(BaseStage):
logger.info("denoising source images") logger.info("denoising source images")
return StageResult.from_arrays( return StageResult.from_arrays(
[remove_noise(source, threshold=strength)[0] for source in sources.as_numpy()] [
remove_noise(source, threshold=strength)[0]
for source in sources.as_numpy()
]
) )
@ -50,7 +53,7 @@ def downscale_image(image: np.ndarray, scale: int = 2):
return result_image return result_image
def replace_noise(region: np.ndarray, threshold: int, deviation: float, op = np.median): def replace_noise(region: np.ndarray, threshold: int, deviation: float, op=np.median):
# Identify stray pixels (brightness significantly deviates from surrounding pixels) # Identify stray pixels (brightness significantly deviates from surrounding pixels)
central_pixel = np.mean(region[2:4, 2:4]) central_pixel = np.mean(region[2:4, 2:4])
@ -59,7 +62,9 @@ def replace_noise(region: np.ndarray, threshold: int, deviation: float, op = np.
diff = np.abs(central_pixel - region_normal) diff = np.abs(central_pixel - region_normal)
# If the whole region is fairly consistent but the central pixel deviates significantly, # If the whole region is fairly consistent but the central pixel deviates significantly,
if diff > (region_deviation + threshold) and diff < (region_deviation + threshold * deviation): if diff > (region_deviation + threshold) and diff < (
region_deviation + threshold * deviation
):
surrounding_pixels = region[region != central_pixel] surrounding_pixels = region[region != central_pixel]
surrounding_median = op(surrounding_pixels) surrounding_median = op(surrounding_pixels)
# replace it with the median of surrounding pixels # replace it with the median of surrounding pixels
@ -69,7 +74,12 @@ def replace_noise(region: np.ndarray, threshold: int, deviation: float, op = np.
return False return False
def remove_noise(image: np.ndarray, threshold: int, deviation: float, region_size: Tuple[int, int] = (6, 6)): def remove_noise(
image: np.ndarray,
threshold: int,
deviation: float,
region_size: Tuple[int, int] = (6, 6),
):
# Create a copy of the original image to store the result # Create a copy of the original image to store the result
result_image = np.copy(image) result_image = np.copy(image)
result_mask = np.zeros_like(image) result_mask = np.zeros_like(image)
@ -87,7 +97,7 @@ def remove_noise(image: np.ndarray, threshold: int, deviation: float, region_siz
# print(i_min, i_max, j_min, j_max) # print(i_min, i_max, j_min, j_max)
# skip if the central pixels have already been masked by a previous artifact # skip if the central pixels have already been masked by a previous artifact
if np.any(result_mask[i - 1:i + 1, j - 1:j + 1] > 0): if np.any(result_mask[i - 1 : i + 1, j - 1 : j + 1] > 0):
pass pass
# Extract region from each channel # Extract region from each channel

View File

@ -227,7 +227,7 @@ def load_pipeline(
vae_encoder_session=components.get("vae_encoder_session", None), vae_encoder_session=components.get("vae_encoder_session", None),
text_encoder_2_session=components.get("text_encoder_2_session", None), text_encoder_2_session=components.get("text_encoder_2_session", None),
tokenizer_2=components.get("tokenizer_2", None), tokenizer_2=components.get("tokenizer_2", None),
add_watermarker=False, # not so invisible: https://github.com/ssube/onnx-web/issues/438 add_watermarker=False, # not so invisible: https://github.com/ssube/onnx-web/issues/438
) )
else: else:
if "controlnet" in components: if "controlnet" in components:

View File

@ -5,9 +5,9 @@ import numpy as np
import torch import torch
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
from ..version_safe_diffusers import AutoencoderKLOutput, DecoderOutput
from ...server import ServerContext from ...server import ServerContext
from ..version_safe_diffusers import AutoencoderKLOutput, DecoderOutput
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -37,8 +37,8 @@ else:
if is_diffusers_0_24: if is_diffusers_0_24:
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.autoencoders.vae import DecoderOutput from diffusers.models.autoencoders.vae import DecoderOutput
from diffusers.models.modeling_outputs import AutoencoderKLOutput
else: else:
from diffusers.models.autoencoder_kl import AutoencoderKLOutput from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DecoderOutput from diffusers.models.vae import DecoderOutput