1
0
Fork 0

fix(api): remove prompt from output name

This commit is contained in:
Sean Sube 2023-01-08 12:46:52 -06:00
parent f4ca6a0547
commit 0d4c0a5942
2 changed files with 10 additions and 11 deletions

View File

@ -20,9 +20,9 @@ from flask import Flask, jsonify, request, send_from_directory, url_for
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from stringcase import spinalcase
from struct import pack from struct import pack
from os import environ, makedirs, path, scandir from os import environ, makedirs, path, scandir
from typing import Tuple, Union
import numpy as np import numpy as np
# defaults # defaults
@ -74,15 +74,15 @@ pipeline_schedulers = {
} }
def get_and_clamp_float(args, key, default_value, max_value, min_value=0.0): def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0):
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args, key, default_value, max_value, min_value=1): def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1):
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args, key, values, default): def get_from_map(args, key: str, values, default):
selected = args.get(key, default) selected = args.get(key, default)
if selected in values: if selected in values:
return values[selected] return values[selected]
@ -90,7 +90,7 @@ def get_from_map(args, key, values, default):
return values[default] return values[default]
def get_model_path(model): def get_model_path(model: str):
return safer_join(model_path, model) return safer_join(model_path, model)
@ -104,7 +104,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents return image_latents
def load_pipeline(pipeline, model, provider, scheduler): def load_pipeline(pipeline, model: str, provider: str, scheduler):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
@ -141,9 +141,9 @@ def json_with_cors(data, origin='*'):
return res return res
def make_output_path(type, params): def make_output_path(type: str, params: Tuple[Union[str, int, float]]):
sha = sha256() sha = sha256()
sha.update(type) sha.update(type.encode('utf-8'))
for param in params: for param in params:
if isinstance(param, str): if isinstance(param, str):
sha.update(param.encode('utf-8')) sha.update(param.encode('utf-8'))
@ -154,7 +154,7 @@ def make_output_path(type, params):
else: else:
print('cannot hash param: %s, %s' % (param, type(param))) print('cannot hash param: %s, %s' % (param, type(param)))
output_file = 'txt2img_%s_%s.png' % (params[0], sha.hexdigest()) output_file = '%s_%s.png' % (type, sha.hexdigest())
output_full = safer_join(output_path, output_file) output_full = safer_join(output_path, output_file)
return (output_file, output_full) return (output_file, output_full)

View File

@ -8,5 +8,4 @@ protobuf<4,>=3.20.2
transformers transformers
### Server packages ### ### Server packages ###
flask flask
stringcase