1
0
Fork 0

start adding queue progress

This commit is contained in:
Sean Sube 2024-01-07 08:16:13 -06:00
parent fd92d52339
commit 237accc973
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 212 additions and 47 deletions

View File

@ -1,6 +1,13 @@
from typing import Any, Optional from typing import Optional
from ..params import HighresParams, ImageParams, Size, SizeChart, StageParams, UpscaleParams from ..params import (
HighresParams,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage
@ -24,7 +31,7 @@ class EditMetadataStage(BaseStage):
note: Optional[str] = None, note: Optional[str] = None,
replace_params: Optional[ImageParams] = None, replace_params: Optional[ImageParams] = None,
**kwargs, **kwargs,
) -> Any: ) -> StageResult:
# Modify the source image's metadata using the provided parameters # Modify the source image's metadata using the provided parameters
for metadata in source.metadata: for metadata in source.metadata:
if note is not None: if note is not None:

View File

@ -0,0 +1,41 @@
from typing import Tuple
from PIL import ImageDraw
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
class EditTextStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: StageResult,
*,
text: str,
position: Tuple[int, int],
fill: str = "white",
stroke: str = "black",
stroke_width: int = 1,
**kwargs,
) -> StageResult:
# Add text to each image in source at the given position
results = []
for image in source.as_images():
image = image.copy()
draw = ImageDraw.Draw(image)
draw.text(
position, text, fill=fill, stroke_width=stroke_width, stroke_fill=stroke
)
results.append(image)
return StageResult.from_images(results, source.metadata)

View File

@ -23,7 +23,7 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress: class ChainProgress:
parent: ProgressCallback parent: ProgressCallback
step: int # same as steps.current, left for legacy purposes step: int # current number of steps
prev: int # accumulator when step resets prev: int # accumulator when step resets
# TODO: should probably be moved to worker context as well # TODO: should probably be moved to worker context as well

View File

