1
0
Fork 0

feat(api): backend support for multiple GPUs in diffusion pipelines

This commit is contained in:
Sean Sube 2023-01-21 15:17:33 -06:00
parent 88c5113e37
commit a868c8cf6b
1 changed files with 12 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image, ImageChops
from typing import Any
from typing import Any, Union
import gc
import numpy as np
@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)
if device is not None:
pipe = pipe.to(device)
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')
if device is not None:
scheduler = scheduler.to(device)
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler
print('running garbage collection during pipeline change')