1
0
Fork 0

feat(api): backend support for multiple GPUs in diffusion pipelines

This commit is contained in:
Sean Sube 2023-01-21 15:17:33 -06:00
parent 88c5113e37
commit a868c8cf6b
1 changed files with 12 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from diffusers import (
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
) )
from PIL import Image, ImageChops from PIL import Image, ImageChops
from typing import Any from typing import Any, Union
import gc import gc
import numpy as np import numpy as np
@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any): def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
safety_checker=None, safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler') scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
) )
if device is not None:
pipe = pipe.to(device)
last_pipeline_instance = pipe last_pipeline_instance = pipe
last_pipeline_options = options last_pipeline_options = options
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler') print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, subfolder='scheduler') model, subfolder='scheduler')
if device is not None:
scheduler = scheduler.to(device)
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
print('running garbage collection during pipeline change') print('running garbage collection during pipeline change')