fix order, parse right data, use output names
This commit is contained in:
parent
51b10de265
commit
feb9ea9e1a
|
@ -37,11 +37,12 @@ class BlendGridStage(BaseStage):
|
|||
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
||||
|
||||
# TODO: labels
|
||||
for i in order or range(len(sources)):
|
||||
order = order or range(len(sources))
|
||||
for i in len(order):
|
||||
x = i % width
|
||||
y = i / width
|
||||
y = i // width
|
||||
|
||||
output.paste(sources[i], (x * size[0], y * size[1]))
|
||||
n = order[i]
|
||||
output.paste(sources[n], (x * size[0], y * size[1]))
|
||||
|
||||
return [output]
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -21,13 +21,12 @@ class PersistDiskStage(BaseStage):
|
|||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
outputs: List[str],
|
||||
stage_source: Image.Image,
|
||||
output: List[str],
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
for source, output in zip(sources, outputs):
|
||||
# TODO: append index to output name
|
||||
dest = save_image(server, output, source, params=params)
|
||||
for source, name in zip(sources, output):
|
||||
dest = save_image(server, name, source, params=params)
|
||||
logger.info("saved image to %s", dest)
|
||||
|
||||
return sources
|
||||
|
|
|
@ -32,7 +32,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
dims: Tuple[int, int, int],
|
||||
dims: Tuple[int, int, int] = None,
|
||||
size: Size,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
|
|
|
@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, _params, _size = pipeline_from_request(server)
|
||||
device, _params, _size = pipeline_from_request(server, data=data)
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get("stages", []):
|
||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||
|
|
Loading…
Reference in New Issue