1
0
Fork 0

feat(api): add denoise stage, use before highres

This commit is contained in:
Sean Sube 2023-11-12 21:13:52 -06:00
parent 4460625309
commit 95e2d6d710
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 54 additions and 5 deletions

View File

@ -1,4 +1,5 @@
from .base import ChainPipeline, PipelineStage, StageParams from .base import ChainPipeline, PipelineStage, StageParams
from .blend_denoise import BlendDenoiseStage
from .blend_img2img import BlendImg2ImgStage from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage from .blend_grid import BlendGridStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage
@ -22,6 +23,7 @@ from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = { CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-img2img": BlendImg2ImgStage, "blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage, "blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage, "blend-grid": BlendGridStage,

View File

@ -232,7 +232,10 @@ class ChainPipeline:
stage_sources = stage_outputs stage_sources = stage_outputs
else: else:
logger.debug("image does not contain sources and is within tile size of %s, running stage", tile) logger.debug(
"image does not contain sources and is within tile size of %s, running stage",
tile,
)
for i in range(worker.retries): for i in range(worker.retries):
try: try:
stage_outputs = stage_pipe.run( stage_outputs = stage_pipe.run(

View File

@ -0,0 +1,39 @@
from logging import getLogger
from typing import List, Optional
import cv2
import numpy as np
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
logger = getLogger(__name__)
class BlendDenoiseStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
logger.info("denoising source images")
results = []
for source in sources:
data = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data)
results.append(Image.fromarray(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)))
return results

View File

@ -702,12 +702,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents) latents = repair_nan(latents)
latents = np.clip(latents, -4, +4)
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1

View File

@ -551,7 +551,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents) latents = repair_nan(latents)
latents = np.clip(latents, -4, +4)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ( if i == len(timesteps) - 1 or (
@ -563,6 +562,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
if output_type == "latent": if output_type == "latent":
image = latents image = latents
else: else:
latents = np.clip(latents, -4, +4)
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate( image = np.concatenate(

View File

@ -4,15 +4,15 @@ from typing import Any, List, Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from onnx_web.chain.highres import stage_highres
from ..chain import ( from ..chain import (
BlendDenoiseStage,
BlendImg2ImgStage, BlendImg2ImgStage,
BlendMaskStage, BlendMaskStage,
ChainPipeline, ChainPipeline,
SourceTxt2ImgStage, SourceTxt2ImgStage,
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.highres import stage_highres
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image from ..output import save_image
@ -68,6 +68,11 @@ def run_txt2img_pipeline(
highres_size = params.unet_tile highres_size = params.unet_tile
stage = StageParams(tile_size=highres_size) stage = StageParams(tile_size=highres_size)
chain.stage(
BlendDenoiseStage(),
stage,
)
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(