45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
|
import unittest
|
||
|
|
||
|
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
|
||
|
from onnx_web.chain.result import StageResult
|
||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||
|
from onnx_web.server.context import ServerContext
|
||
|
from onnx_web.server.hacks import apply_patches
|
||
|
from onnx_web.worker.context import WorkerContext
|
||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||
|
|
||
|
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||
|
|
||
|
|
||
|
class CorrectGFPGANStageTests(unittest.TestCase):
|
||
|
@test_needs_onnx_models([TEST_MODEL])
|
||
|
def test_empty(self):
|
||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||
|
apply_patches(server)
|
||
|
|
||
|
worker = WorkerContext(
|
||
|
"test",
|
||
|
test_device(),
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
0,
|
||
|
0.1,
|
||
|
)
|
||
|
stage = CorrectGFPGANStage()
|
||
|
sources = StageResult.empty()
|
||
|
result = stage.run(
|
||
|
worker,
|
||
|
None,
|
||
|
None,
|
||
|
None,
|
||
|
sources,
|
||
|
highres=HighresParams(False, 1, 0, 0),
|
||
|
upscale=UpscaleParams(TEST_MODEL),
|
||
|
)
|
||
|
|
||
|
self.assertEqual(len(result), 0)
|