From f28fdda47a7a318c109af6b9a7ae1df4a7fcfeb3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 27 Dec 2023 20:17:35 -0600 Subject: [PATCH] feat(api): add stage for local standard deviation denoising for XL --- ...enoise.py => blend_denoise_fastnlmeans.py} | 2 +- api/onnx_web/chain/blend_denoise_localstd.py | 113 ++++++++++++++++++ api/onnx_web/chain/stages.py | 7 +- api/onnx_web/diffusers/run.py | 6 +- 4 files changed, 122 insertions(+), 6 deletions(-) rename api/onnx_web/chain/{blend_denoise.py => blend_denoise_fastnlmeans.py} (95%) create mode 100644 api/onnx_web/chain/blend_denoise_localstd.py diff --git a/api/onnx_web/chain/blend_denoise.py b/api/onnx_web/chain/blend_denoise_fastnlmeans.py similarity index 95% rename from api/onnx_web/chain/blend_denoise.py rename to api/onnx_web/chain/blend_denoise_fastnlmeans.py index e40a30a2..c3de00e2 100644 --- a/api/onnx_web/chain/blend_denoise.py +++ b/api/onnx_web/chain/blend_denoise_fastnlmeans.py @@ -13,7 +13,7 @@ from .result import StageResult logger = getLogger(__name__) -class BlendDenoiseStage(BaseStage): +class BlendDenoiseFastNLMeansStage(BaseStage): max_tile = SizeChart.max def run( diff --git a/api/onnx_web/chain/blend_denoise_localstd.py b/api/onnx_web/chain/blend_denoise_localstd.py new file mode 100644 index 00000000..a98ae22c --- /dev/null +++ b/api/onnx_web/chain/blend_denoise_localstd.py @@ -0,0 +1,113 @@ +from logging import getLogger +from typing import Optional + +import numpy as np +from PIL import Image + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class BlendDenoiseLocalStdStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + sources: StageResult, + *, + strength: int = 3, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + logger.info("denoising source images") + + results = [] + for source in sources.as_numpy(): + results.append(remove_noise(source)) + + return StageResult(arrays=results) + + +def downscale_image(image): + result_image = np.zeros((image.shape[0] // 2, image.shape[1] // 2), dtype=np.uint8) + + for i in range(0, image.shape[0] - 1, 2): + for j in range(0, image.shape[1] - 1, 2): + # Average the four neighboring pixels + pixel_average = np.mean(image[i : i + 2, j : j + 2], axis=(0, 1)) + result_image[i // 2, j // 2] = pixel_average.astype(np.uint8) + + return result_image + + +def replace_noise(region, threshold): + # Identify stray pixels (brightness significantly deviates from surrounding pixels) + central_pixel = region[1, 1] + + region_median = np.median(region) + region_deviation = np.std(region) + diff = np.abs(central_pixel - region_median) + + # If the whole region is fairly consistent but the central pixel deviates significantly, + if diff > region_deviation and diff > threshold: + surrounding_pixels = region[region != central_pixel] + surrounding_median = np.median(surrounding_pixels) + # replace it with the median of surrounding pixels + region[1, 1] = surrounding_median + return True + + return False + + +def remove_noise(image, region_size=(6, 6), threshold=10): + # Assuming 'image' is a 3D numpy array representing the RGB image + + # Create a copy of the original image to store the result + result_image = np.copy(image) + # result_mask = np.ones_like(image) * 255 + + # Iterate over regions in each channel + i_inc = region_size[0] // 2 + j_inc = region_size[1] // 2 + + for i in range(i_inc, image.shape[0] - i_inc, 1): + for j in range(j_inc, image.shape[1] - j_inc, 1): + i_min = i - (region_size[0] // 2) + i_max = i + (region_size[0] // 2) + j_min = j - (region_size[1] // 2) + j_max = j + (region_size[1] // 2) + + # Extract region from each channel + region_red = downscale_image(image[i_min:i_max, j_min:j_max, 0]) + region_green = downscale_image(image[i_min:i_max, j_min:j_max, 1]) + region_blue = downscale_image(image[i_min:i_max, j_min:j_max, 2]) + + replaced = any( + [ + replace_noise(region_red, threshold), + replace_noise(region_green, threshold), + ] + ) + + # Apply the noise removal function to each channel + if replaced: + # Assign the processed region back to the result image + result_image[i - 1 : i + 1, j - 1 : j + 1, 0] = region_red[1, 1] + result_image[i - 1 : i + 1, j - 1 : j + 1, 1] = region_green[1, 1] + result_image[i - 1 : i + 1, j - 1 : j + 1, 2] = region_blue[1, 1] + + # result_mask[i-1:i+1, j-1:j+1, 0] = 0 + # result_mask[i-1:i+1, j-1:j+1, 1] = 0 + # result_mask[i-1:i+1, j-1:j+1, 2] = 0 + + return result_image # , result_mask) diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 4ae14346..0b3e6359 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -1,7 +1,8 @@ from logging import getLogger from .base import BaseStage -from .blend_denoise import BlendDenoiseStage +from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage +from .blend_denoise_localstd import BlendDenoiseLocalStdStage from .blend_grid import BlendGridStage from .blend_img2img import BlendImg2ImgStage from .blend_linear import BlendLinearStage @@ -27,7 +28,9 @@ from .upscale_swinir import UpscaleSwinIRStage logger = getLogger(__name__) CHAIN_STAGES = { - "blend-denoise": BlendDenoiseStage, + "blend-denoise": BlendDenoiseFastNLMeansStage, + "blend-denoise-fastnlmeans": BlendDenoiseFastNLMeansStage, + "blend-denoise-localstd": BlendDenoiseLocalStdStage, "blend-img2img": BlendImg2ImgStage, "blend-inpaint": UpscaleOutpaintStage, "blend-grid": BlendGridStage, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 5bbf65a1..48269fd4 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -5,7 +5,7 @@ from typing import Any, List, Optional from PIL import Image, ImageOps from ..chain import ( - BlendDenoiseStage, + BlendDenoiseLocalStdStage, BlendImg2ImgStage, BlendMaskStage, ChainPipeline, @@ -78,9 +78,9 @@ def run_txt2img_pipeline( # apply upscaling and correction, before highres highres_size = get_highres_tile(server, params, highres, tile_size) - if params.is_panorama(): + if params.is_xl(): chain.stage( - BlendDenoiseStage(), + BlendDenoiseLocalStdStage(), StageParams(tile_size=highres_size), )