fix(api): run torch gc alongside python (#156)
This commit is contained in:
parent
1ca0c01529
commit
0ed4af18ad
|
@ -26,5 +26,5 @@ def correct_codeformer(
|
||||||
from codeformer import CodeFormer
|
from codeformer import CodeFormer
|
||||||
|
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
return pipe(stage_source or source)
|
return pipe(stage_source or source)
|
||||||
|
|
|
@ -15,7 +15,7 @@ def load_gfpgan(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
_device: DeviceParams,
|
device: DeviceParams,
|
||||||
):
|
):
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
@ -40,7 +40,7 @@ def load_gfpgan(
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||||
run_gc()
|
run_gc([device])
|
||||||
|
|
||||||
return gfpgan
|
return gfpgan
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ def load_resrgan(
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("resrgan", cache_key, upsampler)
|
server.cache.set("resrgan", cache_key, upsampler)
|
||||||
run_gc()
|
run_gc([device])
|
||||||
|
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ def load_stable_diffusion(
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("diffusion", cache_key, pipe)
|
server.cache.set("diffusion", cache_key, pipe)
|
||||||
run_gc()
|
run_gc([device])
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
|
@ -115,16 +115,16 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, "to"):
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
scheduler = scheduler.to(device.torch_device())
|
scheduler = scheduler.to(device.torch_str())
|
||||||
|
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||||
run_gc()
|
run_gc([device])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("unloading previous diffusion pipeline")
|
logger.debug("unloading previous diffusion pipeline")
|
||||||
server.cache.drop("diffusion", pipe_key)
|
server.cache.drop("diffusion", pipe_key)
|
||||||
run_gc()
|
run_gc([device])
|
||||||
|
|
||||||
if lpw:
|
if lpw:
|
||||||
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
|
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
|
||||||
|
@ -149,7 +149,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, "to"):
|
if device is not None and hasattr(pipe, "to"):
|
||||||
pipe = pipe.to(device.torch_device())
|
pipe = pipe.to(device.torch_str())
|
||||||
|
|
||||||
server.cache.set("diffusion", pipe_key, pipe)
|
server.cache.set("diffusion", pipe_key, pipe)
|
||||||
server.cache.set("scheduler", scheduler_key, scheduler)
|
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||||
|
|
|
@ -81,9 +81,11 @@ def run_txt2img_pipeline(
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
|
||||||
|
del pipe
|
||||||
del image
|
del image
|
||||||
del result
|
del result
|
||||||
run_gc()
|
|
||||||
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
logger.info("finished txt2img job: %s", dest)
|
logger.info("finished txt2img job: %s", dest)
|
||||||
|
|
||||||
|
@ -147,9 +149,11 @@ def run_img2img_pipeline(
|
||||||
size = Size(*source_image.size)
|
size = Size(*source_image.size)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
|
||||||
|
del pipe
|
||||||
del image
|
del image
|
||||||
del result
|
del result
|
||||||
run_gc()
|
|
||||||
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
logger.info("finished img2img job: %s", dest)
|
logger.info("finished img2img job: %s", dest)
|
||||||
|
|
||||||
|
@ -200,7 +204,8 @@ def run_inpaint_pipeline(
|
||||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
run_gc()
|
|
||||||
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
logger.info("finished inpaint job: %s", dest)
|
logger.info("finished inpaint job: %s", dest)
|
||||||
|
|
||||||
|
@ -226,7 +231,8 @@ def run_upscale_pipeline(
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
run_gc()
|
|
||||||
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
logger.info("finished upscale job: %s", dest)
|
logger.info("finished upscale job: %s", dest)
|
||||||
|
|
||||||
|
@ -263,6 +269,7 @@ def run_blend_pipeline(
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
run_gc()
|
|
||||||
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
logger.info("finished blend job: %s", dest)
|
logger.info("finished blend job: %s", dest)
|
||||||
|
|
|
@ -90,7 +90,7 @@ class DeviceParams:
|
||||||
def sess_options(self) -> SessionOptions:
|
def sess_options(self) -> SessionOptions:
|
||||||
return SessionOptions()
|
return SessionOptions()
|
||||||
|
|
||||||
def torch_device(self) -> str:
|
def torch_str(self) -> str:
|
||||||
if self.device.startswith("cuda"):
|
if self.device.startswith("cuda"):
|
||||||
return self.device
|
return self.device
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -248,7 +248,7 @@ class DevicePoolExecutor:
|
||||||
key,
|
key,
|
||||||
format_exception(type(err), err, err.__traceback__),
|
format_exception(type(err), err, err.__traceback__),
|
||||||
)
|
)
|
||||||
run_gc()
|
run_gc([self.devices[device]])
|
||||||
|
|
||||||
future.add_done_callback(job_done)
|
future.add_done_callback(job_done)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .params import SizeChart
|
from .params import DeviceParams, SizeChart
|
||||||
from .server.model_cache import ModelCache
|
from .server.model_cache import ModelCache
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -134,7 +134,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||||
raise Exception("invalid size")
|
raise Exception("invalid size")
|
||||||
|
|
||||||
|
|
||||||
def run_gc():
|
def run_gc(devices: List[DeviceParams] = []):
|
||||||
logger.debug("running garbage collection")
|
logger.debug("running garbage collection")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for device in devices:
|
||||||
|
logger.debug("running Torch garbage collection for device: %s", device)
|
||||||
|
with torch.cuda.device(device.torch_str()):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
Loading…
Reference in New Issue