From cab13f665a8d4d42b7ecd64f222eda3f560c8cc5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 5 Jan 2023 23:28:25 -0600 Subject: [PATCH] feat(api): cache pipeline when changing scheduler, make txt2img logging more verbose --- api/serve.py | 54 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/api/serve.py b/api/serve.py index bfea065a..107225e5 100644 --- a/api/serve.py +++ b/api/serve.py @@ -32,8 +32,9 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs") # pipeline caching available_models = [] -pipeline_options = (None, None, None) -pipeline_instance = None +last_pipeline_instance = None +last_pipeline_options = (None, None) +last_pipeline_scheduler = None # pipeline params platform_providers = { @@ -74,23 +75,31 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: def load_pipeline(model, provider, scheduler): - global pipeline_instance - global pipeline_options + global last_pipeline_instance + global last_pipeline_scheduler + global last_pipeline_options - options = (model, provider, scheduler) - if pipeline_instance != None and pipeline_options == options: + options = (model, provider) + if last_pipeline_instance != None and last_pipeline_options == options: print('reusing existing pipeline') - return pipeline_instance + pipe = last_pipeline_instance + else: + print('loading different pipeline') + pipe = OnnxStableDiffusionPipeline.from_pretrained( + model, + provider=provider, + safety_checker=None, + scheduler=scheduler.from_pretrained(model, subfolder="scheduler") + ) + last_pipeline_instance = pipe + last_pipeline_options = options + last_pipeline_scheduler = scheduler + + if last_pipeline_scheduler != scheduler: + print('changing pipeline scheduler') + pipe.scheduler = scheduler.from_pretrained(model, subfolder="scheduler") + last_pipeline_scheduler = scheduler - print('loading different pipeline') - pipe = OnnxStableDiffusionPipeline.from_pretrained( - model, - provider=provider, - safety_checker=None, - scheduler=scheduler.from_pretrained(model, subfolder="scheduler") - ) - pipeline_options = options - pipeline_instance = pipe return pipe @@ -101,6 +110,11 @@ def json_with_cors(data, origin='*'): return res +def safer_join(base, tail): + safer_path = path.relpath(path.normpath(path.join('/', tail)), '/') + return path.join(base, safer_path) + + def url_from_rule(rule): options = {} for arg in rule.arguments: @@ -162,7 +176,7 @@ def txt2img(): user = request.remote_addr # pipeline stuff - model = path.join(model_path, request.args.get('model')) + model = safer_join(model_path, request.args.get('model', 'stable-diffusion-onnx-v1-5')) provider = get_from_map(request.args, 'platform', platform_providers, 'amd') scheduler = get_from_map(request.args, 'scheduler', @@ -181,8 +195,8 @@ def txt2img(): latents = get_latents_from_seed(seed, width, height) - print("txt2img from %s: %s/%s, %sx%s, %s, %s" % - (user, cfg, steps, width, height, seed, prompt)) + print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" % + (user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt)) pipe = load_pipeline(model, provider, scheduler) image = pipe( @@ -195,7 +209,7 @@ def txt2img(): ).images[0] output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64])) - output_full = '%s/%s' % (output_path, output_file) + output_full = safer_join(output_path, output_file) print("txt2img output: %s" % output_full) image.save(output_full)