1
0
Fork 0

fix(api): run torch gc alongside python (#156)

This commit is contained in:
Sean Sube 2023-02-16 18:11:35 -06:00
parent 1ca0c01529
commit 0ed4af18ad
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 32 additions and 19 deletions

View File

@ -26,5 +26,5 @@ def correct_codeformer(
from codeformer import CodeFormer from codeformer import CodeFormer
device = job.get_device() device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(stage_source or source) return pipe(stage_source or source)

View File

@ -15,7 +15,7 @@ def load_gfpgan(
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
upscale: UpscaleParams, upscale: UpscaleParams,
_device: DeviceParams, device: DeviceParams,
): ):
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
from gfpgan import GFPGANer from gfpgan import GFPGANer
@ -40,7 +40,7 @@ def load_gfpgan(
) )
server.cache.set("gfpgan", cache_key, gfpgan) server.cache.set("gfpgan", cache_key, gfpgan)
run_gc() run_gc([device])
return gfpgan return gfpgan

View File

@ -89,7 +89,7 @@ def load_resrgan(
) )
server.cache.set("resrgan", cache_key, upsampler) server.cache.set("resrgan", cache_key, upsampler)
run_gc() run_gc([device])
return upsampler return upsampler

View File

@ -50,7 +50,7 @@ def load_stable_diffusion(
) )
server.cache.set("diffusion", cache_key, pipe) server.cache.set("diffusion", cache_key, pipe)
run_gc() run_gc([device])
return pipe return pipe

View File

@ -115,16 +115,16 @@ def load_pipeline(
) )
if device is not None and hasattr(scheduler, "to"): if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device()) scheduler = scheduler.to(device.torch_str())
pipe.scheduler = scheduler pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler) server.cache.set("scheduler", scheduler_key, scheduler)
run_gc() run_gc([device])
else: else:
logger.debug("unloading previous diffusion pipeline") logger.debug("unloading previous diffusion pipeline")
server.cache.drop("diffusion", pipe_key) server.cache.drop("diffusion", pipe_key)
run_gc() run_gc([device])
if lpw: if lpw:
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py" custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
@ -149,7 +149,7 @@ def load_pipeline(
) )
if device is not None and hasattr(pipe, "to"): if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_device()) pipe = pipe.to(device.torch_str())
server.cache.set("diffusion", pipe_key, pipe) server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, scheduler) server.cache.set("scheduler", scheduler_key, scheduler)

View File

@ -81,9 +81,11 @@ def run_txt2img_pipeline(
dest = save_image(server, output, image) dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
del pipe
del image del image
del result del result
run_gc()
run_gc([job.get_device()])
logger.info("finished txt2img job: %s", dest) logger.info("finished txt2img job: %s", dest)
@ -147,9 +149,11 @@ def run_img2img_pipeline(
size = Size(*source_image.size) size = Size(*source_image.size)
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
del pipe
del image del image
del result del result
run_gc()
run_gc([job.get_device()])
logger.info("finished img2img job: %s", dest) logger.info("finished img2img job: %s", dest)
@ -200,7 +204,8 @@ def run_inpaint_pipeline(
save_params(server, output, params, size, upscale=upscale, border=border) save_params(server, output, params, size, upscale=upscale, border=border)
del image del image
run_gc()
run_gc([job.get_device()])
logger.info("finished inpaint job: %s", dest) logger.info("finished inpaint job: %s", dest)
@ -226,7 +231,8 @@ def run_upscale_pipeline(
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
del image del image
run_gc()
run_gc([job.get_device()])
logger.info("finished upscale job: %s", dest) logger.info("finished upscale job: %s", dest)
@ -263,6 +269,7 @@ def run_blend_pipeline(
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
del image del image
run_gc()
run_gc([job.get_device()])
logger.info("finished blend job: %s", dest) logger.info("finished blend job: %s", dest)

View File

@ -90,7 +90,7 @@ class DeviceParams:
def sess_options(self) -> SessionOptions: def sess_options(self) -> SessionOptions:
return SessionOptions() return SessionOptions()
def torch_device(self) -> str: def torch_str(self) -> str:
if self.device.startswith("cuda"): if self.device.startswith("cuda"):
return self.device return self.device
else: else:

View File

@ -248,7 +248,7 @@ class DevicePoolExecutor:
key, key,
format_exception(type(err), err, err.__traceback__), format_exception(type(err), err, err.__traceback__),
) )
run_gc() run_gc([self.devices[device]])
future.add_done_callback(job_done) future.add_done_callback(job_done)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from .params import SizeChart from .params import DeviceParams, SizeChart
from .server.model_cache import ModelCache from .server.model_cache import ModelCache
logger = getLogger(__name__) logger = getLogger(__name__)
@ -134,7 +134,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
raise Exception("invalid size") raise Exception("invalid size")
def run_gc(): def run_gc(devices: List[DeviceParams] = []):
logger.debug("running garbage collection") logger.debug("running garbage collection")
gc.collect() gc.collect()
if torch.cuda.is_available():
for device in devices:
logger.debug("running Torch garbage collection for device: %s", device)
with torch.cuda.device(device.torch_str()):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect()