start adding queue progress
This commit is contained in:
parent
fd92d52339
commit
237accc973
|
@ -1,6 +1,13 @@
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..params import HighresParams, ImageParams, Size, SizeChart, StageParams, UpscaleParams
|
from ..params import (
|
||||||
|
HighresParams,
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
SizeChart,
|
||||||
|
StageParams,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
@ -24,7 +31,7 @@ class EditMetadataStage(BaseStage):
|
||||||
note: Optional[str] = None,
|
note: Optional[str] = None,
|
||||||
replace_params: Optional[ImageParams] = None,
|
replace_params: Optional[ImageParams] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> StageResult:
|
||||||
# Modify the source image's metadata using the provided parameters
|
# Modify the source image's metadata using the provided parameters
|
||||||
for metadata in source.metadata:
|
for metadata in source.metadata:
|
||||||
if note is not None:
|
if note is not None:
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from PIL import ImageDraw
|
||||||
|
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
|
||||||
|
class EditTextStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source: StageResult,
|
||||||
|
*,
|
||||||
|
text: str,
|
||||||
|
position: Tuple[int, int],
|
||||||
|
fill: str = "white",
|
||||||
|
stroke: str = "black",
|
||||||
|
stroke_width: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
# Add text to each image in source at the given position
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for image in source.as_images():
|
||||||
|
image = image.copy()
|
||||||
|
draw = ImageDraw.Draw(image)
|
||||||
|
draw.text(
|
||||||
|
position, text, fill=fill, stroke_width=stroke_width, stroke_fill=stroke
|
||||||
|
)
|
||||||
|
results.append(image)
|
||||||
|
|
||||||
|
return StageResult.from_images(results, source.metadata)
|
|
@ -23,7 +23,7 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
class ChainProgress:
|
class ChainProgress:
|
||||||
parent: ProgressCallback
|
parent: ProgressCallback
|
||||||
step: int # same as steps.current, left for legacy purposes
|
step: int # current number of steps
|
||||||
prev: int # accumulator when step resets
|
prev: int # accumulator when step resets
|
||||||
|
|
||||||
# TODO: should probably be moved to worker context as well
|
# TODO: should probably be moved to worker context as well
|
||||||
|
|
|
@ -93,15 +93,18 @@ def error_reply(err: str):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def job_reply(name: str):
|
EMPTY_PROGRESS = Progress(0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def job_reply(name: str, queue: int = 0):
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"name": name,
|
"name": name,
|
||||||
"queue": Progress(0, 0).tojson(), # TODO: use real queue position
|
"queue": Progress(queue, queue).tojson(),
|
||||||
"status": JobStatus.PENDING,
|
"status": JobStatus.PENDING,
|
||||||
"stages": Progress(0, 0).tojson(),
|
"stages": EMPTY_PROGRESS.tojson(),
|
||||||
"steps": Progress(0, 0).tojson(),
|
"steps": EMPTY_PROGRESS.tojson(),
|
||||||
"tiles": Progress(0, 0).tojson(),
|
"tiles": EMPTY_PROGRESS.tojson(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,24 +113,29 @@ def image_reply(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
name: str,
|
name: str,
|
||||||
status: str,
|
status: str,
|
||||||
|
queue: Progress = None,
|
||||||
stages: Progress = None,
|
stages: Progress = None,
|
||||||
steps: Progress = None,
|
steps: Progress = None,
|
||||||
tiles: Progress = None,
|
tiles: Progress = None,
|
||||||
outputs: List[str] = None,
|
outputs: List[str] = None,
|
||||||
metadata: List[ImageMetadata] = None,
|
metadata: List[ImageMetadata] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
if queue is None:
|
||||||
|
queue = EMPTY_PROGRESS
|
||||||
|
|
||||||
if stages is None:
|
if stages is None:
|
||||||
stages = Progress(0, 0)
|
stages = EMPTY_PROGRESS
|
||||||
|
|
||||||
if steps is None:
|
if steps is None:
|
||||||
steps = Progress(0, 0)
|
steps = EMPTY_PROGRESS
|
||||||
|
|
||||||
if tiles is None:
|
if tiles is None:
|
||||||
tiles = Progress(0, 0)
|
tiles = EMPTY_PROGRESS
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"status": status,
|
"status": status,
|
||||||
|
"queue": queue.tojson(),
|
||||||
"stages": stages.tojson(),
|
"stages": stages.tojson(),
|
||||||
"steps": steps.tojson(),
|
"steps": steps.tojson(),
|
||||||
"tiles": tiles.tojson(),
|
"tiles": tiles.tojson(),
|
||||||
|
@ -263,7 +271,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_count += 1
|
output_count += 1
|
||||||
|
|
||||||
job_name = make_job_name("img2img", params, size, extras=[strength])
|
job_name = make_job_name("img2img", params, size, extras=[strength])
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.IMG2IMG,
|
JobType.IMG2IMG,
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
|
@ -279,7 +287,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("img2img job queued for: %s", job_name)
|
logger.info("img2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -291,7 +299,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
job_name = make_job_name("txt2img", params, size)
|
job_name = make_job_name("txt2img", params, size)
|
||||||
|
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.TXT2IMG,
|
JobType.TXT2IMG,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
|
@ -305,7 +313,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -367,7 +375,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.INPAINT,
|
JobType.INPAINT,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
|
@ -390,7 +398,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("inpaint job queued for: %s", job_name)
|
logger.info("inpaint job queued for: %s", job_name)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -407,7 +415,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
job_name = make_job_name("upscale", params, size)
|
job_name = make_job_name("upscale", params, size)
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.UPSCALE,
|
JobType.UPSCALE,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
|
@ -422,7 +430,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
# keys that are specially parsed by params and should not show up in with_args
|
# keys that are specially parsed by params and should not show up in with_args
|
||||||
|
@ -521,7 +529,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
job_name = make_job_name("chain", base_params, base_size)
|
job_name = make_job_name("chain", base_params, base_size)
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.CHAIN,
|
JobType.CHAIN,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
@ -532,7 +540,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -557,7 +565,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
upscale = build_upscale()
|
upscale = build_upscale()
|
||||||
|
|
||||||
job_name = make_job_name("blend", params, size)
|
job_name = make_job_name("blend", params, size)
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.BLEND,
|
JobType.BLEND,
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
@ -573,7 +581,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -582,7 +590,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
job_name = make_job_name("txt2txt", params, size)
|
job_name = make_job_name("txt2txt", params, size)
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.TXT2TXT,
|
JobType.TXT2TXT,
|
||||||
run_txt2txt_pipeline,
|
run_txt2txt_pipeline,
|
||||||
|
@ -592,7 +600,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return job_reply(job_name)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
||||||
|
|
||||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -612,7 +620,7 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
|
||||||
output_file = sanitize_name(output_file)
|
output_file = sanitize_name(output_file)
|
||||||
status, progress = pool.status(output_file)
|
status, progress, _queue = pool.status(output_file)
|
||||||
|
|
||||||
if status == JobStatus.PENDING:
|
if status == JobStatus.PENDING:
|
||||||
return ready_reply(pending=True)
|
return ready_reply(pending=True)
|
||||||
|
@ -677,7 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
for job_name in job_list:
|
for job_name in job_list:
|
||||||
job_name = sanitize_name(job_name)
|
job_name = sanitize_name(job_name)
|
||||||
status, progress = pool.status(job_name)
|
status, progress, queue = pool.status(job_name)
|
||||||
|
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
outputs = None
|
outputs = None
|
||||||
|
@ -700,7 +708,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
records.append(image_reply(server, job_name, status))
|
records.append(image_reply(server, job_name, status, queue=queue))
|
||||||
|
|
||||||
return jsonify(records)
|
return jsonify(records)
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,18 @@ class Progress:
|
||||||
self.current = current
|
self.current = current
|
||||||
self.total = total
|
self.total = total
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "Progress(%d, %d)" % (self.current, self.total)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "%s/%s" % (self.current, self.total)
|
return "%s/%s" % (self.current, self.total)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if isinstance(other, Progress):
|
||||||
|
return self.current == other.current and self.total == other.total
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
"current": self.current,
|
"current": self.current,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from .command import JobCommand, JobStatus, ProgressCommand
|
from .command import JobCommand, JobStatus, Progress, ProgressCommand
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
from .utils import Interval
|
from .utils import Interval
|
||||||
from .worker import worker_main
|
from .worker import worker_main
|
||||||
|
@ -228,31 +228,33 @@ class DevicePoolExecutor:
|
||||||
self.cancelled_jobs.append(key)
|
self.cancelled_jobs.append(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
|
def status(
|
||||||
|
self, key: str
|
||||||
|
) -> Tuple[JobStatus, Optional[ProgressCommand], Optional[Progress]]:
|
||||||
"""
|
"""
|
||||||
Check if a job has been finished and report the last progress update.
|
Check if a job has been finished and report the last progress update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if key in self.cancelled_jobs:
|
if key in self.cancelled_jobs:
|
||||||
logger.debug("checking status for cancelled job: %s", key)
|
logger.debug("checking status for cancelled job: %s", key)
|
||||||
return (JobStatus.CANCELLED, None)
|
return (JobStatus.CANCELLED, None, None)
|
||||||
|
|
||||||
if key in self.running_jobs:
|
if key in self.running_jobs:
|
||||||
logger.debug("checking status for running job: %s", key)
|
logger.debug("checking status for running job: %s", key)
|
||||||
return (JobStatus.RUNNING, self.running_jobs[key])
|
return (JobStatus.RUNNING, self.running_jobs[key], None)
|
||||||
|
|
||||||
for job in self.finished_jobs:
|
for job in self.finished_jobs:
|
||||||
if job.job == key:
|
if job.job == key:
|
||||||
logger.debug("checking status for finished job: %s", key)
|
logger.debug("checking status for finished job: %s", key)
|
||||||
return (job.status, job)
|
return (job.status, job, None)
|
||||||
|
|
||||||
for job in self.pending_jobs:
|
for i, job in enumerate(self.pending_jobs):
|
||||||
if job.name == key:
|
if job.name == key:
|
||||||
logger.debug("checking status for pending job: %s", key)
|
logger.debug("checking status for pending job: %s", key)
|
||||||
return (JobStatus.PENDING, None)
|
return (JobStatus.PENDING, None, Progress(i, len(self.pending_jobs)))
|
||||||
|
|
||||||
logger.trace("checking status for unknown job: %s", key)
|
logger.trace("checking status for unknown job: %s", key)
|
||||||
return (JobStatus.UNKNOWN, None)
|
return (JobStatus.UNKNOWN, None, None)
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
@ -399,7 +401,7 @@ class DevicePoolExecutor:
|
||||||
*args,
|
*args,
|
||||||
needs_device: Optional[DeviceParams] = None,
|
needs_device: Optional[DeviceParams] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> int:
|
||||||
device_idx = self.get_next_device(needs_device=needs_device)
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
device = self.devices[device_idx].device
|
device = self.devices[device_idx].device
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -413,6 +415,9 @@ class DevicePoolExecutor:
|
||||||
job = JobCommand(key, device, job_type, fn, args, kwargs)
|
job = JobCommand(key, device, job_type, fn, args, kwargs)
|
||||||
self.pending_jobs.append(job)
|
self.pending_jobs.append(job)
|
||||||
|
|
||||||
|
# return position in queue
|
||||||
|
return len(self.pending_jobs)
|
||||||
|
|
||||||
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
|
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from onnx_web.chain.edit_metadata import EditMetadataStage
|
||||||
|
|
||||||
|
|
||||||
|
class TestEditMetadataStage(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.stage = EditMetadataStage()
|
||||||
|
|
||||||
|
def test_run_with_no_changes(self):
|
||||||
|
source = MagicMock()
|
||||||
|
source.metadata = []
|
||||||
|
|
||||||
|
result = self.stage.run(None, None, None, None, source)
|
||||||
|
|
||||||
|
self.assertEqual(result, source)
|
||||||
|
|
||||||
|
def test_run_with_note_change(self):
|
||||||
|
source = MagicMock()
|
||||||
|
source.metadata = [MagicMock()]
|
||||||
|
note = "New note"
|
||||||
|
|
||||||
|
result = self.stage.run(None, None, None, None, source, note=note)
|
||||||
|
|
||||||
|
self.assertEqual(result, source)
|
||||||
|
self.assertEqual(result.metadata[0].note, note)
|
||||||
|
|
||||||
|
def test_run_with_replace_params_change(self):
|
||||||
|
source = MagicMock()
|
||||||
|
source.metadata = [MagicMock()]
|
||||||
|
replace_params = MagicMock()
|
||||||
|
|
||||||
|
result = self.stage.run(
|
||||||
|
None, None, None, None, source, replace_params=replace_params
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result, source)
|
||||||
|
self.assertEqual(result.metadata[0].params, replace_params)
|
||||||
|
|
||||||
|
# Add more test cases for other parameters...
|
|
@ -0,0 +1,48 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.edit_text import EditTextStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestEditTextStage(unittest.TestCase):
|
||||||
|
def test_run(self):
|
||||||
|
# Create a sample image
|
||||||
|
image = Image.new("RGB", (100, 100), color="black")
|
||||||
|
|
||||||
|
# Create an instance of EditTextStage
|
||||||
|
stage = EditTextStage()
|
||||||
|
|
||||||
|
# Define the input parameters
|
||||||
|
text = "Hello, World!"
|
||||||
|
position = (10, 10)
|
||||||
|
fill = "white"
|
||||||
|
stroke = "white"
|
||||||
|
stroke_width = 2
|
||||||
|
|
||||||
|
# Create a mock source StageResult
|
||||||
|
source = StageResult.from_images([image], metadata={})
|
||||||
|
|
||||||
|
# Call the run method
|
||||||
|
result = stage.run(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
source,
|
||||||
|
text=text,
|
||||||
|
position=position,
|
||||||
|
fill=fill,
|
||||||
|
stroke=stroke,
|
||||||
|
stroke_width=stroke_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the output
|
||||||
|
self.assertEqual(len(result.as_images()), 1)
|
||||||
|
# self.assertEqual(result.metadata, {})
|
||||||
|
|
||||||
|
# Verify the modified image
|
||||||
|
modified_image = result.as_images()[0]
|
||||||
|
self.assertEqual(np.max(np.array(modified_image)), 255)
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||||
|
|
||||||
from onnx_web.params import DeviceParams
|
from onnx_web.params import DeviceParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.command import JobStatus
|
from onnx_web.worker.command import JobStatus, Progress
|
||||||
from onnx_web.worker.pool import DevicePoolExecutor
|
from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
from tests.helpers import test_device
|
from tests.helpers import test_device
|
||||||
|
|
||||||
|
@ -61,10 +61,12 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
self.pool.submit("test", "test", sleep_job, lock=lock)
|
self.pool.submit("test", "test", sleep_job, lock=lock)
|
||||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
self.assertEqual(
|
||||||
|
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(self.pool.cancel("test"))
|
self.assertTrue(self.pool.cancel("test"))
|
||||||
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None))
|
||||||
|
|
||||||
def test_cancel_running(self):
|
def test_cancel_running(self):
|
||||||
pass
|
pass
|
||||||
|
@ -104,7 +106,7 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool.submit("test", "test", lock_job)
|
self.pool.submit("test", "test", lock_job)
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
|
||||||
status, _progress = self.pool.status("test")
|
status, _progress, _status = self.pool.status("test")
|
||||||
self.assertEqual(status, JobStatus.RUNNING)
|
self.assertEqual(status, JobStatus.RUNNING)
|
||||||
|
|
||||||
def test_done_pending(self):
|
def test_done_pending(self):
|
||||||
|
@ -116,7 +118,9 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
|
|
||||||
self.pool.submit("test1", "test", lock_job)
|
self.pool.submit("test1", "test", lock_job)
|
||||||
self.pool.submit("test2", "test", lock_job)
|
self.pool.submit("test2", "test", lock_job)
|
||||||
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
self.assertEqual(
|
||||||
|
self.pool.status("test2"), (JobStatus.PENDING, None, Progress(1, 2))
|
||||||
|
)
|
||||||
|
|
||||||
lock.set()
|
lock.set()
|
||||||
|
|
||||||
|
@ -132,10 +136,12 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
self.pool.submit("test", "test", sleep_job)
|
self.pool.submit("test", "test", sleep_job)
|
||||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
self.assertEqual(
|
||||||
|
self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1))
|
||||||
|
)
|
||||||
|
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
status, _progress = self.pool.status("test")
|
status, _progress, _queue = self.pool.status("test")
|
||||||
self.assertEqual(status, JobStatus.SUCCESS)
|
self.assertEqual(status, JobStatus.SUCCESS)
|
||||||
|
|
||||||
def test_recycle_live(self):
|
def test_recycle_live(self):
|
||||||
|
@ -162,7 +168,7 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool.submit("test", "test", progress_job)
|
self.pool.submit("test", "test", progress_job)
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
|
||||||
status, progress = self.pool.status("test")
|
status, progress, _queue = self.pool.status("test")
|
||||||
self.assertEqual(status, JobStatus.SUCCESS)
|
self.assertEqual(status, JobStatus.SUCCESS)
|
||||||
self.assertEqual(progress.steps.current, 1)
|
self.assertEqual(progress.steps.current, 1)
|
||||||
|
|
||||||
|
@ -178,6 +184,6 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool.submit("test", "test", fail_job)
|
self.pool.submit("test", "test", fail_job)
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
|
||||||
status, progress = self.pool.status("test")
|
status, progress, _queue = self.pool.status("test")
|
||||||
self.assertEqual(status, JobStatus.FAILED)
|
self.assertEqual(status, JobStatus.FAILED)
|
||||||
self.assertEqual(progress.steps.current, 0)
|
self.assertEqual(progress.steps.current, 0)
|
||||||
|
|
Loading…
Reference in New Issue