fix(api): run torch gc alongside python (#156)
This commit is contained in:
parent
1ca0c01529
commit
0ed4af18ad
|
@ -26,5 +26,5 @@ def correct_codeformer(
|
|||
from codeformer import CodeFormer
|
||||
|
||||
device = job.get_device()
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||
return pipe(stage_source or source)
|
||||
|
|
|
@ -15,7 +15,7 @@ def load_gfpgan(
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
_device: DeviceParams,
|
||||
device: DeviceParams,
|
||||
):
|
||||
# must be within the load function for patch to take effect
|
||||
from gfpgan import GFPGANer
|
||||
|
@ -40,7 +40,7 @@ def load_gfpgan(
|
|||
)
|
||||
|
||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||
run_gc()
|
||||
run_gc([device])
|
||||
|
||||
return gfpgan
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ def load_resrgan(
|
|||
)
|
||||
|
||||
server.cache.set("resrgan", cache_key, upsampler)
|
||||
run_gc()
|
||||
run_gc([device])
|
||||
|
||||
return upsampler
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def load_stable_diffusion(
|
|||
)
|
||||
|
||||
server.cache.set("diffusion", cache_key, pipe)
|
||||
run_gc()
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
||||
|
|
|
@ -115,16 +115,16 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
scheduler = scheduler.to(device.torch_device())
|
||||
scheduler = scheduler.to(device.torch_str())
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||
run_gc()
|
||||
run_gc([device])
|
||||
|
||||
else:
|
||||
logger.debug("unloading previous diffusion pipeline")
|
||||
server.cache.drop("diffusion", pipe_key)
|
||||
run_gc()
|
||||
run_gc([device])
|
||||
|
||||
if lpw:
|
||||
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
|
||||
|
@ -149,7 +149,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device.torch_device())
|
||||
pipe = pipe.to(device.torch_str())
|
||||
|
||||
server.cache.set("diffusion", pipe_key, pipe)
|
||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||
|
|
|
@ -81,9 +81,11 @@ def run_txt2img_pipeline(
|
|||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del pipe
|
||||
del image
|
||||
del result
|
||||
run_gc()
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
logger.info("finished txt2img job: %s", dest)
|
||||
|
||||
|
@ -147,9 +149,11 @@ def run_img2img_pipeline(
|
|||
size = Size(*source_image.size)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del pipe
|
||||
del image
|
||||
del result
|
||||
run_gc()
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
logger.info("finished img2img job: %s", dest)
|
||||
|
||||
|
@ -200,7 +204,8 @@ def run_inpaint_pipeline(
|
|||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
logger.info("finished inpaint job: %s", dest)
|
||||
|
||||
|
@ -226,7 +231,8 @@ def run_upscale_pipeline(
|
|||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
logger.info("finished upscale job: %s", dest)
|
||||
|
||||
|
@ -263,6 +269,7 @@ def run_blend_pipeline(
|
|||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
logger.info("finished blend job: %s", dest)
|
||||
|
|
|
@ -90,7 +90,7 @@ class DeviceParams:
|
|||
def sess_options(self) -> SessionOptions:
|
||||
return SessionOptions()
|
||||
|
||||
def torch_device(self) -> str:
|
||||
def torch_str(self) -> str:
|
||||
if self.device.startswith("cuda"):
|
||||
return self.device
|
||||
else:
|
||||
|
|
|
@ -248,7 +248,7 @@ class DevicePoolExecutor:
|
|||
key,
|
||||
format_exception(type(err), err, err.__traceback__),
|
||||
)
|
||||
run_gc()
|
||||
run_gc([self.devices[device]])
|
||||
|
||||
future.add_done_callback(job_done)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from .params import SizeChart
|
||||
from .params import DeviceParams, SizeChart
|
||||
from .server.model_cache import ModelCache
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -134,7 +134,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
|||
raise Exception("invalid size")
|
||||
|
||||
|
||||
def run_gc():
|
||||
def run_gc(devices: List[DeviceParams] = []):
|
||||
logger.debug("running garbage collection")
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
for device in devices:
|
||||
logger.debug("running Torch garbage collection for device: %s", device)
|
||||
with torch.cuda.device(device.torch_str()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
|
Loading…
Reference in New Issue