1
0
Fork 0

lint(api): consolidate output fns

This commit is contained in:
Sean Sube 2023-02-02 08:31:35 -06:00
parent 6ce13096ed
commit a1ef6c4c77
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 66 additions and 62 deletions

View File

@ -71,7 +71,7 @@ def run_txt2img_pipeline(
del result
run_gc()
logger.info('saved txt2img output: %s', dest)
logger.info('finished txt2img job: %s', dest)
def run_img2img_pipeline(
@ -108,7 +108,7 @@ def run_img2img_pipeline(
del result
run_gc()
logger.info('saved img2img output: %s', dest)
logger.info('finished img2img job: %s', dest)
def run_inpaint_pipeline(
@ -154,7 +154,7 @@ def run_inpaint_pipeline(
del image
run_gc()
logger.info('saved inpaint output: %s', dest)
logger.info('finished inpaint job: %s', dest)
def run_upscale_pipeline(
@ -174,4 +174,4 @@ def run_upscale_pipeline(
del image
run_gc()
logger.info('saved img2img output: %s', dest)
logger.info('finished upscale job: %s', dest)

View File

@ -1,10 +1,15 @@
from hashlib import sha256
from json import dumps
from logging import getLogger
from PIL import Image
from typing import Any, Optional
from struct import pack
from time import time
from typing import Any, Optional, Tuple
from .params import (
Border,
ImageParams,
Param,
Size,
UpscaleParams,
)
@ -13,6 +18,20 @@ from .utils import (
ServerContext,
)
logger = getLogger(__name__)
def hash_value(sha, param: Param):
if param is None:
return
elif isinstance(param, float):
sha.update(bytearray(pack('!f', param)))
elif isinstance(param, int):
sha.update(bytearray(pack('!I', param)))
elif isinstance(param, str):
sha.update(param.encode('utf-8'))
else:
logger.warn('cannot hash param: %s, %s', param, type(param))
def json_params(
output: str,
@ -42,9 +61,38 @@ def json_params(
return json
def make_output_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None
) -> str:
now = int(time())
sha = sha256()
hash_value(sha, mode)
hash_value(sha, params.model)
hash_value(sha, params.provider)
hash_value(sha, params.scheduler.__name__)
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.steps)
hash_value(sha, params.seed)
hash_value(sha, size.width)
hash_value(sha, size.height)
if extras is not None:
for param in extras:
hash_value(sha, param)
return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format))
image.save(path, format=ctx.image_format)
logger.debug('saved output image to: %s', path)
return path
@ -60,3 +108,5 @@ def save_params(
json = json_params(output, params, size, upscale=upscale, border=border)
with open(path, 'w') as f:
f.write(dumps(json))
logger.debug('saved image params to: %s', path)
return path

View File

@ -436,7 +436,7 @@ def img2img():
params,
size,
extras=(strength,))
logger.info("img2img output saved: %s", output)
logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_img2img_pipeline,
@ -454,7 +454,7 @@ def txt2img():
'txt2img',
params,
size)
logger.info("txt2img output saved: %s", output)
logger.info("txt2img job queued for: %s", output)
executor.submit_stored(
output, run_txt2img_pipeline, context, params, size, output, upscale)
@ -506,7 +506,7 @@ def inpaint():
fill_color,
)
)
logger.info("inpaint output saved: %s", output)
logger.info("inpaint job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
mask_image.thumbnail((size.width, size.height))

View File

@ -1,17 +1,11 @@
from hashlib import sha256
from logging import getLogger
from os import environ, path
from struct import pack
from time import time
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Optional, Union
import gc
import torch
from .params import (
ImageParams,
Param,
Size,
SizeChart,
)
@ -64,6 +58,11 @@ class ServerContext:
)
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, tail_path)
def is_debug() -> bool:
return environ.get('DEBUG') is not None
@ -122,53 +121,7 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
raise Exception('invalid size')
def hash_value(sha, param: Param):
if param is None:
return
elif isinstance(param, float):
sha.update(bytearray(pack('!f', param)))
elif isinstance(param, int):
sha.update(bytearray(pack('!I', param)))
elif isinstance(param, str):
sha.update(param.encode('utf-8'))
else:
logger.warn('cannot hash param: %s, %s', param, type(param))
def make_output_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None
) -> str:
now = int(time())
sha = sha256()
hash_value(sha, mode)
hash_value(sha, params.model)
hash_value(sha, params.provider)
hash_value(sha, params.scheduler.__name__)
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.steps)
hash_value(sha, params.seed)
hash_value(sha, size.width)
hash_value(sha, size.height)
if extras is not None:
for param in extras:
hash_value(sha, param)
return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, tail_path)
def run_gc():
logger.debug('running garbage collection')
gc.collect()
torch.cuda.empty_cache()
torch.cuda.empty_cache()

View File

@ -60,6 +60,7 @@
"stringcase",
"timestep",
"timesteps",
"tojson",
"uncond",
"unet",
"untruncated",