From a868c8cf6bd249571700b0d1fbf6b2ea2cc54d09 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 21 Jan 2023 15:17:33 -0600 Subject: [PATCH] feat(api): backend support for multiple GPUs in diffusion pipelines --- api/onnx_web/pipeline.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index 58cab46c..b8a23bee 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -6,7 +6,7 @@ from diffusers import ( OnnxStableDiffusionInpaintPipeline, ) from PIL import Image, ImageChops -from typing import Any +from typing import Any, Union import gc import numpy as np @@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray: return image_latents -def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any): +def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options @@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu safety_checker=None, scheduler=scheduler.from_pretrained(model, subfolder='scheduler') ) + + if device is not None: + pipe = pipe.to(device) + last_pipeline_instance = pipe last_pipeline_options = options last_pipeline_scheduler = scheduler if last_pipeline_scheduler != scheduler: print('changing pipeline scheduler') - pipe.scheduler = scheduler.from_pretrained( + scheduler = scheduler.from_pretrained( model, subfolder='scheduler') + + if device is not None: + scheduler = scheduler.to(device) + + pipe.scheduler = scheduler last_pipeline_scheduler = scheduler print('running garbage collection during pipeline change')