From 237accc973b7e13740f560db53c22baa27128127 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 7 Jan 2024 08:16:13 -0600 Subject: [PATCH] start adding queue progress --- api/onnx_web/chain/edit_metadata.py | 13 ++++-- api/onnx_web/chain/edit_text.py | 41 +++++++++++++++++++ api/onnx_web/chain/pipeline.py | 2 +- api/onnx_web/server/api.py | 58 +++++++++++++++------------ api/onnx_web/worker/command.py | 9 +++++ api/onnx_web/worker/pool.py | 23 ++++++----- api/tests/chain/test_edit_metadata.py | 41 +++++++++++++++++++ api/tests/chain/test_edit_text.py | 48 ++++++++++++++++++++++ api/tests/worker/test_pool.py | 24 ++++++----- 9 files changed, 212 insertions(+), 47 deletions(-) create mode 100644 api/tests/chain/test_edit_metadata.py create mode 100644 api/tests/chain/test_edit_text.py diff --git a/api/onnx_web/chain/edit_metadata.py b/api/onnx_web/chain/edit_metadata.py index 4ddebe17..0ca4d1c5 100644 --- a/api/onnx_web/chain/edit_metadata.py +++ b/api/onnx_web/chain/edit_metadata.py @@ -1,6 +1,13 @@ -from typing import Any, Optional +from typing import Optional -from ..params import HighresParams, ImageParams, Size, SizeChart, StageParams, UpscaleParams +from ..params import ( + HighresParams, + ImageParams, + Size, + SizeChart, + StageParams, + UpscaleParams, +) from ..server import ServerContext from ..worker import WorkerContext from .base import BaseStage @@ -24,7 +31,7 @@ class EditMetadataStage(BaseStage): note: Optional[str] = None, replace_params: Optional[ImageParams] = None, **kwargs, - ) -> Any: + ) -> StageResult: # Modify the source image's metadata using the provided parameters for metadata in source.metadata: if note is not None: diff --git a/api/onnx_web/chain/edit_text.py b/api/onnx_web/chain/edit_text.py index e69de29b..ddbe8cd9 100644 --- a/api/onnx_web/chain/edit_text.py +++ b/api/onnx_web/chain/edit_text.py @@ -0,0 +1,41 @@ +from typing import Tuple + +from PIL import ImageDraw + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import WorkerContext +from .base import BaseStage +from .result import StageResult + + +class EditTextStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: StageResult, + *, + text: str, + position: Tuple[int, int], + fill: str = "white", + stroke: str = "black", + stroke_width: int = 1, + **kwargs, + ) -> StageResult: + # Add text to each image in source at the given position + results = [] + + for image in source.as_images(): + image = image.copy() + draw = ImageDraw.Draw(image) + draw.text( + position, text, fill=fill, stroke_width=stroke_width, stroke_fill=stroke + ) + results.append(image) + + return StageResult.from_images(results, source.metadata) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index ac1243d2..e78d12bd 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -23,7 +23,7 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] class ChainProgress: parent: ProgressCallback - step: int # same as steps.current, left for legacy purposes + step: int # current number of steps prev: int # accumulator when step resets # TODO: should probably be moved to worker context as well diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 40fb6c10..83820e87 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -93,15 +93,18 @@ def error_reply(err: str): return response -def job_reply(name: str): +EMPTY_PROGRESS = Progress(0, 0) + + +def job_reply(name: str, queue: int = 0): return jsonify( { "name": name, - "queue": Progress(0, 0).tojson(), # TODO: use real queue position + "queue": Progress(queue, queue).tojson(), "status": JobStatus.PENDING, - "stages": Progress(0, 0).tojson(), - "steps": Progress(0, 0).tojson(), - "tiles": Progress(0, 0).tojson(), + "stages": EMPTY_PROGRESS.tojson(), + "steps": EMPTY_PROGRESS.tojson(), + "tiles": EMPTY_PROGRESS.tojson(), } ) @@ -110,24 +113,29 @@ def image_reply( server: ServerContext, name: str, status: str, + queue: Progress = None, stages: Progress = None, steps: Progress = None, tiles: Progress = None, outputs: List[str] = None, metadata: List[ImageMetadata] = None, ) -> Dict[str, Any]: + if queue is None: + queue = EMPTY_PROGRESS + if stages is None: - stages = Progress(0, 0) + stages = EMPTY_PROGRESS if steps is None: - steps = Progress(0, 0) + steps = EMPTY_PROGRESS if tiles is None: - tiles = Progress(0, 0) + tiles = EMPTY_PROGRESS data = { "name": name, "status": status, + "queue": queue.tojson(), "stages": stages.tojson(), "steps": steps.tojson(), "tiles": tiles.tojson(), @@ -263,7 +271,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): output_count += 1 job_name = make_job_name("img2img", params, size, extras=[strength]) - pool.submit( + queue = pool.submit( job_name, JobType.IMG2IMG, run_img2img_pipeline, @@ -279,7 +287,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("img2img job queued for: %s", job_name) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def txt2img(server: ServerContext, pool: DevicePoolExecutor): @@ -291,7 +299,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): job_name = make_job_name("txt2img", params, size) - pool.submit( + queue = pool.submit( job_name, JobType.TXT2IMG, run_txt2img_pipeline, @@ -305,7 +313,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("txt2img job queued for: %s", job_name) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def inpaint(server: ServerContext, pool: DevicePoolExecutor): @@ -367,7 +375,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ], ) - pool.submit( + queue = pool.submit( job_name, JobType.INPAINT, run_inpaint_pipeline, @@ -390,7 +398,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): logger.info("inpaint job queued for: %s", job_name) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def upscale(server: ServerContext, pool: DevicePoolExecutor): @@ -407,7 +415,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): replace_wildcards(params, get_wildcard_data()) job_name = make_job_name("upscale", params, size) - pool.submit( + queue = pool.submit( job_name, JobType.UPSCALE, run_upscale_pipeline, @@ -422,7 +430,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return job_reply(job_name) + return job_reply(job_name, queue=queue) # keys that are specially parsed by params and should not show up in with_args @@ -521,7 +529,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): job_name = make_job_name("chain", base_params, base_size) # build and run chain pipeline - pool.submit( + queue = pool.submit( job_name, JobType.CHAIN, pipeline, @@ -532,7 +540,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): needs_device=device, ) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def blend(server: ServerContext, pool: DevicePoolExecutor): @@ -557,7 +565,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): upscale = build_upscale() job_name = make_job_name("blend", params, size) - pool.submit( + queue = pool.submit( job_name, JobType.BLEND, run_blend_pipeline, @@ -573,7 +581,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def txt2txt(server: ServerContext, pool: DevicePoolExecutor): @@ -582,7 +590,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor): job_name = make_job_name("txt2txt", params, size) logger.info("upscale job queued for: %s", job_name) - pool.submit( + queue = pool.submit( job_name, JobType.TXT2TXT, run_txt2txt_pipeline, @@ -592,7 +600,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor): needs_device=device, ) - return job_reply(job_name) + return job_reply(job_name, queue=queue) def cancel(server: ServerContext, pool: DevicePoolExecutor): @@ -612,7 +620,7 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): return error_reply("output name is required") output_file = sanitize_name(output_file) - status, progress = pool.status(output_file) + status, progress, _queue = pool.status(output_file) if status == JobStatus.PENDING: return ready_reply(pending=True) @@ -677,7 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): for job_name in job_list: job_name = sanitize_name(job_name) - status, progress = pool.status(job_name) + status, progress, queue = pool.status(job_name) if progress is not None: outputs = None @@ -700,7 +708,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): ) ) else: - records.append(image_reply(server, job_name, status)) + records.append(image_reply(server, job_name, status, queue=queue)) return jsonify(records) diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 8e303652..698bad22 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -36,9 +36,18 @@ class Progress: self.current = current self.total = total + def __repr__(self) -> str: + return "Progress(%d, %d)" % (self.current, self.total) + def __str__(self) -> str: return "%s/%s" % (self.current, self.total) + def __eq__(self, other: Any) -> bool: + if isinstance(other, Progress): + return self.current == other.current and self.total == other.total + + return False + def tojson(self): return { "current": self.current, diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index a1dd0da0..90ecb503 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext -from .command import JobCommand, JobStatus, ProgressCommand +from .command import JobCommand, JobStatus, Progress, ProgressCommand from .context import WorkerContext from .utils import Interval from .worker import worker_main @@ -228,31 +228,33 @@ class DevicePoolExecutor: self.cancelled_jobs.append(key) return True - def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]: + def status( + self, key: str + ) -> Tuple[JobStatus, Optional[ProgressCommand], Optional[Progress]]: """ Check if a job has been finished and report the last progress update. """ if key in self.cancelled_jobs: logger.debug("checking status for cancelled job: %s", key) - return (JobStatus.CANCELLED, None) + return (JobStatus.CANCELLED, None, None) if key in self.running_jobs: logger.debug("checking status for running job: %s", key) - return (JobStatus.RUNNING, self.running_jobs[key]) + return (JobStatus.RUNNING, self.running_jobs[key], None) for job in self.finished_jobs: if job.job == key: logger.debug("checking status for finished job: %s", key) - return (job.status, job) + return (job.status, job, None) - for job in self.pending_jobs: + for i, job in enumerate(self.pending_jobs): if job.name == key: logger.debug("checking status for pending job: %s", key) - return (JobStatus.PENDING, None) + return (JobStatus.PENDING, None, Progress(i, len(self.pending_jobs))) logger.trace("checking status for unknown job: %s", key) - return (JobStatus.UNKNOWN, None) + return (JobStatus.UNKNOWN, None, None) def join(self): logger.info("stopping worker pool") @@ -399,7 +401,7 @@ class DevicePoolExecutor: *args, needs_device: Optional[DeviceParams] = None, **kwargs, - ) -> None: + ) -> int: device_idx = self.get_next_device(needs_device=needs_device) device = self.devices[device_idx].device logger.info( @@ -413,6 +415,9 @@ class DevicePoolExecutor: job = JobCommand(key, device, job_type, fn, args, kwargs) self.pending_jobs.append(job) + # return position in queue + return len(self.pending_jobs) + def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]: """ Returns a tuple of: job/device, progress, progress, finished, cancelled, failed diff --git a/api/tests/chain/test_edit_metadata.py b/api/tests/chain/test_edit_metadata.py new file mode 100644 index 00000000..fda0bf01 --- /dev/null +++ b/api/tests/chain/test_edit_metadata.py @@ -0,0 +1,41 @@ +import unittest +from unittest.mock import MagicMock + +from onnx_web.chain.edit_metadata import EditMetadataStage + + +class TestEditMetadataStage(unittest.TestCase): + def setUp(self): + self.stage = EditMetadataStage() + + def test_run_with_no_changes(self): + source = MagicMock() + source.metadata = [] + + result = self.stage.run(None, None, None, None, source) + + self.assertEqual(result, source) + + def test_run_with_note_change(self): + source = MagicMock() + source.metadata = [MagicMock()] + note = "New note" + + result = self.stage.run(None, None, None, None, source, note=note) + + self.assertEqual(result, source) + self.assertEqual(result.metadata[0].note, note) + + def test_run_with_replace_params_change(self): + source = MagicMock() + source.metadata = [MagicMock()] + replace_params = MagicMock() + + result = self.stage.run( + None, None, None, None, source, replace_params=replace_params + ) + + self.assertEqual(result, source) + self.assertEqual(result.metadata[0].params, replace_params) + + # Add more test cases for other parameters... diff --git a/api/tests/chain/test_edit_text.py b/api/tests/chain/test_edit_text.py new file mode 100644 index 00000000..4d0b80ce --- /dev/null +++ b/api/tests/chain/test_edit_text.py @@ -0,0 +1,48 @@ +import unittest + +import numpy as np +from PIL import Image + +from onnx_web.chain.edit_text import EditTextStage +from onnx_web.chain.result import StageResult + + +class TestEditTextStage(unittest.TestCase): + def test_run(self): + # Create a sample image + image = Image.new("RGB", (100, 100), color="black") + + # Create an instance of EditTextStage + stage = EditTextStage() + + # Define the input parameters + text = "Hello, World!" + position = (10, 10) + fill = "white" + stroke = "white" + stroke_width = 2 + + # Create a mock source StageResult + source = StageResult.from_images([image], metadata={}) + + # Call the run method + result = stage.run( + None, + None, + None, + None, + source, + text=text, + position=position, + fill=fill, + stroke=stroke, + stroke_width=stroke_width, + ) + + # Assert the output + self.assertEqual(len(result.as_images()), 1) + # self.assertEqual(result.metadata, {}) + + # Verify the modified image + modified_image = result.as_images()[0] + self.assertEqual(np.max(np.array(modified_image)), 255) diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 1c8a2b83..2df53576 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -5,7 +5,7 @@ from typing import Optional from onnx_web.params import DeviceParams from onnx_web.server.context import ServerContext -from onnx_web.worker.command import JobStatus +from onnx_web.worker.command import JobStatus, Progress from onnx_web.worker.pool import DevicePoolExecutor from tests.helpers import test_device @@ -61,10 +61,12 @@ class TestWorkerPool(unittest.TestCase): self.pool.start() self.pool.submit("test", "test", sleep_job, lock=lock) - self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) + self.assertEqual( + self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1)) + ) self.assertTrue(self.pool.cancel("test")) - self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None)) + self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None)) def test_cancel_running(self): pass @@ -104,7 +106,7 @@ class TestWorkerPool(unittest.TestCase): self.pool.submit("test", "test", lock_job) sleep(5.0) - status, _progress = self.pool.status("test") + status, _progress, _status = self.pool.status("test") self.assertEqual(status, JobStatus.RUNNING) def test_done_pending(self): @@ -116,7 +118,9 @@ class TestWorkerPool(unittest.TestCase): self.pool.submit("test1", "test", lock_job) self.pool.submit("test2", "test", lock_job) - self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None)) + self.assertEqual( + self.pool.status("test2"), (JobStatus.PENDING, None, Progress(1, 2)) + ) lock.set() @@ -132,10 +136,12 @@ class TestWorkerPool(unittest.TestCase): ) self.pool.start() self.pool.submit("test", "test", sleep_job) - self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) + self.assertEqual( + self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1)) + ) sleep(5.0) - status, _progress = self.pool.status("test") + status, _progress, _queue = self.pool.status("test") self.assertEqual(status, JobStatus.SUCCESS) def test_recycle_live(self): @@ -162,7 +168,7 @@ class TestWorkerPool(unittest.TestCase): self.pool.submit("test", "test", progress_job) sleep(5.0) - status, progress = self.pool.status("test") + status, progress, _queue = self.pool.status("test") self.assertEqual(status, JobStatus.SUCCESS) self.assertEqual(progress.steps.current, 1) @@ -178,6 +184,6 @@ class TestWorkerPool(unittest.TestCase): self.pool.submit("test", "test", fail_job) sleep(5.0) - status, progress = self.pool.status("test") + status, progress, _queue = self.pool.status("test") self.assertEqual(status, JobStatus.FAILED) self.assertEqual(progress.steps.current, 0)