@ -93,15 +93,18 @@ def error_reply(err: str):
return response return response
def job_reply(name: str): EMPTY_PROGRESS = Progress(0, 0)
def job_reply(name: str, queue: int = 0):
return jsonify( return jsonify(
{ {
"name": name, "name": name,
"queue": Progress(0, 0).tojson(), # TODO: use real queue position "queue": Progress(queue, queue).tojson(),
"status": JobStatus.PENDING, "status": JobStatus.PENDING,
"stages": Progress(0, 0).tojson(), "stages": EMPTY_PROGRESS.tojson(),
"steps": Progress(0, 0).tojson(), "steps": EMPTY_PROGRESS.tojson(),
"tiles": Progress(0, 0).tojson(), "tiles": EMPTY_PROGRESS.tojson(),
} }
) )
@ -110,24 +113,29 @@ def image_reply(
server: ServerContext, server: ServerContext,
name: str, name: str,
status: str, status: str,
queue: Progress = None,
stages: Progress = None, stages: Progress = None,
steps: Progress = None, steps: Progress = None,
tiles: Progress = None, tiles: Progress = None,
outputs: List[str] = None, outputs: List[str] = None,
metadata: List[ImageMetadata] = None, metadata: List[ImageMetadata] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if queue is None:
queue = EMPTY_PROGRESS
if stages is None: if stages is None:
stages = Progress(0, 0) stages = EMPTY_PROGRESS
if steps is None: if steps is None:
steps = Progress(0, 0) steps = EMPTY_PROGRESS
if tiles is None: if tiles is None:
tiles = Progress(0, 0) tiles = EMPTY_PROGRESS
data = { data = {
"name": name, "name": name,
"status": status, "status": status,
"queue": queue.tojson(),
"stages": stages.tojson(), "stages": stages.tojson(),
"steps": steps.tojson(), "steps": steps.tojson(),
"tiles": tiles.tojson(), "tiles": tiles.tojson(),
@ -263,7 +271,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
output_count += 1 output_count += 1
job_name = make_job_name("img2img", params, size, extras=[strength]) job_name = make_job_name("img2img", params, size, extras=[strength])
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.IMG2IMG, JobType.IMG2IMG,
run_img2img_pipeline, run_img2img_pipeline,
@ -279,7 +287,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("img2img job queued for: %s", job_name) logger.info("img2img job queued for: %s", job_name)
return job_reply(job_name) return job_reply(job_name, queue=queue)
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
@ -291,7 +299,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
job_name = make_job_name("txt2img", params, size) job_name = make_job_name("txt2img", params, size)
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.TXT2IMG, JobType.TXT2IMG,
run_txt2img_pipeline, run_txt2img_pipeline,
@ -305,7 +313,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("txt2img job queued for: %s", job_name) logger.info("txt2img job queued for: %s", job_name)
return job_reply(job_name) return job_reply(job_name, queue=queue)
def inpaint(server: ServerContext, pool: DevicePoolExecutor): def inpaint(server: ServerContext, pool: DevicePoolExecutor):
@ -367,7 +375,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
], ],
) )
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.INPAINT, JobType.INPAINT,
run_inpaint_pipeline, run_inpaint_pipeline,
@ -390,7 +398,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
logger.info("inpaint job queued for: %s", job_name) logger.info("inpaint job queued for: %s", job_name)
return job_reply(job_name) return job_reply(job_name, queue=queue)
def upscale(server: ServerContext, pool: DevicePoolExecutor): def upscale(server: ServerContext, pool: DevicePoolExecutor):
@ -407,7 +415,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
job_name = make_job_name("upscale", params, size) job_name = make_job_name("upscale", params, size)
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.UPSCALE, JobType.UPSCALE,
run_upscale_pipeline, run_upscale_pipeline,
@ -422,7 +430,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return job_reply(job_name) return job_reply(job_name, queue=queue)
# keys that are specially parsed by params and should not show up in with_args # keys that are specially parsed by params and should not show up in with_args
@ -521,7 +529,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
job_name = make_job_name("chain", base_params, base_size) job_name = make_job_name("chain", base_params, base_size)
# build and run chain pipeline # build and run chain pipeline
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.CHAIN, JobType.CHAIN,
pipeline, pipeline,
@ -532,7 +540,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
needs_device=device, needs_device=device,
) )
return job_reply(job_name) return job_reply(job_name, queue=queue)
def blend(server: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -557,7 +565,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
upscale = build_upscale() upscale = build_upscale()
job_name = make_job_name("blend", params, size) job_name = make_job_name("blend", params, size)
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.BLEND, JobType.BLEND,
run_blend_pipeline, run_blend_pipeline,
@ -573,7 +581,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return job_reply(job_name) return job_reply(job_name, queue=queue)
def txt2txt(server: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
@ -582,7 +590,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
job_name = make_job_name("txt2txt", params, size) job_name = make_job_name("txt2txt", params, size)
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
pool.submit( queue = pool.submit(
job_name, job_name,
JobType.TXT2TXT, JobType.TXT2TXT,
run_txt2txt_pipeline, run_txt2txt_pipeline,
@ -592,7 +600,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
needs_device=device, needs_device=device,
) )
return job_reply(job_name) return job_reply(job_name, queue=queue)
def cancel(server: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):
@ -612,7 +620,7 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required") return error_reply("output name is required")
output_file = sanitize_name(output_file) output_file = sanitize_name(output_file)
status, progress = pool.status(output_file) status, progress, _queue = pool.status(output_file)
if status == JobStatus.PENDING: if status == JobStatus.PENDING:
return ready_reply(pending=True) return ready_reply(pending=True)
@ -677,7 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
for job_name in job_list: for job_name in job_list:
job_name = sanitize_name(job_name) job_name = sanitize_name(job_name)
status, progress = pool.status(job_name) status, progress, queue = pool.status(job_name)
if progress is not None: if progress is not None:
outputs = None outputs = None
@ -700,7 +708,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
) )
) )
else: else:
records.append(image_reply(server, job_name, status)) records.append(image_reply(server, job_name, status, queue=queue))
return jsonify(records) return jsonify(records)

View File

