feat(api): cache pipeline when changing scheduler, make txt2img logging more verbose
This commit is contained in:
parent
f2ee2bb0e7
commit
cab13f665a
54
api/serve.py
54
api/serve.py
|
@ -32,8 +32,9 @@ output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
|
||||||
|
|
||||||
# pipeline caching
|
# pipeline caching
|
||||||
available_models = []
|
available_models = []
|
||||||
pipeline_options = (None, None, None)
|
last_pipeline_instance = None
|
||||||
pipeline_instance = None
|
last_pipeline_options = (None, None)
|
||||||
|
last_pipeline_scheduler = None
|
||||||
|
|
||||||
# pipeline params
|
# pipeline params
|
||||||
platform_providers = {
|
platform_providers = {
|
||||||
|
@ -74,23 +75,31 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(model, provider, scheduler):
|
def load_pipeline(model, provider, scheduler):
|
||||||
global pipeline_instance
|
global last_pipeline_instance
|
||||||
global pipeline_options
|
global last_pipeline_scheduler
|
||||||
|
global last_pipeline_options
|
||||||
|
|
||||||
options = (model, provider, scheduler)
|
options = (model, provider)
|
||||||
if pipeline_instance != None and pipeline_options == options:
|
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||||
print('reusing existing pipeline')
|
print('reusing existing pipeline')
|
||||||
return pipeline_instance
|
pipe = last_pipeline_instance
|
||||||
|
else:
|
||||||
|
print('loading different pipeline')
|
||||||
|
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||||
|
model,
|
||||||
|
provider=provider,
|
||||||
|
safety_checker=None,
|
||||||
|
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
||||||
|
)
|
||||||
|
last_pipeline_instance = pipe
|
||||||
|
last_pipeline_options = options
|
||||||
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
|
if last_pipeline_scheduler != scheduler:
|
||||||
|
print('changing pipeline scheduler')
|
||||||
|
pipe.scheduler = scheduler.from_pretrained(model, subfolder="scheduler")
|
||||||
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
print('loading different pipeline')
|
|
||||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
|
||||||
model,
|
|
||||||
provider=provider,
|
|
||||||
safety_checker=None,
|
|
||||||
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
|
|
||||||
)
|
|
||||||
pipeline_options = options
|
|
||||||
pipeline_instance = pipe
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,6 +110,11 @@ def json_with_cors(data, origin='*'):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def safer_join(base, tail):
|
||||||
|
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||||
|
return path.join(base, safer_path)
|
||||||
|
|
||||||
|
|
||||||
def url_from_rule(rule):
|
def url_from_rule(rule):
|
||||||
options = {}
|
options = {}
|
||||||
for arg in rule.arguments:
|
for arg in rule.arguments:
|
||||||
|
@ -162,7 +176,7 @@ def txt2img():
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
model = path.join(model_path, request.args.get('model'))
|
model = safer_join(model_path, request.args.get('model', 'stable-diffusion-onnx-v1-5'))
|
||||||
provider = get_from_map(request.args, 'platform',
|
provider = get_from_map(request.args, 'platform',
|
||||||
platform_providers, 'amd')
|
platform_providers, 'amd')
|
||||||
scheduler = get_from_map(request.args, 'scheduler',
|
scheduler = get_from_map(request.args, 'scheduler',
|
||||||
|
@ -181,8 +195,8 @@ def txt2img():
|
||||||
|
|
||||||
latents = get_latents_from_seed(seed, width, height)
|
latents = get_latents_from_seed(seed, width, height)
|
||||||
|
|
||||||
print("txt2img from %s: %s/%s, %sx%s, %s, %s" %
|
print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||||
(user, cfg, steps, width, height, seed, prompt))
|
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||||
|
|
||||||
pipe = load_pipeline(model, provider, scheduler)
|
pipe = load_pipeline(model, provider, scheduler)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
|
@ -195,7 +209,7 @@ def txt2img():
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
|
||||||
output_full = '%s/%s' % (output_path, output_file)
|
output_full = safer_join(output_path, output_file)
|
||||||
print("txt2img output: %s" % output_full)
|
print("txt2img output: %s" % output_full)
|
||||||
image.save(output_full)
|
image.save(output_full)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue