lint all the new stuff
This commit is contained in:
parent
b880b7a121
commit
584dddb5d6
|
@ -57,4 +57,4 @@ from .utils import (
|
||||||
)
|
)
|
||||||
from .worker import (
|
from .worker import (
|
||||||
DevicePoolExecutor,
|
DevicePoolExecutor,
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,4 +29,4 @@ CHAIN_STAGES = {
|
||||||
"upscale-outpaint": upscale_outpaint,
|
"upscale-outpaint": upscale_outpaint,
|
||||||
"upscale-resrgan": upscale_resrgan,
|
"upscale-resrgan": upscale_resrgan,
|
||||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,9 +7,9 @@ from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -7,8 +7,8 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusion.load import load_pipeline
|
from ..diffusion.load import load_pipeline
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -7,9 +7,9 @@ from onnx_web.image import valid_image
|
||||||
from onnx_web.output import save_image
|
from onnx_web.output import save_image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,8 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,9 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from ..server import ServerContext
|
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,8 @@ from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,8 @@ from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,8 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,8 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,8 @@ from typing import Callable
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,8 @@ from PIL import Image
|
||||||
|
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .utils import process_tile_grid, process_tile_order
|
from .utils import process_tile_grid, process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -6,9 +6,9 @@ from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxNet
|
from ..onnx import OnnxNet
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..worker import WorkerContext, ProgressCallback
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,9 @@ from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DiffusionPipeline,
|
|
||||||
OnnxRuntimeModel,
|
|
||||||
StableDiffusionPipeline,
|
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
DiffusionPipeline,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
DPMSolverSinglestepScheduler,
|
DPMSolverSinglestepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
|
@ -19,7 +17,9 @@ from diffusers import (
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
|
OnnxRuntimeModel,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -12,10 +12,10 @@ from onnx_web.chain.base import ChainProgress
|
||||||
from ..chain import upscale_outpaint
|
from ..chain import upscale_outpaint
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..worker import WorkerContext
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..upscale import run_upscale_correction
|
from ..upscale import run_upscale_correction
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
from ..worker import WorkerContext
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -7,49 +7,53 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||||
from torch.multiprocessing import set_start_method
|
from torch.multiprocessing import set_start_method
|
||||||
|
|
||||||
from .server.api import register_api_routes
|
from .server.api import register_api_routes
|
||||||
from .server.static import register_static_routes
|
from .server.config import (
|
||||||
from .server.config import get_available_platforms, load_models, load_params, load_platforms
|
get_available_platforms,
|
||||||
from .server.utils import check_paths
|
load_models,
|
||||||
|
load_params,
|
||||||
|
load_platforms,
|
||||||
|
)
|
||||||
from .server.context import ServerContext
|
from .server.context import ServerContext
|
||||||
from .server.hacks import apply_patches
|
from .server.hacks import apply_patches
|
||||||
from .utils import (
|
from .server.static import register_static_routes
|
||||||
is_debug,
|
from .server.utils import check_paths
|
||||||
)
|
from .utils import is_debug
|
||||||
from .worker import DevicePoolExecutor
|
from .worker import DevicePoolExecutor
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
set_start_method("spawn", force=True)
|
set_start_method("spawn", force=True)
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
context = ServerContext.from_environ()
|
||||||
apply_patches(context)
|
apply_patches(context)
|
||||||
check_paths(context)
|
check_paths(context)
|
||||||
load_models(context)
|
load_models(context)
|
||||||
load_params(context)
|
load_params(context)
|
||||||
load_platforms(context)
|
load_platforms(context)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
gc.set_debug(gc.DEBUG_STATS)
|
gc.set_debug(gc.DEBUG_STATS)
|
||||||
|
|
||||||
if not context.show_progress:
|
if not context.show_progress:
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
disable_progress_bars()
|
disable_progress_bars()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app, origins=context.cors_origin)
|
CORS(app, origins=context.cors_origin)
|
||||||
|
|
||||||
# any is a fake device, should not be in the pool
|
# any is a fake device, should not be in the pool
|
||||||
pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"])
|
pool = DevicePoolExecutor(
|
||||||
|
context, [p for p in get_available_platforms() if p.device != "any"]
|
||||||
|
)
|
||||||
|
|
||||||
# register routes
|
# register routes
|
||||||
register_static_routes(app, context, pool)
|
register_static_routes(app, context, pool)
|
||||||
register_api_routes(app, context, pool)
|
register_api_routes(app, context, pool)
|
||||||
|
|
||||||
return app, pool
|
return app, pool
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app, pool = main()
|
app, pool = main()
|
||||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||||
pool.join()
|
pool.join()
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from onnxruntime import InferenceSession, SessionOptions
|
|
||||||
|
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from .torch_before_ort import InferenceSession, SessionOptions
|
||||||
|
|
||||||
|
|
||||||
class OnnxTensor:
|
class OnnxTensor:
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# this file exists to make sure torch is always imported before onnxruntime
|
||||||
|
# to work around https://github.com/microsoft/onnxruntime/issues/11092
|
||||||
|
|
||||||
|
import torch # NOQA
|
||||||
|
from onnxruntime import * # NOQA
|
|
@ -2,8 +2,7 @@ from enum import IntEnum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
|
||||||
from onnxruntime import GraphOptimizationLevel, SessionOptions
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -7,27 +7,7 @@ from flask import Flask, jsonify, make_response, request, url_for
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .context import ServerContext
|
from ..chain import CHAIN_STAGES, ChainPipeline
|
||||||
from .utils import wrap_route
|
|
||||||
from ..worker.pool import DevicePoolExecutor
|
|
||||||
|
|
||||||
from .config import (
|
|
||||||
get_available_platforms,
|
|
||||||
get_config_params,
|
|
||||||
get_config_value,
|
|
||||||
get_correction_models,
|
|
||||||
get_diffusion_models,
|
|
||||||
get_inversion_models,
|
|
||||||
get_mask_filters,
|
|
||||||
get_noise_sources,
|
|
||||||
get_upscaling_models,
|
|
||||||
)
|
|
||||||
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
|
||||||
|
|
||||||
from ..chain import (
|
|
||||||
CHAIN_STAGES,
|
|
||||||
ChainPipeline,
|
|
||||||
)
|
|
||||||
from ..diffusion.load import get_pipeline_schedulers
|
from ..diffusion.load import get_pipeline_schedulers
|
||||||
from ..diffusion.run import (
|
from ..diffusion.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
@ -36,16 +16,9 @@ from ..diffusion.run import (
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import valid_image # mask filters; noise sources
|
||||||
valid_image,
|
|
||||||
)
|
|
||||||
from ..output import json_params, make_output_name
|
from ..output import json_params, make_output_name
|
||||||
from ..params import (
|
from ..params import Border, StageParams, TileOrder, UpscaleParams
|
||||||
Border,
|
|
||||||
StageParams,
|
|
||||||
TileOrder,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from ..transformers import run_txt2txt_pipeline
|
from ..transformers import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -56,6 +29,21 @@ from ..utils import (
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
get_size,
|
get_size,
|
||||||
)
|
)
|
||||||
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
from .config import (
|
||||||
|
get_available_platforms,
|
||||||
|
get_config_params,
|
||||||
|
get_config_value,
|
||||||
|
get_correction_models,
|
||||||
|
get_diffusion_models,
|
||||||
|
get_inversion_models,
|
||||||
|
get_mask_filters,
|
||||||
|
get_noise_sources,
|
||||||
|
get_upscaling_models,
|
||||||
|
)
|
||||||
|
from .context import ServerContext
|
||||||
|
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
||||||
|
from .utils import wrap_route
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -456,22 +444,38 @@ def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return [
|
return [
|
||||||
app.route("/api")(wrap_route(introspect, context, app=app)),
|
app.route("/api")(wrap_route(introspect, context, app=app)),
|
||||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
||||||
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
||||||
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
||||||
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
||||||
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
||||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
||||||
app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)),
|
app.route("/api/img2img", methods=["POST"])(
|
||||||
app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)),
|
wrap_route(img2img, context, pool=pool)
|
||||||
app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)),
|
),
|
||||||
app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)),
|
app.route("/api/txt2img", methods=["POST"])(
|
||||||
app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)),
|
wrap_route(txt2img, context, pool=pool)
|
||||||
app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)),
|
),
|
||||||
app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)),
|
app.route("/api/txt2txt", methods=["POST"])(
|
||||||
app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)),
|
wrap_route(txt2txt, context, pool=pool)
|
||||||
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
),
|
||||||
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
app.route("/api/inpaint", methods=["POST"])(
|
||||||
]
|
wrap_route(inpaint, context, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/upscale", methods=["POST"])(
|
||||||
|
wrap_route(upscale, context, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/chain", methods=["POST"])(
|
||||||
|
wrap_route(chain, context, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/blend", methods=["POST"])(
|
||||||
|
wrap_route(blend, context, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/cancel", methods=["PUT"])(
|
||||||
|
wrap_route(cancel, context, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
||||||
|
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
||||||
|
]
|
||||||
|
|
|
@ -2,13 +2,11 @@ from functools import cmp_to_key
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from onnxruntime import get_available_providers
|
|
||||||
|
|
||||||
from .context import ServerContext
|
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -20,9 +18,9 @@ from ..image import ( # mask filters; noise sources
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from ..params import (
|
from ..onnx.torch_before_ort import get_available_providers
|
||||||
DeviceParams,
|
from ..params import DeviceParams
|
||||||
)
|
from .context import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -221,4 +219,3 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
"available acceleration platforms: %s",
|
"available acceleration platforms: %s",
|
||||||
", ".join([str(p) for p in available_platforms]),
|
", ".join([str(p) for p in available_platforms]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,30 +4,24 @@ from typing import Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from .context import ServerContext
|
|
||||||
|
|
||||||
from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models
|
|
||||||
from .utils import get_model_path
|
|
||||||
|
|
||||||
from ..diffusion.load import pipeline_schedulers
|
from ..diffusion.load import pipeline_schedulers
|
||||||
from ..params import (
|
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
||||||
Border,
|
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
||||||
DeviceParams,
|
from .config import (
|
||||||
ImageParams,
|
get_available_platforms,
|
||||||
Size,
|
get_config_value,
|
||||||
UpscaleParams,
|
get_correction_models,
|
||||||
)
|
get_upscaling_models,
|
||||||
from ..utils import (
|
|
||||||
get_and_clamp_float,
|
|
||||||
get_and_clamp_int,
|
|
||||||
get_from_list,
|
|
||||||
get_not_empty,
|
|
||||||
)
|
)
|
||||||
|
from .context import ServerContext
|
||||||
|
from .utils import get_model_path
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]:
|
def pipeline_from_request(
|
||||||
|
context: ServerContext,
|
||||||
|
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# platform stuff
|
# platform stuff
|
||||||
|
@ -43,9 +37,7 @@ def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImagePa
|
||||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(context, model)
|
model_path = get_model_path(context, model)
|
||||||
scheduler = get_from_list(
|
scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
|
||||||
request.args, "scheduler", pipeline_schedulers.keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = get_config_value("scheduler")
|
scheduler = get_config_value("scheduler")
|
||||||
|
|
|
@ -2,9 +2,9 @@ from os import path
|
||||||
|
|
||||||
from flask import Flask, send_from_directory
|
from flask import Flask, send_from_directory
|
||||||
|
|
||||||
from .utils import wrap_route
|
|
||||||
from .context import ServerContext
|
|
||||||
from ..worker.pool import DevicePoolExecutor
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
from .context import ServerContext
|
||||||
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
|
||||||
def serve_bundle_file(context: ServerContext, filename="index.html"):
|
def serve_bundle_file(context: ServerContext, filename="index.html"):
|
||||||
|
@ -26,9 +26,11 @@ def output(context: ServerContext, filename: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
def register_static_routes(
|
||||||
return [
|
app: Flask, context: ServerContext, pool: DevicePoolExecutor
|
||||||
app.route("/")(wrap_route(index, context)),
|
):
|
||||||
app.route("/<path:filename>")(wrap_route(index_path, context)),
|
return [
|
||||||
app.route("/output/<path:filename>")(wrap_route(output, context)),
|
app.route("/")(wrap_route(index, context)),
|
||||||
]
|
app.route("/<path:filename>")(wrap_route(index_path, context)),
|
||||||
|
app.route("/output/<path:filename>")(wrap_route(output, context)),
|
||||||
|
]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
from functools import partial, update_wrapper
|
||||||
from os import makedirs, path
|
from os import makedirs, path
|
||||||
from typing import Callable, Dict, List, Tuple
|
from typing import Callable, Dict, List, Tuple
|
||||||
from functools import partial, update_wrapper
|
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
|
||||||
|
@ -22,7 +22,12 @@ def get_model_path(context: ServerContext, model: str):
|
||||||
return base_join(context.model_path, model)
|
return base_join(context.model_path, model)
|
||||||
|
|
||||||
|
|
||||||
def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]):
|
def register_routes(
|
||||||
|
app: Flask,
|
||||||
|
context: ServerContext,
|
||||||
|
pool: DevicePoolExecutor,
|
||||||
|
routes: List[Tuple[str, Dict, Callable]],
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .chain import (
|
||||||
)
|
)
|
||||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from .server import ServerContext
|
from .server import ServerContext
|
||||||
from .worker import WorkerContext, ProgressCallback
|
from .worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .context import WorkerContext, ProgressCallback
|
from .context import WorkerContext, ProgressCallback
|
||||||
from .pool import DevicePoolExecutor
|
from .pool import DevicePoolExecutor
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from torch.multiprocessing import Queue, Value
|
|
||||||
from typing import Any, Callable, Tuple
|
from typing import Any, Callable, Tuple
|
||||||
|
|
||||||
|
from torch.multiprocessing import Queue, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -9,6 +10,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
ProgressCallback = Callable[[int, int, Any], None]
|
ProgressCallback = Callable[[int, int, Any], None]
|
||||||
|
|
||||||
|
|
||||||
class WorkerContext:
|
class WorkerContext:
|
||||||
cancel: "Value[bool]" = None
|
cancel: "Value[bool]" = None
|
||||||
key: str = None
|
key: str = None
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
# TODO: queue-based logger
|
# TODO: queue-based logger
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from multiprocessing import Queue
|
from multiprocessing import Queue
|
||||||
from torch.multiprocessing import Lock, Process, Value
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from torch.multiprocessing import Lock, Process, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
|
@ -35,7 +36,7 @@ class DevicePoolExecutor:
|
||||||
self.pending = {}
|
self.pending = {}
|
||||||
self.progress = {}
|
self.progress = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
self.jobs = {} # Dict[Output, Device]
|
self.jobs = {} # Dict[Output, Device]
|
||||||
self.job_count = 0
|
self.job_count = 0
|
||||||
|
|
||||||
# TODO: make this a method
|
# TODO: make this a method
|
||||||
|
@ -58,15 +59,21 @@ class DevicePoolExecutor:
|
||||||
cancel = Value("B", False, lock=lock)
|
cancel = Value("B", False, lock=lock)
|
||||||
finished = Value("B", False)
|
finished = Value("B", False)
|
||||||
self.finished[name] = finished
|
self.finished[name] = finished
|
||||||
progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why?
|
progress = Value(
|
||||||
|
"I", 0
|
||||||
|
) # , lock=lock) # needs its own lock for some reason. TODO: why?
|
||||||
self.progress[name] = progress
|
self.progress[name] = progress
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
self.pending[name] = pending
|
self.pending[name] = pending
|
||||||
context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished)
|
context = WorkerContext(
|
||||||
|
name, cancel, device, pending, progress, self.log_queue, finished
|
||||||
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
self.workers[name] = Process(target=worker_init, args=(lock, context, server))
|
self.workers[name] = Process(
|
||||||
|
target=worker_init, args=(lock, context, server)
|
||||||
|
)
|
||||||
self.workers[name].start()
|
self.workers[name].start()
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
|
@ -78,7 +85,7 @@ class DevicePoolExecutor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||||
if not key in self.jobs:
|
if key not in self.jobs:
|
||||||
logger.warn("checking status for unknown key: %s", key)
|
logger.warn("checking status for unknown key: %s", key)
|
||||||
return (None, 0)
|
return (None, 0)
|
||||||
|
|
||||||
|
@ -88,7 +95,6 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
return (finished.value, progress.value)
|
return (finished.value, progress.value)
|
||||||
|
|
||||||
|
|
||||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||||
# respect overrides if possible
|
# respect overrides if possible
|
||||||
if needs_device is not None:
|
if needs_device is not None:
|
||||||
|
@ -96,9 +102,7 @@ class DevicePoolExecutor:
|
||||||
if self.devices[i].device == needs_device.device:
|
if self.devices[i].device == needs_device.device:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
pending = [
|
pending = [self.pending[d.device].qsize() for d in self.devices]
|
||||||
self.pending[d.device].qsize() for d in self.devices
|
|
||||||
]
|
|
||||||
jobs = Counter(range(len(self.devices)))
|
jobs = Counter(range(len(self.devices)))
|
||||||
jobs.update(pending)
|
jobs.update(pending)
|
||||||
|
|
||||||
|
@ -128,7 +132,7 @@ class DevicePoolExecutor:
|
||||||
finished_count - self.finished_limit,
|
finished_count - self.finished_limit,
|
||||||
finished_count,
|
finished_count,
|
||||||
)
|
)
|
||||||
self.finished[:] = self.finished[-self.finished_limit:]
|
self.finished[:] = self.finished[-self.finished_limit :]
|
||||||
|
|
||||||
def recycle(self):
|
def recycle(self):
|
||||||
for name, proc in self.workers.items():
|
for name, proc in self.workers.items():
|
||||||
|
@ -149,10 +153,11 @@ class DevicePoolExecutor:
|
||||||
lock = self.locks[name]
|
lock = self.locks[name]
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", name)
|
logger.debug("starting worker for device %s", name)
|
||||||
self.workers[name] = Process(target=worker_init, args=(lock, context, self.server))
|
self.workers[name] = Process(
|
||||||
|
target=worker_init, args=(lock, context, self.server)
|
||||||
|
)
|
||||||
self.workers[name].start()
|
self.workers[name].start()
|
||||||
|
|
||||||
|
|
||||||
def submit(
|
def submit(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
@ -171,7 +176,10 @@ class DevicePoolExecutor:
|
||||||
self.prune()
|
self.prune()
|
||||||
device_idx = self.get_next_device(needs_device=needs_device)
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
logger.info(
|
logger.info(
|
||||||
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
"assigning job %s to device %s: %s",
|
||||||
|
key,
|
||||||
|
device_idx,
|
||||||
|
self.devices[device_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
device = self.devices[device_idx]
|
device = self.devices[device_idx]
|
||||||
|
@ -180,7 +188,6 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
self.jobs[key] = device.device
|
self.jobs[key] = device.device
|
||||||
|
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||||
pending = [
|
pending = [
|
||||||
(
|
(
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
import torch # has to come before ORT
|
|
||||||
from onnxruntime import get_available_providers
|
|
||||||
from torch.multiprocessing import Lock, Queue
|
|
||||||
from traceback import format_exception
|
from traceback import format_exception
|
||||||
from setproctitle import setproctitle
|
|
||||||
|
|
||||||
from .context import WorkerContext
|
from setproctitle import setproctitle
|
||||||
|
from torch.multiprocessing import Lock, Queue
|
||||||
|
|
||||||
|
from ..onnx.torch_before_ort import get_available_providers
|
||||||
from ..server import ServerContext, apply_patches
|
from ..server import ServerContext, apply_patches
|
||||||
|
from .context import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
|
||||||
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
|
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
|
||||||
|
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
setproctitle("onnx-web worker: %s", context.device.device)
|
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
job = context.pending.get()
|
job = context.pending.get()
|
||||||
|
@ -52,4 +52,3 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(format_exception(type(e), e, e.__traceback__))
|
logger.error(format_exception(type(e), e, e.__traceback__))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue