lint(api): consolidate output fns
This commit is contained in:
parent
6ce13096ed
commit
a1ef6c4c77
|
@ -71,7 +71,7 @@ def run_txt2img_pipeline(
|
||||||
del result
|
del result
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('saved txt2img output: %s', dest)
|
logger.info('finished txt2img job: %s', dest)
|
||||||
|
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
|
@ -108,7 +108,7 @@ def run_img2img_pipeline(
|
||||||
del result
|
del result
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('saved img2img output: %s', dest)
|
logger.info('finished img2img job: %s', dest)
|
||||||
|
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
|
@ -154,7 +154,7 @@ def run_inpaint_pipeline(
|
||||||
del image
|
del image
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('saved inpaint output: %s', dest)
|
logger.info('finished inpaint job: %s', dest)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
|
@ -174,4 +174,4 @@ def run_upscale_pipeline(
|
||||||
del image
|
del image
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.info('saved img2img output: %s', dest)
|
logger.info('finished upscale job: %s', dest)
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
|
from hashlib import sha256
|
||||||
from json import dumps
|
from json import dumps
|
||||||
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Any, Optional
|
from struct import pack
|
||||||
|
from time import time
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from .params import (
|
from .params import (
|
||||||
Border,
|
Border,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
|
Param,
|
||||||
Size,
|
Size,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
|
@ -13,6 +18,20 @@ from .utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
def hash_value(sha, param: Param):
|
||||||
|
if param is None:
|
||||||
|
return
|
||||||
|
elif isinstance(param, float):
|
||||||
|
sha.update(bytearray(pack('!f', param)))
|
||||||
|
elif isinstance(param, int):
|
||||||
|
sha.update(bytearray(pack('!I', param)))
|
||||||
|
elif isinstance(param, str):
|
||||||
|
sha.update(param.encode('utf-8'))
|
||||||
|
else:
|
||||||
|
logger.warn('cannot hash param: %s, %s', param, type(param))
|
||||||
|
|
||||||
|
|
||||||
def json_params(
|
def json_params(
|
||||||
output: str,
|
output: str,
|
||||||
|
@ -42,9 +61,38 @@ def json_params(
|
||||||
return json
|
return json
|
||||||
|
|
||||||
|
|
||||||
|
def make_output_name(
|
||||||
|
mode: str,
|
||||||
|
params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
extras: Optional[Tuple[Param]] = None
|
||||||
|
) -> str:
|
||||||
|
now = int(time())
|
||||||
|
sha = sha256()
|
||||||
|
|
||||||
|
hash_value(sha, mode)
|
||||||
|
hash_value(sha, params.model)
|
||||||
|
hash_value(sha, params.provider)
|
||||||
|
hash_value(sha, params.scheduler.__name__)
|
||||||
|
hash_value(sha, params.prompt)
|
||||||
|
hash_value(sha, params.negative_prompt)
|
||||||
|
hash_value(sha, params.cfg)
|
||||||
|
hash_value(sha, params.steps)
|
||||||
|
hash_value(sha, params.seed)
|
||||||
|
hash_value(sha, size.width)
|
||||||
|
hash_value(sha, size.height)
|
||||||
|
|
||||||
|
if extras is not None:
|
||||||
|
for param in extras:
|
||||||
|
hash_value(sha, param)
|
||||||
|
|
||||||
|
return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
|
||||||
|
|
||||||
|
|
||||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||||
path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format))
|
path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format))
|
||||||
image.save(path, format=ctx.image_format)
|
image.save(path, format=ctx.image_format)
|
||||||
|
logger.debug('saved output image to: %s', path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,3 +108,5 @@ def save_params(
|
||||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
json = json_params(output, params, size, upscale=upscale, border=border)
|
||||||
with open(path, 'w') as f:
|
with open(path, 'w') as f:
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
|
logger.debug('saved image params to: %s', path)
|
||||||
|
return path
|
||||||
|
|
|
@ -436,7 +436,7 @@ def img2img():
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
extras=(strength,))
|
extras=(strength,))
|
||||||
logger.info("img2img output saved: %s", output)
|
logger.info("img2img job queued for: %s", output)
|
||||||
|
|
||||||
source_image.thumbnail((size.width, size.height))
|
source_image.thumbnail((size.width, size.height))
|
||||||
executor.submit_stored(output, run_img2img_pipeline,
|
executor.submit_stored(output, run_img2img_pipeline,
|
||||||
|
@ -454,7 +454,7 @@ def txt2img():
|
||||||
'txt2img',
|
'txt2img',
|
||||||
params,
|
params,
|
||||||
size)
|
size)
|
||||||
logger.info("txt2img output saved: %s", output)
|
logger.info("txt2img job queued for: %s", output)
|
||||||
|
|
||||||
executor.submit_stored(
|
executor.submit_stored(
|
||||||
output, run_txt2img_pipeline, context, params, size, output, upscale)
|
output, run_txt2img_pipeline, context, params, size, output, upscale)
|
||||||
|
@ -506,7 +506,7 @@ def inpaint():
|
||||||
fill_color,
|
fill_color,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info("inpaint output saved: %s", output)
|
logger.info("inpaint job queued for: %s", output)
|
||||||
|
|
||||||
source_image.thumbnail((size.width, size.height))
|
source_image.thumbnail((size.width, size.height))
|
||||||
mask_image.thumbnail((size.width, size.height))
|
mask_image.thumbnail((size.width, size.height))
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
from hashlib import sha256
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from struct import pack
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from time import time
|
|
||||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .params import (
|
from .params import (
|
||||||
ImageParams,
|
|
||||||
Param,
|
|
||||||
Size,
|
|
||||||
SizeChart,
|
SizeChart,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,6 +58,11 @@ class ServerContext:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def base_join(base: str, tail: str) -> str:
|
||||||
|
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||||
|
return path.join(base, tail_path)
|
||||||
|
|
||||||
|
|
||||||
def is_debug() -> bool:
|
def is_debug() -> bool:
|
||||||
return environ.get('DEBUG') is not None
|
return environ.get('DEBUG') is not None
|
||||||
|
|
||||||
|
@ -122,53 +121,7 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||||
raise Exception('invalid size')
|
raise Exception('invalid size')
|
||||||
|
|
||||||
|
|
||||||
def hash_value(sha, param: Param):
|
|
||||||
if param is None:
|
|
||||||
return
|
|
||||||
elif isinstance(param, float):
|
|
||||||
sha.update(bytearray(pack('!f', param)))
|
|
||||||
elif isinstance(param, int):
|
|
||||||
sha.update(bytearray(pack('!I', param)))
|
|
||||||
elif isinstance(param, str):
|
|
||||||
sha.update(param.encode('utf-8'))
|
|
||||||
else:
|
|
||||||
logger.warn('cannot hash param: %s, %s', param, type(param))
|
|
||||||
|
|
||||||
|
|
||||||
def make_output_name(
|
|
||||||
mode: str,
|
|
||||||
params: ImageParams,
|
|
||||||
size: Size,
|
|
||||||
extras: Optional[Tuple[Param]] = None
|
|
||||||
) -> str:
|
|
||||||
now = int(time())
|
|
||||||
sha = sha256()
|
|
||||||
|
|
||||||
hash_value(sha, mode)
|
|
||||||
hash_value(sha, params.model)
|
|
||||||
hash_value(sha, params.provider)
|
|
||||||
hash_value(sha, params.scheduler.__name__)
|
|
||||||
hash_value(sha, params.prompt)
|
|
||||||
hash_value(sha, params.negative_prompt)
|
|
||||||
hash_value(sha, params.cfg)
|
|
||||||
hash_value(sha, params.steps)
|
|
||||||
hash_value(sha, params.seed)
|
|
||||||
hash_value(sha, size.width)
|
|
||||||
hash_value(sha, size.height)
|
|
||||||
|
|
||||||
if extras is not None:
|
|
||||||
for param in extras:
|
|
||||||
hash_value(sha, param)
|
|
||||||
|
|
||||||
return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
|
|
||||||
|
|
||||||
|
|
||||||
def base_join(base: str, tail: str) -> str:
|
|
||||||
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
|
||||||
return path.join(base, tail_path)
|
|
||||||
|
|
||||||
|
|
||||||
def run_gc():
|
def run_gc():
|
||||||
logger.debug('running garbage collection')
|
logger.debug('running garbage collection')
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -60,6 +60,7 @@
|
||||||
"stringcase",
|
"stringcase",
|
||||||
"timestep",
|
"timestep",
|
||||||
"timesteps",
|
"timesteps",
|
||||||
|
"tojson",
|
||||||
"uncond",
|
"uncond",
|
||||||
"unet",
|
"unet",
|
||||||
"untruncated",
|
"untruncated",
|
||||||
|
|
Loading…
Reference in New Issue