return output names
This commit is contained in:
parent
6f0adcbae3
commit
9d05d9baac
|
@ -16,18 +16,12 @@ from .utils import base_join, hash_value
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def make_output_name(
|
||||
def make_output_names(
|
||||
server: ServerContext,
|
||||
mode: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
count: Optional[int] = None,
|
||||
job_name: str,
|
||||
count: int = 1,
|
||||
offset: int = 0,
|
||||
) -> List[str]:
|
||||
count = count or params.batch
|
||||
job_name = make_job_name(mode, params, size, extras)
|
||||
|
||||
return [
|
||||
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
||||
]
|
||||
|
@ -68,12 +62,14 @@ def save_result(
|
|||
result: StageResult,
|
||||
base_name: str,
|
||||
) -> List[str]:
|
||||
images = result.as_image()
|
||||
outputs = make_output_names(server, base_name, len(images))
|
||||
results = []
|
||||
for i, (image, metadata) in enumerate(zip(result.as_image(), result.metadata)):
|
||||
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||
results.append(
|
||||
save_image(
|
||||
server,
|
||||
base_name + f"_{i}.{server.image_format}",
|
||||
filename,
|
||||
image,
|
||||
metadata,
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ from ..diffusers.run import (
|
|||
run_upscale_pipeline,
|
||||
)
|
||||
from ..diffusers.utils import replace_wildcards
|
||||
from ..output import make_job_name
|
||||
from ..output import make_job_name, make_output_names
|
||||
from ..params import Progress, Size, StageParams, TileOrder
|
||||
from ..transformers.run import run_txt2txt_pipeline
|
||||
from ..utils import (
|
||||
|
@ -668,7 +668,10 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
# TODO: accumulate results
|
||||
if progress is not None:
|
||||
# TODO: add output paths based on progress.results counter
|
||||
outputs = None
|
||||
if progress.results > 0:
|
||||
outputs = make_output_names(server, job_name, progress.results)
|
||||
|
||||
return image_reply(
|
||||
job_name,
|
||||
status,
|
||||
|
@ -676,6 +679,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
stages=Progress(progress.stages, 0),
|
||||
steps=Progress(progress.steps, 0),
|
||||
tiles=Progress(progress.tiles, 0),
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
return image_reply(job_name, status, "TODO")
|
||||
|
|
Loading…
Reference in New Issue