feat(api): add denoise stage, use before highres
This commit is contained in:
parent
4460625309
commit
95e2d6d710
|
@ -1,4 +1,5 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageParams
|
from .base import ChainPipeline, PipelineStage, StageParams
|
||||||
|
from .blend_denoise import BlendDenoiseStage
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
from .blend_grid import BlendGridStage
|
from .blend_grid import BlendGridStage
|
||||||
from .blend_linear import BlendLinearStage
|
from .blend_linear import BlendLinearStage
|
||||||
|
@ -22,6 +23,7 @@ from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||||
from .upscale_swinir import UpscaleSwinIRStage
|
from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
CHAIN_STAGES = {
|
CHAIN_STAGES = {
|
||||||
|
"blend-denoise": BlendDenoiseStage,
|
||||||
"blend-img2img": BlendImg2ImgStage,
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
"blend-inpaint": UpscaleOutpaintStage,
|
"blend-inpaint": UpscaleOutpaintStage,
|
||||||
"blend-grid": BlendGridStage,
|
"blend-grid": BlendGridStage,
|
||||||
|
|
|
@ -232,7 +232,10 @@ class ChainPipeline:
|
||||||
|
|
||||||
stage_sources = stage_outputs
|
stage_sources = stage_outputs
|
||||||
else:
|
else:
|
||||||
logger.debug("image does not contain sources and is within tile size of %s, running stage", tile)
|
logger.debug(
|
||||||
|
"image does not contain sources and is within tile size of %s, running stage",
|
||||||
|
tile,
|
||||||
|
)
|
||||||
for i in range(worker.retries):
|
for i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
stage_outputs = stage_pipe.run(
|
stage_outputs = stage_pipe.run(
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .stage import BaseStage
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlendDenoiseStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: List[Image.Image],
|
||||||
|
*,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
logger.info("denoising source images")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for source in sources:
|
||||||
|
data = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
|
||||||
|
data = cv2.fastNlMeansDenoisingColored(data)
|
||||||
|
results.append(Image.fromarray(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)))
|
||||||
|
|
||||||
|
return results
|
|
@ -702,12 +702,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
latents = repair_nan(latents)
|
latents = repair_nan(latents)
|
||||||
latents = np.clip(latents, -4, +4)
|
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
latents = np.clip(latents, -4, +4)
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
|
|
@ -551,7 +551,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
latents = repair_nan(latents)
|
latents = repair_nan(latents)
|
||||||
latents = np.clip(latents, -4, +4)
|
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or (
|
if i == len(timesteps) - 1 or (
|
||||||
|
@ -563,6 +562,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
image = latents
|
image = latents
|
||||||
else:
|
else:
|
||||||
|
latents = np.clip(latents, -4, +4)
|
||||||
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
image = np.concatenate(
|
image = np.concatenate(
|
||||||
|
|
|
@ -4,15 +4,15 @@ from typing import Any, List, Optional
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from onnx_web.chain.highres import stage_highres
|
|
||||||
|
|
||||||
from ..chain import (
|
from ..chain import (
|
||||||
|
BlendDenoiseStage,
|
||||||
BlendImg2ImgStage,
|
BlendImg2ImgStage,
|
||||||
BlendMaskStage,
|
BlendMaskStage,
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
SourceTxt2ImgStage,
|
SourceTxt2ImgStage,
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
|
from ..chain.highres import stage_highres
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
|
@ -68,6 +68,11 @@ def run_txt2img_pipeline(
|
||||||
highres_size = params.unet_tile
|
highres_size = params.unet_tile
|
||||||
|
|
||||||
stage = StageParams(tile_size=highres_size)
|
stage = StageParams(tile_size=highres_size)
|
||||||
|
chain.stage(
|
||||||
|
BlendDenoiseStage(),
|
||||||
|
stage,
|
||||||
|
)
|
||||||
|
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
|
|
Loading…
Reference in New Issue