@ -36,9 +36,18 @@ class Progress:
self.current = current self.current = current
self.total = total self.total = total
def __repr__(self) -> str:
return "Progress(%d, %d)" % (self.current, self.total)
def __str__(self) -> str: def __str__(self) -> str:
return "%s/%s" % (self.current, self.total) return "%s/%s" % (self.current, self.total)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Progress):
return self.current == other.current and self.total == other.total
return False
def tojson(self): def tojson(self):
return { return {
"current": self.current, "current": self.current,

View File

@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .command import JobCommand, JobStatus, ProgressCommand from .command import JobCommand, JobStatus, Progress, ProgressCommand
from .context import WorkerContext from .context import WorkerContext
from .utils import Interval from .utils import Interval
from .worker import worker_main from .worker import worker_main
@ -228,31 +228,33 @@ class DevicePoolExecutor:
self.cancelled_jobs.append(key) self.cancelled_jobs.append(key)
return True return True
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]: def status(
self, key: str
) -> Tuple[JobStatus, Optional[ProgressCommand], Optional[Progress]]:
""" """
Check if a job has been finished and report the last progress update. Check if a job has been finished and report the last progress update.
""" """
if key in self.cancelled_jobs: if key in self.cancelled_jobs:
logger.debug("checking status for cancelled job: %s", key) logger.debug("checking status for cancelled job: %s", key)
return (JobStatus.CANCELLED, None) return (JobStatus.CANCELLED, None, None)
if key in self.running_jobs: if key in self.running_jobs:
logger.debug("checking status for running job: %s", key) logger.debug("checking status for running job: %s", key)
return (JobStatus.RUNNING, self.running_jobs[key]) return (JobStatus.RUNNING, self.running_jobs[key], None)
for job in self.finished_jobs: for job in self.finished_jobs:
if job.job == key: if job.job == key:
logger.debug("checking status for finished job: %s", key) logger.debug("checking status for finished job: %s", key)
return (job.status, job) return (job.status, job, None)
for job in self.pending_jobs: for i, job in enumerate(self.pending_jobs):
if job.name == key: if job.name == key:
logger.debug("checking status for pending job: %s", key) logger.debug("checking status for pending job: %s", key)
return (JobStatus.PENDING, None) return (JobStatus.PENDING, None, Progress(i, len(self.pending_jobs)))
logger.trace("checking status for unknown job: %s", key) logger.trace("checking status for unknown job: %s", key)
return (JobStatus.UNKNOWN, None) return (JobStatus.UNKNOWN, None, None)
def join(self): def join(self):
logger.info("stopping worker pool") logger.info("stopping worker pool")
@ -399,7 +401,7 @@ class DevicePoolExecutor:
*args, *args,
needs_device: Optional[DeviceParams] = None, needs_device: Optional[DeviceParams] = None,
**kwargs, **kwargs,
) -> None: ) -> int:
device_idx = self.get_next_device(needs_device=needs_device) device_idx = self.get_next_device(needs_device=needs_device)
device = self.devices[device_idx].device device = self.devices[device_idx].device
logger.info( logger.info(
@ -413,6 +415,9 @@ class DevicePoolExecutor:
job = JobCommand(key, device, job_type, fn, args, kwargs) job = JobCommand(key, device, job_type, fn, args, kwargs)
self.pending_jobs.append(job) self.pending_jobs.append(job)
# return position in queue
return len(self.pending_jobs)
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]: def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
""" """
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed Returns a tuple of: job/device, progress, progress, finished, cancelled, failed

View File

@ -0,0 +1,41 @@
import unittest
from unittest.mock import MagicMock
from onnx_web.chain.edit_metadata import EditMetadataStage
class TestEditMetadataStage(unittest.TestCase):
def setUp(self):
self.stage = EditMetadataStage()
def test_run_with_no_changes(self):
source = MagicMock()
source.metadata = []
result = self.stage.run(None, None, None, None, source)
self.assertEqual(result, source)
def test_run_with_note_change(self):
source = MagicMock()
source.metadata = [MagicMock()]
note = "New note"
result = self.stage.run(None, None, None, None, source, note=note)
self.assertEqual(result, source)
self.assertEqual(result.metadata[0].note, note)
def test_run_with_replace_params_change(self):
source = MagicMock()
source.metadata = [MagicMock()]
replace_params = MagicMock()
result = self.stage.run(
None, None, None, None, source, replace_params=replace_params
)
self.assertEqual(result, source)
self.assertEqual(result.metadata[0].params, replace_params)
# Add more test cases for other parameters...

View File

@ -0,0 +1,48 @@
import unittest
import numpy as np
from PIL import Image
from onnx_web.chain.edit_text import EditTextStage
from onnx_web.chain.result import StageResult
class TestEditTextStage(unittest.TestCase):
def test_run(self):
# Create a sample image
image = Image.new("RGB", (100, 100), color="black")
# Create an instance of EditTextStage
stage = EditTextStage()
# Define the input parameters
text = "Hello, World!"
position = (10, 10)
fill = "white"
stroke = "white"
stroke_width = 2
# Create a mock source StageResult
source = StageResult.from_images([image], metadata={})
# Call the run method
result = stage.run(
None,
None,
None,
None,
source,
text=text,
position=position,
fill=fill,
stroke=stroke,
stroke_width=stroke_width,
)
# Assert the output
self.assertEqual(len(result.as_images()), 1)
# self.assertEqual(result.metadata, {})
# Verify the modified image
modified_image = result.as_images()[0]
self.assertEqual(np.max(np.array(modified_image)), 255)

View File

@ -5,7 +5,7 @@ from typing import Optional
from onnx_web.params import DeviceParams from onnx_web.params import DeviceParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobStatus from onnx_web.worker.command import JobStatus, Progress
from onnx_web.worker.pool import DevicePoolExecutor from onnx_web.worker.pool import DevicePoolExecutor
from tests.helpers import test_device from tests.helpers import test_device
@ -61,10 +61,12 @@ class TestWorkerPool(unittest.TestCase):
self.pool.start() self.pool.start()
self.pool.submit("test", "test", sleep_job, lock=lock) self.pool.submit("test", "test", sleep_job, lock=lock)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertEqual(
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
)
self.assertTrue(self.pool.cancel("test")) self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None)) self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None))
def test_cancel_running(self): def test_cancel_running(self):
pass pass
@ -104,7 +106,7 @@ class TestWorkerPool(unittest.TestCase):
self.pool.submit("test", "test", lock_job) self.pool.submit("test", "test", lock_job)
sleep(5.0) sleep(5.0)
status, _progress = self.pool.status("test") status, _progress, _status = self.pool.status("test")
self.assertEqual(status, JobStatus.RUNNING) self.assertEqual(status, JobStatus.RUNNING)
def test_done_pending(self): def test_done_pending(self):
@ -116,7 +118,9 @@ class TestWorkerPool(unittest.TestCase):
self.pool.submit("test1", "test", lock_job) self.pool.submit("test1", "test", lock_job)
self.pool.submit("test2", "test", lock_job) self.pool.submit("test2", "test", lock_job)
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None)) self.assertEqual(
self.pool.status("test2"), (JobStatus.PENDING, None, Progress(1, 2))
)
lock.set() lock.set()
@ -132,10 +136,12 @@ class TestWorkerPool(unittest.TestCase):
) )
self.pool.start() self.pool.start()
self.pool.submit("test", "test", sleep_job) self.pool.submit("test", "test", sleep_job)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertEqual(
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
)
sleep(5.0) sleep(5.0)
status, _progress = self.pool.status("test") status, _progress, _queue = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS) self.assertEqual(status, JobStatus.SUCCESS)
def test_recycle_live(self): def test_recycle_live(self):
@ -162,7 +168,7 @@ class TestWorkerPool(unittest.TestCase):
self.pool.submit("test", "test", progress_job) self.pool.submit("test", "test", progress_job)
sleep(5.0) sleep(5.0)
status, progress = self.pool.status("test") status, progress, _queue = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS) self.assertEqual(status, JobStatus.SUCCESS)
self.assertEqual(progress.steps.current, 1) self.assertEqual(progress.steps.current, 1)
@ -178,6 +184,6 @@ class TestWorkerPool(unittest.TestCase):
self.pool.submit("test", "test", fail_job) self.pool.submit("test", "test", fail_job)
sleep(5.0) sleep(5.0)
status, progress = self.pool.status("test") status, progress, _queue = self.pool.status("test")
self.assertEqual(status, JobStatus.FAILED) self.assertEqual(status, JobStatus.FAILED)
self.assertEqual(progress.steps.current, 0) self.assertEqual(progress.steps.current, 0)