diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 3743ba69..0989a3e5 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -26,5 +26,5 @@ def correct_codeformer( from codeformer import CodeFormer device = job.get_device() - pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device()) + pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) return pipe(stage_source or source) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index b27d4299..99afce03 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -15,7 +15,7 @@ def load_gfpgan( server: ServerContext, _stage: StageParams, upscale: UpscaleParams, - _device: DeviceParams, + device: DeviceParams, ): # must be within the load function for patch to take effect from gfpgan import GFPGANer @@ -40,7 +40,7 @@ def load_gfpgan( ) server.cache.set("gfpgan", cache_key, gfpgan) - run_gc() + run_gc([device]) return gfpgan diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index deff5894..484314ba 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -89,7 +89,7 @@ def load_resrgan( ) server.cache.set("resrgan", cache_key, upsampler) - run_gc() + run_gc([device]) return upsampler diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 0de51883..e6580dd0 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -50,7 +50,7 @@ def load_stable_diffusion( ) server.cache.set("diffusion", cache_key, pipe) - run_gc() + run_gc([device]) return pipe diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 937cace2..60072621 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -115,16 +115,16 @@ def load_pipeline( ) if device is not None and hasattr(scheduler, "to"): - scheduler = scheduler.to(device.torch_device()) + scheduler = scheduler.to(device.torch_str()) pipe.scheduler = scheduler server.cache.set("scheduler", scheduler_key, scheduler) - run_gc() + run_gc([device]) else: logger.debug("unloading previous diffusion pipeline") server.cache.drop("diffusion", pipe_key) - run_gc() + run_gc([device]) if lpw: custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py" @@ -149,7 +149,7 @@ def load_pipeline( ) if device is not None and hasattr(pipe, "to"): - pipe = pipe.to(device.torch_device()) + pipe = pipe.to(device.torch_str()) server.cache.set("diffusion", pipe_key, pipe) server.cache.set("scheduler", scheduler_key, scheduler) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 82b4d3f3..da58a42b 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -81,9 +81,11 @@ def run_txt2img_pipeline( dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale) + del pipe del image del result - run_gc() + + run_gc([job.get_device()]) logger.info("finished txt2img job: %s", dest) @@ -147,9 +149,11 @@ def run_img2img_pipeline( size = Size(*source_image.size) save_params(server, output, params, size, upscale=upscale) + del pipe del image del result - run_gc() + + run_gc([job.get_device()]) logger.info("finished img2img job: %s", dest) @@ -200,7 +204,8 @@ def run_inpaint_pipeline( save_params(server, output, params, size, upscale=upscale, border=border) del image - run_gc() + + run_gc([job.get_device()]) logger.info("finished inpaint job: %s", dest) @@ -226,7 +231,8 @@ def run_upscale_pipeline( save_params(server, output, params, size, upscale=upscale) del image - run_gc() + + run_gc([job.get_device()]) logger.info("finished upscale job: %s", dest) @@ -263,6 +269,7 @@ def run_blend_pipeline( save_params(server, output, params, size, upscale=upscale) del image - run_gc() + + run_gc([job.get_device()]) logger.info("finished blend job: %s", dest) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index b7d10903..1a645056 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -90,7 +90,7 @@ class DeviceParams: def sess_options(self) -> SessionOptions: return SessionOptions() - def torch_device(self) -> str: + def torch_str(self) -> str: if self.device.startswith("cuda"): return self.device else: diff --git a/api/onnx_web/server/device_pool.py b/api/onnx_web/server/device_pool.py index e11c920c..b935d549 100644 --- a/api/onnx_web/server/device_pool.py +++ b/api/onnx_web/server/device_pool.py @@ -248,7 +248,7 @@ class DevicePoolExecutor: key, format_exception(type(err), err, err.__traceback__), ) - run_gc() + run_gc([self.devices[device]]) future.add_done_callback(job_done) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index b3c01c07..733326e5 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union import torch -from .params import SizeChart +from .params import DeviceParams, SizeChart from .server.model_cache import ModelCache logger = getLogger(__name__) @@ -134,7 +134,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart: raise Exception("invalid size") -def run_gc(): +def run_gc(devices: List[DeviceParams] = []): logger.debug("running garbage collection") gc.collect() - torch.cuda.empty_cache() + + if torch.cuda.is_available(): + for device in devices: + logger.debug("running Torch garbage collection for device: %s", device) + with torch.cuda.device(device.torch_str()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect()