1
0
Fork 0

fix(api): run GC after changing pipeline (#58)

This commit is contained in:
Sean Sube 2023-01-19 19:46:36 -06:00
parent 9a2e7adfb8
commit 4a3bb97342
3 changed files with 15 additions and 2 deletions

View File

@ -5,10 +5,10 @@ from diffusers import (
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
) )
from os import environ
from PIL import Image, ImageChops from PIL import Image, ImageChops
from typing import Any from typing import Any
import gc
import numpy as np import numpy as np
from .image import ( from .image import (
@ -19,6 +19,7 @@ from .upscale import (
UpscaleParams, UpscaleParams,
) )
from .utils import ( from .utils import (
is_debug,
safer_join, safer_join,
BaseParams, BaseParams,
Border, Border,
@ -70,6 +71,9 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
model, subfolder='scheduler') model, subfolder='scheduler')
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
print('running garbage collection during pipeline change')
gc.collect()
return pipe return pipe
@ -167,7 +171,7 @@ def run_inpaint_pipeline(
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter) mask_filter=mask_filter)
if environ.get('DEBUG') is not None: if is_debug():
source_image.save(safer_join(ctx.output_path, 'last-source.png')) source_image.save(safer_join(ctx.output_path, 'last-source.png'))
mask_image.save(safer_join(ctx.output_path, 'last-mask.png')) mask_image.save(safer_join(ctx.output_path, 'last-mask.png'))
noise_image.save(safer_join(ctx.output_path, 'last-noise.png')) noise_image.save(safer_join(ctx.output_path, 'last-noise.png'))

View File

@ -45,6 +45,7 @@ from .upscale import (
UpscaleParams, UpscaleParams,
) )
from .utils import ( from .utils import (
is_debug,
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
get_from_list, get_from_list,
@ -57,6 +58,7 @@ from .utils import (
Size, Size,
) )
import gc
import json import json
import numpy as np import numpy as np
@ -259,6 +261,9 @@ app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
CORS(app, origins=context.cors_origin) CORS(app, origins=context.cors_origin)
executor = Executor(app) executor = Executor(app)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
# TODO: these two use context # TODO: these two use context

View File

@ -99,6 +99,10 @@ class Size:
} }
def is_debug() -> bool:
return environ.get('DEBUG') is not None
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)