1
0
Fork 0

feat(api): add stage for local standard deviation denoising for XL

This commit is contained in:
Sean Sube 2023-12-27 20:17:35 -06:00
parent 7d56689527
commit f28fdda47a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 122 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class BlendDenoiseStage(BaseStage): class BlendDenoiseFastNLMeansStage(BaseStage):
max_tile = SizeChart.max max_tile = SizeChart.max
def run( def run(

View File

@ -0,0 +1,113 @@
from logging import getLogger
from typing import Optional
import numpy as np
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendDenoiseLocalStdStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
strength: int = 3,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("denoising source images")
results = []
for source in sources.as_numpy():
results.append(remove_noise(source))
return StageResult(arrays=results)
def downscale_image(image):
result_image = np.zeros((image.shape[0] // 2, image.shape[1] // 2), dtype=np.uint8)
for i in range(0, image.shape[0] - 1, 2):
for j in range(0, image.shape[1] - 1, 2):
# Average the four neighboring pixels
pixel_average = np.mean(image[i : i + 2, j : j + 2], axis=(0, 1))
result_image[i // 2, j // 2] = pixel_average.astype(np.uint8)
return result_image
def replace_noise(region, threshold):
# Identify stray pixels (brightness significantly deviates from surrounding pixels)
central_pixel = region[1, 1]
region_median = np.median(region)
region_deviation = np.std(region)
diff = np.abs(central_pixel - region_median)
# If the whole region is fairly consistent but the central pixel deviates significantly,
if diff > region_deviation and diff > threshold:
surrounding_pixels = region[region != central_pixel]
surrounding_median = np.median(surrounding_pixels)
# replace it with the median of surrounding pixels
region[1, 1] = surrounding_median
return True
return False
def remove_noise(image, region_size=(6, 6), threshold=10):
# Assuming 'image' is a 3D numpy array representing the RGB image
# Create a copy of the original image to store the result
result_image = np.copy(image)
# result_mask = np.ones_like(image) * 255
# Iterate over regions in each channel
i_inc = region_size[0] // 2
j_inc = region_size[1] // 2
for i in range(i_inc, image.shape[0] - i_inc, 1):
for j in range(j_inc, image.shape[1] - j_inc, 1):
i_min = i - (region_size[0] // 2)
i_max = i + (region_size[0] // 2)
j_min = j - (region_size[1] // 2)
j_max = j + (region_size[1] // 2)
# Extract region from each channel
region_red = downscale_image(image[i_min:i_max, j_min:j_max, 0])
region_green = downscale_image(image[i_min:i_max, j_min:j_max, 1])
region_blue = downscale_image(image[i_min:i_max, j_min:j_max, 2])
replaced = any(
[
replace_noise(region_red, threshold),
replace_noise(region_green, threshold),
]
)
# Apply the noise removal function to each channel
if replaced:
# Assign the processed region back to the result image
result_image[i - 1 : i + 1, j - 1 : j + 1, 0] = region_red[1, 1]
result_image[i - 1 : i + 1, j - 1 : j + 1, 1] = region_green[1, 1]
result_image[i - 1 : i + 1, j - 1 : j + 1, 2] = region_blue[1, 1]
# result_mask[i-1:i+1, j-1:j+1, 0] = 0
# result_mask[i-1:i+1, j-1:j+1, 1] = 0
# result_mask[i-1:i+1, j-1:j+1, 2] = 0
return result_image # , result_mask)

View File

@ -1,7 +1,8 @@
from logging import getLogger from logging import getLogger
from .base import BaseStage from .base import BaseStage
from .blend_denoise import BlendDenoiseStage from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
from .blend_grid import BlendGridStage from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage
@ -27,7 +28,9 @@ from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__) logger = getLogger(__name__)
CHAIN_STAGES = { CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage, "blend-denoise": BlendDenoiseFastNLMeansStage,
"blend-denoise-fastnlmeans": BlendDenoiseFastNLMeansStage,
"blend-denoise-localstd": BlendDenoiseLocalStdStage,
"blend-img2img": BlendImg2ImgStage, "blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage, "blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage, "blend-grid": BlendGridStage,

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from ..chain import ( from ..chain import (
BlendDenoiseStage, BlendDenoiseLocalStdStage,
BlendImg2ImgStage, BlendImg2ImgStage,
BlendMaskStage, BlendMaskStage,
ChainPipeline, ChainPipeline,
@ -78,9 +78,9 @@ def run_txt2img_pipeline(
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
highres_size = get_highres_tile(server, params, highres, tile_size) highres_size = get_highres_tile(server, params, highres, tile_size)
if params.is_panorama(): if params.is_xl():
chain.stage( chain.stage(
BlendDenoiseStage(), BlendDenoiseLocalStdStage(),
StageParams(tile_size=highres_size), StageParams(tile_size=highres_size),
) )