lint(api): add class for params, size, other common data
This commit is contained in:
parent
1146118b1a
commit
ff70d36a05
|
@ -4,8 +4,13 @@ from typing import Tuple
|
|||
|
||||
import numpy as np
|
||||
|
||||
from .utils import (
|
||||
Border,
|
||||
Point,
|
||||
)
|
||||
|
||||
def mask_filter_none(mask_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], fill='white') -> Image:
|
||||
|
||||
def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white') -> Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new('RGB', (width, height), fill)
|
||||
|
@ -14,7 +19,7 @@ def mask_filter_none(mask_image: Image, dims: Tuple[int, int], origin: Tuple[int
|
|||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_multiply(mask_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], rounds=3) -> Image:
|
||||
def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3) -> Image:
|
||||
'''
|
||||
Gaussian blur with multiply, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -27,7 +32,7 @@ def mask_filter_gaussian_multiply(mask_image: Image, dims: Tuple[int, int], orig
|
|||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_screen(mask_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], rounds=3) -> Image:
|
||||
def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3) -> Image:
|
||||
'''
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -40,7 +45,7 @@ def mask_filter_gaussian_screen(mask_image: Image, dims: Tuple[int, int], origin
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_edge(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], fill='white') -> Image:
|
||||
def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white') -> Image:
|
||||
'''
|
||||
Identity transform, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -52,7 +57,7 @@ def noise_source_fill_edge(source_image: Image, dims: Tuple[int, int], origin: T
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_mask(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], fill='white') -> Image:
|
||||
def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white') -> Image:
|
||||
'''
|
||||
Fill the whole canvas, no source or noise.
|
||||
'''
|
||||
|
@ -63,7 +68,7 @@ def noise_source_fill_mask(source_image: Image, dims: Tuple[int, int], origin: T
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_gaussian(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int], rounds=3) -> Image:
|
||||
def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3) -> Image:
|
||||
'''
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -76,7 +81,7 @@ def noise_source_gaussian(source_image: Image, dims: Tuple[int, int], origin: Tu
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_uniform(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int]) -> Image:
|
||||
def noise_source_uniform(source_image: Image, dims: Point, origin: Point) -> Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -98,7 +103,7 @@ def noise_source_uniform(source_image: Image, dims: Tuple[int, int], origin: Tup
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_normal(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int]) -> Image:
|
||||
def noise_source_normal(source_image: Image, dims: Point, origin: Point) -> Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -120,7 +125,7 @@ def noise_source_normal(source_image: Image, dims: Tuple[int, int], origin: Tupl
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_histogram(source_image: Image, dims: Tuple[int, int], origin: Tuple[int, int]) -> Image:
|
||||
def noise_source_histogram(source_image: Image, dims: Point, origin: Point) -> Image:
|
||||
r, g, b = source_image.split()
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -154,18 +159,16 @@ def noise_source_histogram(source_image: Image, dims: Tuple[int, int], origin: T
|
|||
def expand_image(
|
||||
source_image: Image,
|
||||
mask_image: Image,
|
||||
expand_by: Tuple[int, int, int, int],
|
||||
expand: Border,
|
||||
fill='white',
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
left, right, top, bottom = expand_by
|
||||
|
||||
full_width = left + source_image.width + right
|
||||
full_height = top + source_image.height + bottom
|
||||
full_width = expand.left + source_image.width + expand.right
|
||||
full_height = expand.top + source_image.height + expand.bottom
|
||||
|
||||
dims = (full_width, full_height)
|
||||
origin = (top, left)
|
||||
origin = (expand.top, expand.left)
|
||||
|
||||
full_source = Image.new('RGB', dims, fill)
|
||||
full_source.paste(source_image, origin)
|
||||
|
|
|
@ -7,7 +7,7 @@ from diffusers import (
|
|||
)
|
||||
from os import environ
|
||||
from PIL import Image
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -18,7 +18,10 @@ from .upscale import (
|
|||
upscale_resrgan,
|
||||
)
|
||||
from .utils import (
|
||||
safer_join
|
||||
safer_join,
|
||||
BaseParams,
|
||||
Border,
|
||||
Size,
|
||||
)
|
||||
|
||||
last_pipeline_instance = None
|
||||
|
@ -28,9 +31,9 @@ last_pipeline_scheduler = None
|
|||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||
|
||||
|
||||
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||
def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
||||
# 1 is batch size
|
||||
latents_shape = (1, 4, height // 8, width // 8)
|
||||
latents_shape = (1, 4, size.height // 8, size.width // 8)
|
||||
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
||||
rng = np.random.default_rng(seed)
|
||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||
|
@ -67,82 +70,70 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
return pipe
|
||||
|
||||
|
||||
def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, height, width):
|
||||
def run_txt2img_pipeline(params: BaseParams, size: Size):
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
model, provider, scheduler)
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
rng = np.random.RandomState(seed)
|
||||
latents = get_latents_from_seed(params.seed, size.width, size.height)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
params.prompt,
|
||||
size.width,
|
||||
size.height,
|
||||
generator=rng,
|
||||
guidance_scale=cfg,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image.save(output)
|
||||
image.save(params.output.file)
|
||||
|
||||
print('saved txt2img output: %s' % (output))
|
||||
print('saved txt2img output: %s' % (params.output.file))
|
||||
|
||||
|
||||
def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed, output, strength, input_image):
|
||||
def run_img2img_pipeline(params: BaseParams, strength, input_image):
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
model, provider, scheduler)
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=cfg,
|
||||
guidance_scale=params.cfg,
|
||||
image=input_image,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
).images[0]
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image.save(output)
|
||||
image.save(params.output.file)
|
||||
|
||||
print('saved img2img output: %s' % (output))
|
||||
print('saved img2img output: %s' % (params.output.file))
|
||||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
model: str,
|
||||
provider: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Union[str, None],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int,
|
||||
output: str,
|
||||
height: int,
|
||||
width: int,
|
||||
params: BaseParams,
|
||||
size: Size,
|
||||
source_image: Image,
|
||||
mask_image: Image,
|
||||
left: int,
|
||||
right: int,
|
||||
top: int,
|
||||
bottom: int,
|
||||
expand: Border,
|
||||
noise_source: Any,
|
||||
mask_filter: Any
|
||||
):
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
model, provider, scheduler)
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
latents = get_latents_from_seed(seed, width, height)
|
||||
rng = np.random.RandomState(seed)
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
print('applying mask filter and generating noise source')
|
||||
source_image, mask_image, noise_image, _full_dims = expand_image(
|
||||
source_image,
|
||||
mask_image,
|
||||
(left, right, top, bottom),
|
||||
expand,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter)
|
||||
|
||||
|
@ -152,18 +143,18 @@ def run_inpaint_pipeline(
|
|||
noise_image.save(safer_join(output_path, 'last-noise.png'))
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=cfg,
|
||||
height=height,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=source_image,
|
||||
latents=latents,
|
||||
mask_image=mask_image,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
width=width,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
image.save(params.output.file)
|
||||
|
||||
print('saved inpaint output: %s' % (output))
|
||||
print('saved inpaint output: %s' % (params.output.file))
|
||||
|
|
|
@ -46,7 +46,10 @@ from .utils import (
|
|||
get_and_clamp_float,
|
||||
get_and_clamp_int,
|
||||
get_from_map,
|
||||
safer_join
|
||||
safer_join,
|
||||
BaseParams,
|
||||
OutputPath,
|
||||
Size,
|
||||
)
|
||||
|
||||
import json
|
||||
|
@ -111,11 +114,14 @@ def serve_bundle_file(filename='index.html'):
|
|||
return send_from_directory(path.join('..', bundle_path), filename)
|
||||
|
||||
|
||||
def make_output_path(mode: str, seed: int, params: Tuple[Union[str, int, float]]):
|
||||
def make_output_path(mode: str, params: BaseParams, size: Size, extras: Tuple[Union[str, int, float]]) -> OutputPath:
|
||||
now = int(time.time())
|
||||
sha = sha256()
|
||||
sha.update(mode.encode('utf-8'))
|
||||
|
||||
# TODO: add params
|
||||
# TODO: add size
|
||||
|
||||
for param in params:
|
||||
if param is None:
|
||||
continue
|
||||
|
@ -128,10 +134,10 @@ def make_output_path(mode: str, seed: int, params: Tuple[Union[str, int, float]]
|
|||
else:
|
||||
print('cannot hash param: %s, %s' % (param, type(param)))
|
||||
|
||||
output_file = '%s_%s_%s_%s.png' % (mode, seed, sha.hexdigest(), now)
|
||||
output_file = '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
|
||||
output_full = safer_join(output_path, output_file)
|
||||
|
||||
return (output_file, output_full)
|
||||
return OutputPath(output_full, output_file)
|
||||
|
||||
|
||||
def url_from_rule(rule):
|
||||
|
@ -142,7 +148,7 @@ def url_from_rule(rule):
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def pipeline_from_request():
|
||||
def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
# pipeline stuff
|
||||
|
@ -189,7 +195,9 @@ def pipeline_from_request():
|
|||
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed)
|
||||
params = BaseParams(model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed)
|
||||
size = Size(width, height)
|
||||
return (params, size)
|
||||
|
||||
|
||||
def check_paths():
|
||||
|
@ -283,27 +291,17 @@ def img2img():
|
|||
|
||||
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
|
||||
|
||||
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
|
||||
width, seed) = pipeline_from_request()
|
||||
params, size = pipeline_from_request()
|
||||
|
||||
(output_file, output_full) = make_output_path(
|
||||
output = make_output_path(
|
||||
'img2img',
|
||||
seed, (
|
||||
model,
|
||||
provider,
|
||||
scheduler.__name__,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
cfg,
|
||||
steps,
|
||||
strength,
|
||||
height,
|
||||
width))
|
||||
print("img2img output: %s" % (output_full))
|
||||
params,
|
||||
size,
|
||||
extras=(strength))
|
||||
print("img2img output: %s" % (output.path))
|
||||
|
||||
input_image.thumbnail((width, height))
|
||||
executor.submit_stored(output_file, run_img2img_pipeline, model, provider,
|
||||
scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image)
|
||||
input_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(output.file, run_img2img_pipeline, params, output, strength, input_image)
|
||||
|
||||
return jsonify({
|
||||
'output': output_file,
|
||||
|
|
|
@ -2,9 +2,7 @@ from os import path
|
|||
from typing import Any, Dict, Tuple
|
||||
|
||||
|
||||
Border = Tuple[int, int, int, int]
|
||||
Point = Tuple[int, int]
|
||||
Size = Tuple[int, int]
|
||||
|
||||
|
||||
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
||||
|
@ -26,3 +24,36 @@ def get_from_map(args, key: str, values: Dict[str, Any], default: Any):
|
|||
def safer_join(base, tail):
|
||||
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||
return path.join(base, safer_path)
|
||||
|
||||
|
||||
class OutputPath:
|
||||
def __init__(self, path, file):
|
||||
self.path = path
|
||||
self.file = file
|
||||
|
||||
|
||||
class BaseParams:
|
||||
def __init__(self, model, provider, scheduler, prompt, negative_prompt, cfg, steps, seed):
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.scheduler = scheduler
|
||||
self.prompt = prompt
|
||||
self.negative_prompt = negative_prompt
|
||||
self.cfg = cfg
|
||||
self.steps = steps
|
||||
self.seed = seed
|
||||
self.output = output
|
||||
|
||||
|
||||
class Border:
|
||||
def __init__(self, left, right, top, bottom):
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.top = top
|
||||
self.bottom = bottom
|
||||
|
||||
|
||||
class Size:
|
||||
def __init__(self, width, height):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
|
Loading…
Reference in New Issue