feat(api): parse named tile sizes
This commit is contained in:
parent
db9189fd3d
commit
8f1cbc83f8
|
@ -30,12 +30,12 @@ def blend_img2img(
|
|||
prompt: str = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('generating image using img2img, %s steps: %s', params.steps, params.prompt)
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
prompt = prompt or params.prompt
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
result = pipe(
|
||||
|
|
|
@ -67,9 +67,11 @@ def upscale_stable_diffusion(
|
|||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
prompt: str = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling with Stable Diffusion')
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
||||
|
||||
pipeline = load_stable_diffusion(ctx, upscale)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
|
|
|
@ -72,6 +72,7 @@ from .utils import (
|
|||
get_from_list,
|
||||
get_from_map,
|
||||
get_not_empty,
|
||||
get_size,
|
||||
make_output_name,
|
||||
base_join,
|
||||
ServerContext,
|
||||
|
@ -577,8 +578,8 @@ def chain():
|
|||
|
||||
stage = StageParams(
|
||||
stage_data.get('name', callback.__name__),
|
||||
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
|
||||
outscale=int(kwargs.get('outscale', 1)),
|
||||
tile_size=get_size(kwargs.get('tile_size')),
|
||||
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
|
||||
)
|
||||
|
||||
# TODO: create Border from border
|
||||
|
|
|
@ -3,12 +3,13 @@ from logging import getLogger
|
|||
from os import environ, path
|
||||
from struct import pack
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from .params import (
|
||||
ImageParams,
|
||||
Param,
|
||||
Size,
|
||||
SizeChart,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -92,6 +93,22 @@ def get_not_empty(args: Any, key: str, default: Any) -> Any:
|
|||
return val
|
||||
|
||||
|
||||
def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||
if val is None:
|
||||
return SizeChart.auto
|
||||
|
||||
if type(val) is str:
|
||||
if val in SizeChart:
|
||||
return SizeChart[val]
|
||||
else:
|
||||
return int(val)
|
||||
|
||||
if type(val) is int:
|
||||
return val
|
||||
|
||||
raise Exception('invalid size')
|
||||
|
||||
|
||||
def hash_value(sha, param: Param):
|
||||
if param is None:
|
||||
return
|
||||
|
|
|
@ -23,13 +23,15 @@
|
|||
"prompt": "a magical wizard in a robe fighting a dragon",
|
||||
"scale": 4,
|
||||
"outscale": 4,
|
||||
"tile_size": 128
|
||||
"tile_size": "mini"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "save-local",
|
||||
"type": "persist-disk",
|
||||
"params": {}
|
||||
"params": {
|
||||
"tile_size": "8k"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "save-ceph",
|
||||
|
@ -37,7 +39,8 @@
|
|||
"params": {
|
||||
"bucket": "storage-stable-diffusion",
|
||||
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
||||
"profile_name": "ceph"
|
||||
"profile_name": "ceph",
|
||||
"tile_size": "8k"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue