return output names
This commit is contained in:
parent
6f0adcbae3
commit
9d05d9baac
|
@ -16,18 +16,12 @@ from .utils import base_join, hash_value
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_output_name(
|
def make_output_names(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
mode: str,
|
job_name: str,
|
||||||
params: ImageParams,
|
count: int = 1,
|
||||||
size: Size,
|
|
||||||
extras: Optional[List[Optional[Param]]] = None,
|
|
||||||
count: Optional[int] = None,
|
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
count = count or params.batch
|
|
||||||
job_name = make_job_name(mode, params, size, extras)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
||||||
]
|
]
|
||||||
|
@ -68,12 +62,14 @@ def save_result(
|
||||||
result: StageResult,
|
result: StageResult,
|
||||||
base_name: str,
|
base_name: str,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
images = result.as_image()
|
||||||
|
outputs = make_output_names(server, base_name, len(images))
|
||||||
results = []
|
results = []
|
||||||
for i, (image, metadata) in enumerate(zip(result.as_image(), result.metadata)):
|
for image, metadata, filename in zip(images, result.metadata, outputs):
|
||||||
results.append(
|
results.append(
|
||||||
save_image(
|
save_image(
|
||||||
server,
|
server,
|
||||||
base_name + f"_{i}.{server.image_format}",
|
filename,
|
||||||
image,
|
image,
|
||||||
metadata,
|
metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..diffusers.run import (
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from ..diffusers.utils import replace_wildcards
|
from ..diffusers.utils import replace_wildcards
|
||||||
from ..output import make_job_name
|
from ..output import make_job_name, make_output_names
|
||||||
from ..params import Progress, Size, StageParams, TileOrder
|
from ..params import Progress, Size, StageParams, TileOrder
|
||||||
from ..transformers.run import run_txt2txt_pipeline
|
from ..transformers.run import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
@ -668,7 +668,10 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
# TODO: accumulate results
|
# TODO: accumulate results
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
# TODO: add output paths based on progress.results counter
|
outputs = None
|
||||||
|
if progress.results > 0:
|
||||||
|
outputs = make_output_names(server, job_name, progress.results)
|
||||||
|
|
||||||
return image_reply(
|
return image_reply(
|
||||||
job_name,
|
job_name,
|
||||||
status,
|
status,
|
||||||
|
@ -676,6 +679,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
stages=Progress(progress.stages, 0),
|
stages=Progress(progress.stages, 0),
|
||||||
steps=Progress(progress.steps, 0),
|
steps=Progress(progress.steps, 0),
|
||||||
tiles=Progress(progress.tiles, 0),
|
tiles=Progress(progress.tiles, 0),
|
||||||
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_reply(job_name, status, "TODO")
|
return image_reply(job_name, status, "TODO")
|
||||||
|
|
Loading…
Reference in New Issue