feat(api): cache pipeline when changing scheduler, make txt2img logging more verbose
This commit is contained in:
parent
f2ee2bb0e7
commit
cab13f665a
42
api/serve.py
42
api/serve.py
|
@ -32,8 +32,9 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
|||
|
||||
# pipeline caching
|
||||
available_models = []
|
||||
pipeline_options = (None, None, None)
|
||||
pipeline_instance = None
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None)
|
||||
last_pipeline_scheduler = None
|
||||
|
||||
# pipeline params
|
||||
platform_providers = {
|
||||
|
@ -74,14 +75,15 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
|||
|
||||
|
||||
def load_pipeline(model, provider, scheduler):
|
||||
global pipeline_instance
|
||||
global pipeline_options
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (model, provider, scheduler)
|
||||
if pipeline_instance != None and pipeline_options == options:
|
||||
options = (model, provider)
|
||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||
print('reusing existing pipeline')
|
||||
return pipeline_instance
|
||||
|
||||
pipe = last_pipeline_instance
|
||||
else:
|
||||
print('loading different pipeline')
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
model,
|
||||
|
@ -89,8 +91,15 @@ def load_pipeline(model, provider, scheduler):
|
|||
safety_checker=None,
|
||||
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
||||
)
|
||||
pipeline_options = options
|
||||
pipeline_instance = pipe
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
pipe.scheduler = scheduler.from_pretrained(model, subfolder="scheduler")
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
|
@ -101,6 +110,11 @@ def json_with_cors(data, origin='*'):
|
|||
return res
|
||||
|
||||
|
||||
def safer_join(base, tail):
|
||||
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||
return path.join(base, safer_path)
|
||||
|
||||
|
||||
def url_from_rule(rule):
|
||||
options = {}
|
||||
for arg in rule.arguments:
|
||||
|
@ -162,7 +176,7 @@ def txt2img():
|
|||
user = request.remote_addr
|
||||
|
||||
# pipeline stuff
|
||||
model = path.join(model_path, request.args.get('model'))
|
||||
model = safer_join(model_path, request.args.get('model', 'stable-diffusion-onnx-v1-5'))
|
||||
provider = get_from_map(request.args, 'platform',
|
||||
platform_providers, 'amd')
|
||||
scheduler = get_from_map(request.args, 'scheduler',
|
||||
|
@ -181,8 +195,8 @@ def txt2img():
|
|||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
|
||||
print("txt2img from %s: %s/%s, %sx%s, %s, %s" %
|
||||
(user, cfg, steps, width, height, seed, prompt))
|
||||
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
pipe = load_pipeline(model, provider, scheduler)
|
||||
image = pipe(
|
||||
|
@ -195,7 +209,7 @@ def txt2img():
|
|||
).images[0]
|
||||
|
||||
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
||||
output_full = '%s/%s' % (output_path, output_file)
|
||||
output_full = safer_join(output_path, output_file)
|
||||
print("txt2img output: %s" % output_full)
|
||||
image.save(output_full)
|
||||
|
||||
|
|
Loading…
Reference in New Issue