1
0
Fork 0

feat(api): collect progress from chain pipelines (#90)

This commit is contained in:
Sean Sube 2023-02-12 12:17:36 -06:00
parent 27a3fa8f51
commit d9fc908592
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 47 additions and 8 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional, Protocol, Tuple
from PIL import Image
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
@ -30,6 +30,24 @@ class StageCallback(Protocol):
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
@ -57,11 +75,15 @@ class ChainPipeline:
server: ServerContext,
params: ImageParams,
source: Image.Image,
callback: ProgressCallback = None,
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] outputs
TODO: handle List[Image] inputs and outputs
"""
if callback is not None:
callback = ChainProgress(callback, start=callback.step)
start = monotonic()
logger.info(
"running pipeline on source image with dimensions %sx%s",

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..utils import ServerContext
@ -23,6 +23,7 @@ def blend_img2img(
*,
strength: float,
prompt: Optional[str] = None,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
@ -46,6 +47,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
@ -57,6 +59,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=callback,
)
output = result.images[0]

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
@ -29,6 +29,7 @@ def blend_inpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
logger.info("upscaling image by expanding borders", expand)
@ -83,6 +84,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
@ -97,6 +99,7 @@ def blend_inpaint(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
return result.images[0]

View File

@ -5,7 +5,7 @@ import torch
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
@ -22,6 +22,7 @@ def source_txt2img(
*,
size: Size,
prompt: str = None,
callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
@ -53,6 +54,7 @@ def source_txt2img(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
@ -65,6 +67,7 @@ def source_txt2img(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
output = result.images[0]

View File

@ -6,7 +6,7 @@ import torch
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
@ -30,6 +30,7 @@ def upscale_outpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
@ -92,6 +93,7 @@ def upscale_outpaint(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
@ -106,6 +108,7 @@ def upscale_outpaint(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
# once part of the image has been drawn, keep it

View File

@ -5,7 +5,7 @@ import torch
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image
from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
@ -67,6 +67,7 @@ def upscale_stable_diffusion(
*,
upscale: UpscaleParams,
prompt: str = None,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
@ -80,4 +81,5 @@ def upscale_stable_diffusion(
source,
generator=generator,
num_inference_steps=params.steps,
callback=callback,
).images[0]

View File

@ -10,6 +10,8 @@ from .utils import run_gc
logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
class JobContext:
cancel: Value = None
@ -51,8 +53,9 @@ class JobContext:
def get_progress(self) -> int:
return self.progress.value
def get_progress_callback(self) -> Callable[..., None]:
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise Exception("job has been cancelled")
else: