1
0
Fork 0

feat(api): collect progress from chain pipelines (#90)

This commit is contained in:
Sean Sube 2023-02-12 12:17:36 -06:00
parent 27a3fa8f51
commit d9fc908592
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 47 additions and 8 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional, Protocol, Tuple
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
@ -30,6 +30,24 @@ class StageCallback(Protocol):
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]] PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
class ChainPipeline: class ChainPipeline:
""" """
Run many stages in series, passing the image results from each to the next, and processing Run many stages in series, passing the image results from each to the next, and processing
@ -57,11 +75,15 @@ class ChainPipeline:
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
callback: ProgressCallback = None,
**pipeline_kwargs **pipeline_kwargs
) -> Image.Image: ) -> Image.Image:
""" """
TODO: handle List[Image] outputs TODO: handle List[Image] inputs and outputs
""" """
if callback is not None:
callback = ChainProgress(callback, start=callback.step)
start = monotonic() start = monotonic()
logger.info( logger.info(
"running pipeline on source image with dimensions %sx%s", "running pipeline on source image with dimensions %sx%s",

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..utils import ServerContext from ..utils import ServerContext
@ -23,6 +23,7 @@ def blend_img2img(
*, *,
strength: float, strength: float,
prompt: Optional[str] = None, prompt: Optional[str] = None,
callback: ProgressCallback,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
@ -46,6 +47,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,
callback=callback,
) )
else: else:
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -57,6 +59,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,
callback=callback,
) )
output = result.images[0] output = result.images[0]

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
@ -29,6 +29,7 @@ def blend_inpaint(
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info("upscaling image by expanding borders", expand) logger.info("upscaling image by expanding borders", expand)
@ -83,6 +84,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
callback=callback,
) )
else: else:
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -97,6 +99,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
callback=callback,
) )
return result.images[0] return result.images[0]

View File

@ -5,7 +5,7 @@ import torch
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext from ..utils import ServerContext
@ -22,6 +22,7 @@ def source_txt2img(
*, *,
size: Size, size: Size,
prompt: str = None, prompt: str = None,
callback: ProgressCallback = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
@ -53,6 +54,7 @@ def source_txt2img(
latents=latents, latents=latents,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
callback=callback,
) )
else: else:
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -65,6 +67,7 @@ def source_txt2img(
latents=latents, latents=latents,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
callback=callback,
) )
output = result.images[0] output = result.images[0]

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
@ -30,6 +30,7 @@ def upscale_outpaint(
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
@ -92,6 +93,7 @@ def upscale_outpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
callback=callback,
) )
else: else:
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -106,6 +108,7 @@ def upscale_outpaint(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
callback=callback,
) )
# once part of the image has been drawn, keep it # once part of the image has been drawn, keep it

View File

@ -5,7 +5,7 @@ import torch
from diffusers import StableDiffusionUpscalePipeline from diffusers import StableDiffusionUpscalePipeline
from PIL import Image from PIL import Image
from ..device_pool import JobContext from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
@ -67,6 +67,7 @@ def upscale_stable_diffusion(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
prompt: str = None, prompt: str = None,
callback: ProgressCallback,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt prompt = prompt or params.prompt
@ -80,4 +81,5 @@ def upscale_stable_diffusion(
source, source,
generator=generator, generator=generator,
num_inference_steps=params.steps, num_inference_steps=params.steps,
callback=callback,
).images[0] ).images[0]

View File

@ -10,6 +10,8 @@ from .utils import run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
class JobContext: class JobContext:
cancel: Value = None cancel: Value = None
@ -51,8 +53,9 @@ class JobContext:
def get_progress(self) -> int: def get_progress(self) -> int:
return self.progress.value return self.progress.value
def get_progress_callback(self) -> Callable[..., None]: def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled(): if self.is_cancelled():
raise Exception("job has been cancelled") raise Exception("job has been cancelled")
else: else: