1
0
Fork 0

feat(api): cache pipeline when changing scheduler, make txt2img logging more verbose

This commit is contained in:
Sean Sube 2023-01-05 23:28:25 -06:00
parent f2ee2bb0e7
commit cab13f665a
1 changed files with 34 additions and 20 deletions

View File

@ -32,8 +32,9 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
# pipeline caching
available_models = []
pipeline_options = (None, None, None)
pipeline_instance = None
last_pipeline_instance = None
last_pipeline_options = (None, None)
last_pipeline_scheduler = None
# pipeline params
platform_providers = {
@ -74,14 +75,15 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
def load_pipeline(model, provider, scheduler):
global pipeline_instance
global pipeline_options
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (model, provider, scheduler)
if pipeline_instance != None and pipeline_options == options:
options = (model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline')
return pipeline_instance
pipe = last_pipeline_instance
else:
print('loading different pipeline')
pipe = OnnxStableDiffusionPipeline.from_pretrained(
model,
@ -89,8 +91,15 @@ def load_pipeline(model, provider, scheduler):
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
)
pipeline_options = options
pipeline_instance = pipe
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(model, subfolder="scheduler")
last_pipeline_scheduler = scheduler
return pipe
@ -101,6 +110,11 @@ def json_with_cors(data, origin='*'):
return res
def safer_join(base, tail):
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, safer_path)
def url_from_rule(rule):
options = {}
for arg in rule.arguments:
@ -162,7 +176,7 @@ def txt2img():
user = request.remote_addr
# pipeline stuff
model = path.join(model_path, request.args.get('model'))
model = safer_join(model_path, request.args.get('model', 'stable-diffusion-onnx-v1-5'))
provider = get_from_map(request.args, 'platform',
platform_providers, 'amd')
scheduler = get_from_map(request.args, 'scheduler',
@ -181,8 +195,8 @@ def txt2img():
latents = get_latents_from_seed(seed, width, height)
print("txt2img from %s: %s/%s, %sx%s, %s, %s" %
(user, cfg, steps, width, height, seed, prompt))
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
pipe = load_pipeline(model, provider, scheduler)
image = pipe(
@ -195,7 +209,7 @@ def txt2img():
).images[0]
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
output_full = '%s/%s' % (output_path, output_file)
output_full = safer_join(output_path, output_file)
print("txt2img output: %s" % output_full)
image.save(output_full)