1
0
Fork 0

feat(api): parse named tile sizes

This commit is contained in:
Sean Sube 2023-01-28 23:06:25 -06:00
parent db9189fd3d
commit 8f1cbc83f8
5 changed files with 32 additions and 9 deletions

View File

@ -30,12 +30,12 @@ def blend_img2img(
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('generating image using img2img, %s steps: %s', params.steps, params.prompt)
prompt = prompt or params.prompt
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
prompt = prompt or params.prompt
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -67,9 +67,11 @@ def upscale_stable_diffusion(
source: Image.Image,
*,
upscale: UpscaleParams,
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('upscaling with Stable Diffusion')
prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)

View File

@ -72,6 +72,7 @@ from .utils import (
get_from_list,
get_from_map,
get_not_empty,
get_size,
make_output_name,
base_join,
ServerContext,
@ -577,8 +578,8 @@ def chain():
stage = StageParams(
stage_data.get('name', callback.__name__),
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
outscale=int(kwargs.get('outscale', 1)),
tile_size=get_size(kwargs.get('tile_size')),
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
)
# TODO: create Border from border

View File

@ -3,12 +3,13 @@ from logging import getLogger
from os import environ, path
from struct import pack
from time import time
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Union, Tuple
from .params import (
ImageParams,
Param,
Size,
SizeChart,
)
logger = getLogger(__name__)
@ -92,6 +93,22 @@ def get_not_empty(args: Any, key: str, default: Any) -> Any:
return val
def get_size(val: Union[int, str, None]) -> SizeChart:
if val is None:
return SizeChart.auto
if type(val) is str:
if val in SizeChart:
return SizeChart[val]
else:
return int(val)
if type(val) is int:
return val
raise Exception('invalid size')
def hash_value(sha, param: Param):
if param is None:
return

View File

@ -23,13 +23,15 @@
"prompt": "a magical wizard in a robe fighting a dragon",
"scale": 4,
"outscale": 4,
"tile_size": 128
"tile_size": "mini"
}
},
{
"name": "save-local",
"type": "persist-disk",
"params": {}
"params": {
"tile_size": "8k"
}
},
{
"name": "save-ceph",
@ -37,7 +39,8 @@
"params": {
"bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph"
"profile_name": "ceph",
"tile_size": "8k"
}
}
]