start adding queue progress
This commit is contained in:
parent
fd92d52339
commit
237accc973
|
@ -1,6 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from ..params import HighresParams, ImageParams, Size, SizeChart, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
|
@ -24,7 +31,7 @@ class EditMetadataStage(BaseStage):
|
|||
note: Optional[str] = None,
|
||||
replace_params: Optional[ImageParams] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
) -> StageResult:
|
||||
# Modify the source image's metadata using the provided parameters
|
||||
for metadata in source.metadata:
|
||||
if note is not None:
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Tuple
|
||||
|
||||
from PIL import ImageDraw
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
|
||||
class EditTextStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: StageResult,
|
||||
*,
|
||||
text: str,
|
||||
position: Tuple[int, int],
|
||||
fill: str = "white",
|
||||
stroke: str = "black",
|
||||
stroke_width: int = 1,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
# Add text to each image in source at the given position
|
||||
results = []
|
||||
|
||||
for image in source.as_images():
|
||||
image = image.copy()
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.text(
|
||||
position, text, fill=fill, stroke_width=stroke_width, stroke_fill=stroke
|
||||
)
|
||||
results.append(image)
|
||||
|
||||
return StageResult.from_images(results, source.metadata)
|
|
@ -23,7 +23,7 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
|||
|
||||
class ChainProgress:
|
||||
parent: ProgressCallback
|
||||
step: int # same as steps.current, left for legacy purposes
|
||||
step: int # current number of steps
|
||||
prev: int # accumulator when step resets
|
||||
|
||||
# TODO: should probably be moved to worker context as well
|
||||
|
|
|
@ -93,15 +93,18 @@ def error_reply(err: str):
|
|||
return response
|
||||
|
||||
|
||||
def job_reply(name: str):
|
||||
EMPTY_PROGRESS = Progress(0, 0)
|
||||
|
||||
|
||||
def job_reply(name: str, queue: int = 0):
|
||||
return jsonify(
|
||||
{
|
||||
"name": name,
|
||||
"queue": Progress(0, 0).tojson(), # TODO: use real queue position
|
||||
"queue": Progress(queue, queue).tojson(),
|
||||
"status": JobStatus.PENDING,
|
||||
"stages": Progress(0, 0).tojson(),
|
||||
"steps": Progress(0, 0).tojson(),
|
||||
"tiles": Progress(0, 0).tojson(),
|
||||
"stages": EMPTY_PROGRESS.tojson(),
|
||||
"steps": EMPTY_PROGRESS.tojson(),
|
||||
"tiles": EMPTY_PROGRESS.tojson(),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -110,24 +113,29 @@ def image_reply(
|
|||
server: ServerContext,
|
||||
name: str,
|
||||
status: str,
|
||||
queue: Progress = None,
|
||||
stages: Progress = None,
|
||||
steps: Progress = None,
|
||||
tiles: Progress = None,
|
||||
outputs: List[str] = None,
|
||||
metadata: List[ImageMetadata] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if queue is None:
|
||||
queue = EMPTY_PROGRESS
|
||||
|
||||
if stages is None:
|
||||
stages = Progress(0, 0)
|
||||
stages = EMPTY_PROGRESS
|
||||
|
||||
if steps is None:
|
||||
steps = Progress(0, 0)
|
||||
steps = EMPTY_PROGRESS
|
||||
|
||||
if tiles is None:
|
||||
tiles = Progress(0, 0)
|
||||
tiles = EMPTY_PROGRESS
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"status": status,
|
||||
"queue": queue.tojson(),
|
||||
"stages": stages.tojson(),
|
||||
"steps": steps.tojson(),
|
||||
"tiles": tiles.tojson(),
|
||||
|
@ -263,7 +271,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
output_count += 1
|
||||
|
||||
job_name = make_job_name("img2img", params, size, extras=[strength])
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.IMG2IMG,
|
||||
run_img2img_pipeline,
|
||||
|
@ -279,7 +287,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -291,7 +299,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
job_name = make_job_name("txt2img", params, size)
|
||||
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.TXT2IMG,
|
||||
run_txt2img_pipeline,
|
||||
|
@ -305,7 +313,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("txt2img job queued for: %s", job_name)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -367,7 +375,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
],
|
||||
)
|
||||
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.INPAINT,
|
||||
run_inpaint_pipeline,
|
||||
|
@ -390,7 +398,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("inpaint job queued for: %s", job_name)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -407,7 +415,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
job_name = make_job_name("upscale", params, size)
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.UPSCALE,
|
||||
run_upscale_pipeline,
|
||||
|
@ -422,7 +430,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
# keys that are specially parsed by params and should not show up in with_args
|
||||
|
@ -521,7 +529,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
job_name = make_job_name("chain", base_params, base_size)
|
||||
|
||||
# build and run chain pipeline
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.CHAIN,
|
||||
pipeline,
|
||||
|
@ -532,7 +540,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
needs_device=device,
|
||||
)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -557,7 +565,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
upscale = build_upscale()
|
||||
|
||||
job_name = make_job_name("blend", params, size)
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.BLEND,
|
||||
run_blend_pipeline,
|
||||
|
@ -573,7 +581,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -582,7 +590,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
|||
job_name = make_job_name("txt2txt", params, size)
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
queue = pool.submit(
|
||||
job_name,
|
||||
JobType.TXT2TXT,
|
||||
run_txt2txt_pipeline,
|
||||
|
@ -592,7 +600,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
|||
needs_device=device,
|
||||
)
|
||||
|
||||
return job_reply(job_name)
|
||||
return job_reply(job_name, queue=queue)
|
||||
|
||||
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -612,7 +620,7 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("output name is required")
|
||||
|
||||
output_file = sanitize_name(output_file)
|
||||
status, progress = pool.status(output_file)
|
||||
status, progress, _queue = pool.status(output_file)
|
||||
|
||||
if status == JobStatus.PENDING:
|
||||
return ready_reply(pending=True)
|
||||
|
@ -677,7 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
for job_name in job_list:
|
||||
job_name = sanitize_name(job_name)
|
||||
status, progress = pool.status(job_name)
|
||||
status, progress, queue = pool.status(job_name)
|
||||
|
||||
if progress is not None:
|
||||
outputs = None
|
||||
|
@ -700,7 +708,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
)
|
||||
else:
|
||||
records.append(image_reply(server, job_name, status))
|
||||
records.append(image_reply(server, job_name, status, queue=queue))
|
||||
|
||||
return jsonify(records)
|
||||
|
||||
|
|
|
@ -36,9 +36,18 @@ class Progress:
|
|||
self.current = current
|
||||
self.total = total
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Progress(%d, %d)" % (self.current, self.total)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s/%s" % (self.current, self.total)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Progress):
|
||||
return self.current == other.current and self.total == other.total
|
||||
|
||||
return False
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"current": self.current,
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
|||
|
||||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
from .command import JobCommand, JobStatus, ProgressCommand
|
||||
from .command import JobCommand, JobStatus, Progress, ProgressCommand
|
||||
from .context import WorkerContext
|
||||
from .utils import Interval
|
||||
from .worker import worker_main
|
||||
|
@ -228,31 +228,33 @@ class DevicePoolExecutor:
|
|||
self.cancelled_jobs.append(key)
|
||||
return True
|
||||
|
||||
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
|
||||
def status(
|
||||
self, key: str
|
||||
) -> Tuple[JobStatus, Optional[ProgressCommand], Optional[Progress]]:
|
||||
"""
|
||||
Check if a job has been finished and report the last progress update.
|
||||
"""
|
||||
|
||||
if key in self.cancelled_jobs:
|
||||
logger.debug("checking status for cancelled job: %s", key)
|
||||
return (JobStatus.CANCELLED, None)
|
||||
return (JobStatus.CANCELLED, None, None)
|
||||
|
||||
if key in self.running_jobs:
|
||||
logger.debug("checking status for running job: %s", key)
|
||||
return (JobStatus.RUNNING, self.running_jobs[key])
|
||||
return (JobStatus.RUNNING, self.running_jobs[key], None)
|
||||
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
logger.debug("checking status for finished job: %s", key)
|
||||
return (job.status, job)
|
||||
return (job.status, job, None)
|
||||
|
||||
for job in self.pending_jobs:
|
||||
for i, job in enumerate(self.pending_jobs):
|
||||
if job.name == key:
|
||||
logger.debug("checking status for pending job: %s", key)
|
||||
return (JobStatus.PENDING, None)
|
||||
return (JobStatus.PENDING, None, Progress(i, len(self.pending_jobs)))
|
||||
|
||||
logger.trace("checking status for unknown job: %s", key)
|
||||
return (JobStatus.UNKNOWN, None)
|
||||
return (JobStatus.UNKNOWN, None, None)
|
||||
|
||||
def join(self):
|
||||
logger.info("stopping worker pool")
|
||||
|
@ -399,7 +401,7 @@ class DevicePoolExecutor:
|
|||
*args,
|
||||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> int:
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
device = self.devices[device_idx].device
|
||||
logger.info(
|
||||
|
@ -413,6 +415,9 @@ class DevicePoolExecutor:
|
|||
job = JobCommand(key, device, job_type, fn, args, kwargs)
|
||||
self.pending_jobs.append(job)
|
||||
|
||||
# return position in queue
|
||||
return len(self.pending_jobs)
|
||||
|
||||
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
|
||||
"""
|
||||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onnx_web.chain.edit_metadata import EditMetadataStage
|
||||
|
||||
|
||||
class TestEditMetadataStage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.stage = EditMetadataStage()
|
||||
|
||||
def test_run_with_no_changes(self):
|
||||
source = MagicMock()
|
||||
source.metadata = []
|
||||
|
||||
result = self.stage.run(None, None, None, None, source)
|
||||
|
||||
self.assertEqual(result, source)
|
||||
|
||||
def test_run_with_note_change(self):
|
||||
source = MagicMock()
|
||||
source.metadata = [MagicMock()]
|
||||
note = "New note"
|
||||
|
||||
result = self.stage.run(None, None, None, None, source, note=note)
|
||||
|
||||
self.assertEqual(result, source)
|
||||
self.assertEqual(result.metadata[0].note, note)
|
||||
|
||||
def test_run_with_replace_params_change(self):
|
||||
source = MagicMock()
|
||||
source.metadata = [MagicMock()]
|
||||
replace_params = MagicMock()
|
||||
|
||||
result = self.stage.run(
|
||||
None, None, None, None, source, replace_params=replace_params
|
||||
)
|
||||
|
||||
self.assertEqual(result, source)
|
||||
self.assertEqual(result.metadata[0].params, replace_params)
|
||||
|
||||
# Add more test cases for other parameters...
|
|
@ -0,0 +1,48 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.edit_text import EditTextStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
|
||||
class TestEditTextStage(unittest.TestCase):
|
||||
def test_run(self):
|
||||
# Create a sample image
|
||||
image = Image.new("RGB", (100, 100), color="black")
|
||||
|
||||
# Create an instance of EditTextStage
|
||||
stage = EditTextStage()
|
||||
|
||||
# Define the input parameters
|
||||
text = "Hello, World!"
|
||||
position = (10, 10)
|
||||
fill = "white"
|
||||
stroke = "white"
|
||||
stroke_width = 2
|
||||
|
||||
# Create a mock source StageResult
|
||||
source = StageResult.from_images([image], metadata={})
|
||||
|
||||
# Call the run method
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
source,
|
||||
text=text,
|
||||
position=position,
|
||||
fill=fill,
|
||||
stroke=stroke,
|
||||
stroke_width=stroke_width,
|
||||
)
|
||||
|
||||
# Assert the output
|
||||
self.assertEqual(len(result.as_images()), 1)
|
||||
# self.assertEqual(result.metadata, {})
|
||||
|
||||
# Verify the modified image
|
||||
modified_image = result.as_images()[0]
|
||||
self.assertEqual(np.max(np.array(modified_image)), 255)
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
from onnx_web.params import DeviceParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobStatus
|
||||
from onnx_web.worker.command import JobStatus, Progress
|
||||
from onnx_web.worker.pool import DevicePoolExecutor
|
||||
from tests.helpers import test_device
|
||||
|
||||
|
@ -61,10 +61,12 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool.start()
|
||||
|
||||
self.pool.submit("test", "test", sleep_job, lock=lock)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
self.assertEqual(
|
||||
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
|
||||
)
|
||||
|
||||
self.assertTrue(self.pool.cancel("test"))
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None))
|
||||
|
||||
def test_cancel_running(self):
|
||||
pass
|
||||
|
@ -104,7 +106,7 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool.submit("test", "test", lock_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, _progress = self.pool.status("test")
|
||||
status, _progress, _status = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.RUNNING)
|
||||
|
||||
def test_done_pending(self):
|
||||
|
@ -116,7 +118,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
|
||||
self.pool.submit("test1", "test", lock_job)
|
||||
self.pool.submit("test2", "test", lock_job)
|
||||
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
||||
self.assertEqual(
|
||||
self.pool.status("test2"), (JobStatus.PENDING, None, Progress(1, 2))
|
||||
)
|
||||
|
||||
lock.set()
|
||||
|
||||
|
@ -132,10 +136,12 @@ class TestWorkerPool(unittest.TestCase):
|
|||
)
|
||||
self.pool.start()
|
||||
self.pool.submit("test", "test", sleep_job)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
self.assertEqual(
|
||||
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
|
||||
)
|
||||
|
||||
sleep(5.0)
|
||||
status, _progress = self.pool.status("test")
|
||||
status, _progress, _queue = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.SUCCESS)
|
||||
|
||||
def test_recycle_live(self):
|
||||
|
@ -162,7 +168,7 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool.submit("test", "test", progress_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, progress = self.pool.status("test")
|
||||
status, progress, _queue = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.SUCCESS)
|
||||
self.assertEqual(progress.steps.current, 1)
|
||||
|
||||
|
@ -178,6 +184,6 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool.submit("test", "test", fail_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, progress = self.pool.status("test")
|
||||
status, progress, _queue = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.FAILED)
|
||||
self.assertEqual(progress.steps.current, 0)
|
||||
|
|
Loading…
Reference in New Issue