1
0
Fork 0

return output names

This commit is contained in:
Sean Sube 2024-01-03 20:54:11 -06:00
parent 6f0adcbae3
commit 9d05d9baac
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 13 additions and 13 deletions

View File

@ -16,18 +16,12 @@ from .utils import base_join, hash_value
logger = getLogger(__name__)
def make_output_name(
def make_output_names(
server: ServerContext,
mode: str,
params: ImageParams,
size: Size,
extras: Optional[List[Optional[Param]]] = None,
count: Optional[int] = None,
job_name: str,
count: int = 1,
offset: int = 0,
) -> List[str]:
count = count or params.batch
job_name = make_job_name(mode, params, size, extras)
return [
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
]
@ -68,12 +62,14 @@ def save_result(
result: StageResult,
base_name: str,
) -> List[str]:
images = result.as_image()
outputs = make_output_names(server, base_name, len(images))
results = []
for i, (image, metadata) in enumerate(zip(result.as_image(), result.metadata)):
for image, metadata, filename in zip(images, result.metadata, outputs):
results.append(
save_image(
server,
base_name + f"_{i}.{server.image_format}",
filename,
image,
metadata,
)

View File

@ -18,7 +18,7 @@ from ..diffusers.run import (
run_upscale_pipeline,
)
from ..diffusers.utils import replace_wildcards
from ..output import make_job_name
from ..output import make_job_name, make_output_names
from ..params import Progress, Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
@ -668,7 +668,10 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
# TODO: accumulate results
if progress is not None:
# TODO: add output paths based on progress.results counter
outputs = None
if progress.results > 0:
outputs = make_output_names(server, job_name, progress.results)
return image_reply(
job_name,
status,
@ -676,6 +679,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
stages=Progress(progress.stages, 0),
steps=Progress(progress.steps, 0),
tiles=Progress(progress.tiles, 0),
outputs=outputs,
)
return image_reply(job_name, status, "TODO")