feat(api): collect progress from chain pipelines (#90)
This commit is contained in:
parent
27a3fa8f51
commit
d9fc908592
|
@ -5,7 +5,7 @@ from typing import Any, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
|
@ -30,6 +30,24 @@ class StageCallback(Protocol):
|
||||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class ChainProgress:
|
||||||
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
|
self.parent = parent
|
||||||
|
self.step = start
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
|
if step < self.step:
|
||||||
|
# accumulate on resets
|
||||||
|
self.total += self.step
|
||||||
|
|
||||||
|
self.step = step
|
||||||
|
self.parent(self.get_total(), timestep, latents)
|
||||||
|
|
||||||
|
def get_total(self) -> int:
|
||||||
|
return self.step + self.total
|
||||||
|
|
||||||
|
|
||||||
class ChainPipeline:
|
class ChainPipeline:
|
||||||
"""
|
"""
|
||||||
Run many stages in series, passing the image results from each to the next, and processing
|
Run many stages in series, passing the image results from each to the next, and processing
|
||||||
|
@ -57,11 +75,15 @@ class ChainPipeline:
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
callback: ProgressCallback = None,
|
||||||
**pipeline_kwargs
|
**pipeline_kwargs
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
TODO: handle List[Image] outputs
|
TODO: handle List[Image] inputs and outputs
|
||||||
"""
|
"""
|
||||||
|
if callback is not None:
|
||||||
|
callback = ChainProgress(callback, start=callback.step)
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
logger.info(
|
logger.info(
|
||||||
"running pipeline on source image with dimensions %sx%s",
|
"running pipeline on source image with dimensions %sx%s",
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..diffusion.load import load_pipeline
|
from ..diffusion.load import load_pipeline
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
@ -23,6 +23,7 @@ def blend_img2img(
|
||||||
*,
|
*,
|
||||||
strength: float,
|
strength: float,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
|
callback: ProgressCallback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
|
@ -46,6 +47,7 @@ def blend_img2img(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -57,6 +59,7 @@ def blend_img2img(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = result.images[0]
|
output = result.images[0]
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
|
@ -29,6 +29,7 @@ def blend_inpaint(
|
||||||
fill_color: str = "white",
|
fill_color: str = "white",
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
|
callback: ProgressCallback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info("upscaling image by expanding borders", expand)
|
logger.info("upscaling image by expanding borders", expand)
|
||||||
|
@ -83,6 +84,7 @@ def blend_inpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -97,6 +99,7 @@ def blend_inpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
@ -22,6 +22,7 @@ def source_txt2img(
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
|
callback: ProgressCallback = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
|
@ -53,6 +54,7 @@ def source_txt2img(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -65,6 +67,7 @@ def source_txt2img(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = result.images[0]
|
output = result.images[0]
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
|
@ -30,6 +30,7 @@ def upscale_outpaint(
|
||||||
fill_color: str = "white",
|
fill_color: str = "white",
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
|
callback: ProgressCallback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
|
@ -92,6 +93,7 @@ def upscale_outpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -106,6 +108,7 @@ def upscale_outpaint(
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# once part of the image has been drawn, keep it
|
# once part of the image has been drawn, keep it
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext, ProgressCallback
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
|
@ -67,6 +67,7 @@ def upscale_stable_diffusion(
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
|
callback: ProgressCallback,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
|
@ -80,4 +81,5 @@ def upscale_stable_diffusion(
|
||||||
source,
|
source,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
|
@ -10,6 +10,8 @@ from .utils import run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
ProgressCallback = Callable[[int, int, Any], None]
|
||||||
|
|
||||||
|
|
||||||
class JobContext:
|
class JobContext:
|
||||||
cancel: Value = None
|
cancel: Value = None
|
||||||
|
@ -51,8 +53,9 @@ class JobContext:
|
||||||
def get_progress(self) -> int:
|
def get_progress(self) -> int:
|
||||||
return self.progress.value
|
return self.progress.value
|
||||||
|
|
||||||
def get_progress_callback(self) -> Callable[..., None]:
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
|
on_progress.step = step
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
raise Exception("job has been cancelled")
|
raise Exception("job has been cancelled")
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue