lint(api): encapsulate paths in server context class
This commit is contained in:
parent
a76793d105
commit
13a4fa2278
|
@ -22,13 +22,13 @@ from .utils import (
|
|||
BaseParams,
|
||||
Border,
|
||||
OutputPath,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None, None)
|
||||
last_pipeline_scheduler = None
|
||||
model_path = None
|
||||
|
||||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||
|
||||
|
@ -42,17 +42,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
|||
return image_latents
|
||||
|
||||
|
||||
def get_model_path(model: str):
|
||||
return safer_join(model_path, model)
|
||||
|
||||
|
||||
# TODO: hax
|
||||
def set_model_path(model: str):
|
||||
global model_path
|
||||
model_path = model
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
@ -82,7 +72,12 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
return pipe
|
||||
|
||||
|
||||
def run_txt2img_pipeline(params: BaseParams, size: Size, output: OutputPath):
|
||||
def run_txt2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
size: Size,
|
||||
output: OutputPath
|
||||
):
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
|
@ -99,13 +94,19 @@ def run_txt2img_pipeline(params: BaseParams, size: Size, output: OutputPath):
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image = upscale_resrgan(image, ctx.models)
|
||||
image.save(output.path)
|
||||
|
||||
print('saved txt2img output: %s' % (output.file))
|
||||
|
||||
|
||||
def run_img2img_pipeline(params: BaseParams, output: OutputPath, strength: float, input_image: Image):
|
||||
def run_img2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
output: OutputPath,
|
||||
source_image: Image,
|
||||
strength: float
|
||||
):
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
|
@ -115,18 +116,19 @@ def run_img2img_pipeline(params: BaseParams, output: OutputPath, strength: float
|
|||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=input_image,
|
||||
image=source_image,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
).images[0]
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image = upscale_resrgan(image, ctx.model_path)
|
||||
image.save(output.path)
|
||||
|
||||
print('saved img2img output: %s' % (output.file))
|
||||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
size: Size,
|
||||
output: OutputPath,
|
||||
|
|
|
@ -46,8 +46,10 @@ from .utils import (
|
|||
get_and_clamp_int,
|
||||
get_from_map,
|
||||
make_output_path,
|
||||
safer_join,
|
||||
BaseParams,
|
||||
Border,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
||||
|
||||
|
@ -116,6 +118,10 @@ def url_from_rule(rule) -> str:
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def get_model_path(model: str):
|
||||
return safer_join(model_path, model)
|
||||
|
||||
|
||||
def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
|
@ -200,6 +206,9 @@ app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
|
|||
CORS(app, origins=cors_origin)
|
||||
executor = Executor(app)
|
||||
|
||||
context = ServerContext(app, executor, bundle_path,
|
||||
model_path, output_path, params_path)
|
||||
|
||||
# routes
|
||||
|
||||
|
||||
|
@ -256,8 +265,8 @@ def list_schedulers():
|
|||
|
||||
@app.route('/api/img2img', methods=['POST'])
|
||||
def img2img():
|
||||
input_file = request.files.get('source')
|
||||
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
|
@ -275,9 +284,9 @@ def img2img():
|
|||
extras=(strength))
|
||||
print("img2img output: %s" % (output.path))
|
||||
|
||||
input_image.thumbnail((size.width, size.height))
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(output.file, run_img2img_pipeline,
|
||||
params, output, strength, input_image)
|
||||
context, params, output, source_image, strength)
|
||||
|
||||
return jsonify({
|
||||
'output': output.file,
|
||||
|
@ -298,7 +307,7 @@ def txt2img():
|
|||
print("txt2img output: %s" % (output.file))
|
||||
|
||||
executor.submit_stored(
|
||||
output.file, run_txt2img_pipeline, params, size, output)
|
||||
output.file, run_txt2img_pipeline, context, params, size, output)
|
||||
|
||||
return jsonify({
|
||||
'output': output.file,
|
||||
|
@ -352,6 +361,7 @@ def inpaint():
|
|||
executor.submit_stored(
|
||||
output.file,
|
||||
run_inpaint_pipeline,
|
||||
context,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from os import path
|
||||
import time
|
||||
from time import time
|
||||
from struct import pack
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
@ -9,18 +9,18 @@ Param = Union[str, int, float]
|
|||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class OutputPath:
|
||||
'''
|
||||
TODO: .path is only used in one place, can probably just be a str
|
||||
'''
|
||||
|
||||
def __init__(self, path, file):
|
||||
self.path = path
|
||||
self.file = file
|
||||
|
||||
|
||||
class BaseParams:
|
||||
def __init__(self, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Union[None, str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.scheduler = scheduler
|
||||
|
@ -44,15 +44,39 @@ class BaseParams:
|
|||
|
||||
|
||||
class Border:
|
||||
def __init__(self, left: int, right: int, top: int, bottom: int):
|
||||
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.top = top
|
||||
self.bottom = bottom
|
||||
|
||||
|
||||
class OutputPath:
|
||||
'''
|
||||
TODO: .path is only used in one place, can probably just be a str
|
||||
'''
|
||||
|
||||
def __init__(self, path, file) -> None:
|
||||
self.path = path
|
||||
self.file = file
|
||||
|
||||
|
||||
class ServerContext:
|
||||
def __init__(
|
||||
self,
|
||||
bundle_path: str,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
params_path: str
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
self.output_path = output_path
|
||||
self.params_path = params_path
|
||||
|
||||
|
||||
class Size:
|
||||
def __init__(self, width: int, height: int):
|
||||
def __init__(self, width: int, height: int) -> None:
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
|
@ -85,9 +109,6 @@ def safer_join(base: str, tail: str) -> str:
|
|||
|
||||
|
||||
def hash_value(sha, param: Param):
|
||||
'''
|
||||
TODO: include functions by name
|
||||
'''
|
||||
if param is None:
|
||||
return
|
||||
elif isinstance(param, float):
|
||||
|
@ -107,7 +128,7 @@ def make_output_path(
|
|||
size: Size,
|
||||
extras: Union[None, Tuple[Param]] = None
|
||||
) -> OutputPath:
|
||||
now = int(time.time())
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
||||
hash_value(sha, mode)
|
||||
|
|
Loading…
Reference in New Issue