1
0
Fork 0

lint(api): encapsulate paths in server context class

This commit is contained in:
Sean Sube 2023-01-16 07:31:42 -06:00
parent a76793d105
commit 13a4fa2278
3 changed files with 73 additions and 40 deletions

View File

@ -22,13 +22,13 @@ from .utils import (
BaseParams,
Border,
OutputPath,
ServerContext,
Size,
)
last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None
model_path = None
# from https://www.travelneil.com/stable-diffusion-updates.html
@ -42,17 +42,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents
def get_model_path(model: str):
return safer_join(model_path, model)
# TODO: hax
def set_model_path(model: str):
global model_path
model_path = model
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
@ -82,7 +72,12 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
return pipe
def run_txt2img_pipeline(params: BaseParams, size: Size, output: OutputPath):
def run_txt2img_pipeline(
ctx: ServerContext,
params: BaseParams,
size: Size,
output: OutputPath
):
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
@ -99,13 +94,19 @@ def run_txt2img_pipeline(params: BaseParams, size: Size, output: OutputPath):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
).images[0]
image = upscale_resrgan(image, model_path)
image = upscale_resrgan(image, ctx.models)
image.save(output.path)
print('saved txt2img output: %s' % (output.file))
def run_img2img_pipeline(params: BaseParams, output: OutputPath, strength: float, input_image: Image):
def run_img2img_pipeline(
ctx: ServerContext,
params: BaseParams,
output: OutputPath,
source_image: Image,
strength: float
):
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
@ -115,18 +116,19 @@ def run_img2img_pipeline(params: BaseParams, output: OutputPath, strength: float
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=input_image,
image=source_image,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
).images[0]
image = upscale_resrgan(image, model_path)
image = upscale_resrgan(image, ctx.model_path)
image.save(output.path)
print('saved img2img output: %s' % (output.file))
def run_inpaint_pipeline(
ctx: ServerContext,
params: BaseParams,
size: Size,
output: OutputPath,

View File

@ -46,8 +46,10 @@ from .utils import (
get_and_clamp_int,
get_from_map,
make_output_path,
safer_join,
BaseParams,
Border,
ServerContext,
Size,
)
@ -116,6 +118,10 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)
def get_model_path(model: str):
return safer_join(model_path, model)
def pipeline_from_request() -> Tuple[BaseParams, Size]:
user = request.remote_addr
@ -200,6 +206,9 @@ app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
CORS(app, origins=cors_origin)
executor = Executor(app)
context = ServerContext(app, executor, bundle_path,
model_path, output_path, params_path)
# routes
@ -256,8 +265,8 @@ def list_schedulers():
@app.route('/api/img2img', methods=['POST'])
def img2img():
input_file = request.files.get('source')
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
strength = get_and_clamp_float(
request.args,
@ -275,9 +284,9 @@ def img2img():
extras=(strength))
print("img2img output: %s" % (output.path))
input_image.thumbnail((size.width, size.height))
source_image.thumbnail((size.width, size.height))
executor.submit_stored(output.file, run_img2img_pipeline,
params, output, strength, input_image)
context, params, output, source_image, strength)
return jsonify({
'output': output.file,
@ -298,7 +307,7 @@ def txt2img():
print("txt2img output: %s" % (output.file))
executor.submit_stored(
output.file, run_txt2img_pipeline, params, size, output)
output.file, run_txt2img_pipeline, context, params, size, output)
return jsonify({
'output': output.file,
@ -352,6 +361,7 @@ def inpaint():
executor.submit_stored(
output.file,
run_inpaint_pipeline,
context,
params,
size,
output,

View File

@ -1,5 +1,5 @@
from os import path
import time
from time import time
from struct import pack
from typing import Any, Dict, Tuple, Union
from hashlib import sha256
@ -9,18 +9,18 @@ Param = Union[str, int, float]
Point = Tuple[int, int]
class OutputPath:
'''
TODO: .path is only used in one place, can probably just be a str
'''
def __init__(self, path, file):
self.path = path
self.file = file
class BaseParams:
def __init__(self, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed):
def __init__(
self,
model: str,
provider: str,
scheduler: Any,
prompt: str,
negative_prompt: Union[None, str],
cfg: float,
steps: int,
seed: int
) -> None:
self.model = model
self.provider = provider
self.scheduler = scheduler
@ -44,15 +44,39 @@ class BaseParams:
class Border:
def __init__(self, left: int, right: int, top: int, bottom: int):
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
self.left = left
self.right = right
self.top = top
self.bottom = bottom
class OutputPath:
'''
TODO: .path is only used in one place, can probably just be a str
'''
def __init__(self, path, file) -> None:
self.path = path
self.file = file
class ServerContext:
def __init__(
self,
bundle_path: str,
model_path: str,
output_path: str,
params_path: str
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
class Size:
def __init__(self, width: int, height: int):
def __init__(self, width: int, height: int) -> None:
self.width = width
self.height = height
@ -85,9 +109,6 @@ def safer_join(base: str, tail: str) -> str:
def hash_value(sha, param: Param):
'''
TODO: include functions by name
'''
if param is None:
return
elif isinstance(param, float):
@ -107,7 +128,7 @@ def make_output_path(
size: Size,
extras: Union[None, Tuple[Param]] = None
) -> OutputPath:
now = int(time.time())
now = int(time())
sha = sha256()
hash_value(sha, mode)