1
0
Fork 0

lint all the new stuff

This commit is contained in:
Sean Sube 2023-02-26 14:15:30 -06:00
parent b880b7a121
commit 584dddb5d6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
34 changed files with 183 additions and 174 deletions

View File

@ -57,4 +57,4 @@ from .utils import (
)
from .worker import (
DevicePoolExecutor,
)
)

View File

@ -29,4 +29,4 @@ CHAIN_STAGES = {
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
}

View File

@ -7,9 +7,9 @@ from PIL import Image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order
logger = getLogger(__name__)

View File

@ -7,8 +7,8 @@ from PIL import Image
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order
logger = getLogger(__name__)

View File

@ -7,9 +7,9 @@ from onnx_web.image import valid_image
from onnx_web.output import save_image
from ..params import ImageParams, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -3,8 +3,8 @@ from logging import getLogger
from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -5,10 +5,9 @@ import numpy as np
from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from ..server import ServerContext
logger = getLogger(__name__)

View File

@ -4,9 +4,8 @@ from PIL import Image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -5,9 +5,8 @@ from boto3 import Session
from PIL import Image
from ..params import ImageParams, StageParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -3,9 +3,8 @@ from logging import getLogger
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -3,9 +3,8 @@ from logging import getLogger
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -4,9 +4,8 @@ from typing import Callable
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -7,9 +7,8 @@ from PIL import Image
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipel
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_grid, process_tile_order
logger = getLogger(__name__)

View File

@ -6,9 +6,9 @@ from PIL import Image
from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
logger = getLogger(__name__)

View File

@ -10,9 +10,9 @@ from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..worker import WorkerContext, ProgressCallback
from ..server import ServerContext
from ..utils import run_gc
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -4,11 +4,9 @@ from typing import Any, Optional, Tuple
import numpy as np
from diffusers import (
DiffusionPipeline,
OnnxRuntimeModel,
StableDiffusionPipeline,
DDIMScheduler,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
@ -19,7 +17,9 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
OnnxRuntimeModel,
PNDMScheduler,
StableDiffusionPipeline,
)
try:

View File

@ -12,10 +12,10 @@ from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..worker import WorkerContext
from ..server import ServerContext
from ..upscale import run_upscale_correction
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
logger = getLogger(__name__)

View File

@ -7,49 +7,53 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
from torch.multiprocessing import set_start_method
from .server.api import register_api_routes
from .server.static import register_static_routes
from .server.config import get_available_platforms, load_models, load_params, load_platforms
from .server.utils import check_paths
from .server.config import (
get_available_platforms,
load_models,
load_params,
load_platforms,
)
from .server.context import ServerContext
from .server.hacks import apply_patches
from .utils import (
is_debug,
)
from .server.static import register_static_routes
from .server.utils import check_paths
from .utils import is_debug
from .worker import DevicePoolExecutor
def main():
set_start_method("spawn", force=True)
set_start_method("spawn", force=True)
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
load_platforms(context)
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
load_platforms(context)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
if not context.show_progress:
disable_progress_bar()
disable_progress_bars()
if not context.show_progress:
disable_progress_bar()
disable_progress_bars()
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
# any is a fake device, should not be in the pool
pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"])
# any is a fake device, should not be in the pool
pool = DevicePoolExecutor(
context, [p for p in get_available_platforms() if p.device != "any"]
)
# register routes
register_static_routes(app, context, pool)
register_api_routes(app, context, pool)
# register routes
register_static_routes(app, context, pool)
register_api_routes(app, context, pool)
return app, pool
return app, pool
if __name__ == "__main__":
app, pool = main()
app.run("0.0.0.0", 5000, debug=is_debug())
pool.join()
app, pool = main()
app.run("0.0.0.0", 5000, debug=is_debug())
pool.join()

View File

@ -3,9 +3,9 @@ from typing import Any, Optional
import numpy as np
import torch
from onnxruntime import InferenceSession, SessionOptions
from ..server import ServerContext
from .torch_before_ort import InferenceSession, SessionOptions
class OnnxTensor:

View File

@ -0,0 +1,5 @@
# this file exists to make sure torch is always imported before onnxruntime
# to work around https://github.com/microsoft/onnxruntime/issues/11092
import torch # NOQA
from onnxruntime import * # NOQA

View File

@ -2,8 +2,7 @@ from enum import IntEnum
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
from onnxruntime import GraphOptimizationLevel, SessionOptions
from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)

View File

@ -7,27 +7,7 @@ from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
from .context import ServerContext
from .utils import wrap_route
from ..worker.pool import DevicePoolExecutor
from .config import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_inversion_models,
get_mask_filters,
get_noise_sources,
get_upscaling_models,
)
from .params import border_from_request, pipeline_from_request, upscale_from_request
from ..chain import (
CHAIN_STAGES,
ChainPipeline,
)
from ..chain import CHAIN_STAGES, ChainPipeline
from ..diffusion.load import get_pipeline_schedulers
from ..diffusion.run import (
run_blend_pipeline,
@ -36,16 +16,9 @@ from ..diffusion.run import (
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..image import ( # mask filters; noise sources
valid_image,
)
from ..image import valid_image # mask filters; noise sources
from ..output import json_params, make_output_name
from ..params import (
Border,
StageParams,
TileOrder,
UpscaleParams,
)
from ..params import Border, StageParams, TileOrder, UpscaleParams
from ..transformers import run_txt2txt_pipeline
from ..utils import (
base_join,
@ -56,6 +29,21 @@ from ..utils import (
get_not_empty,
get_size,
)
from ..worker.pool import DevicePoolExecutor
from .config import (
get_available_platforms,
get_config_params,
get_config_value,
get_correction_models,
get_diffusion_models,
get_inversion_models,
get_mask_filters,
get_noise_sources,
get_upscaling_models,
)
from .context import ServerContext
from .params import border_from_request, pipeline_from_request, upscale_from_request
from .utils import wrap_route
logger = getLogger(__name__)
@ -456,22 +444,38 @@ def status(context: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, context, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
app.route("/api/settings/models")(wrap_route(list_models, context)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)),
app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)),
app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)),
app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)),
app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)),
app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)),
app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)),
]
return [
app.route("/api")(wrap_route(introspect, context, app=app)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
app.route("/api/settings/models")(wrap_route(list_models, context)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool)
),
app.route("/api/txt2img", methods=["POST"])(
wrap_route(txt2img, context, pool=pool)
),
app.route("/api/txt2txt", methods=["POST"])(
wrap_route(txt2txt, context, pool=pool)
),
app.route("/api/inpaint", methods=["POST"])(
wrap_route(inpaint, context, pool=pool)
),
app.route("/api/upscale", methods=["POST"])(
wrap_route(upscale, context, pool=pool)
),
app.route("/api/chain", methods=["POST"])(
wrap_route(chain, context, pool=pool)
),
app.route("/api/blend", methods=["POST"])(
wrap_route(blend, context, pool=pool)
),
app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, context, pool=pool)
),
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
app.route("/api/status")(wrap_route(status, context, pool=pool)),
]

View File

@ -2,13 +2,11 @@ from functools import cmp_to_key
from glob import glob
from logging import getLogger
from os import path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union
import torch
import yaml
from onnxruntime import get_available_providers
from .context import ServerContext
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -20,9 +18,9 @@ from ..image import ( # mask filters; noise sources
noise_source_normal,
noise_source_uniform,
)
from ..params import (
DeviceParams,
)
from ..onnx.torch_before_ort import get_available_providers
from ..params import DeviceParams
from .context import ServerContext
logger = getLogger(__name__)
@ -221,4 +219,3 @@ def load_platforms(context: ServerContext) -> None:
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)

View File

@ -4,30 +4,24 @@ from typing import Tuple
import numpy as np
from flask import request
from .context import ServerContext
from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models
from .utils import get_model_path
from ..diffusion.load import pipeline_schedulers
from ..params import (
Border,
DeviceParams,
ImageParams,
Size,
UpscaleParams,
)
from ..utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_not_empty,
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
from .config import (
get_available_platforms,
get_config_value,
get_correction_models,
get_upscaling_models,
)
from .context import ServerContext
from .utils import get_model_path
logger = getLogger(__name__)
def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]:
def pipeline_from_request(
context: ServerContext,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
@ -43,9 +37,7 @@ def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImagePa
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model)
scheduler = get_from_list(
request.args, "scheduler", pipeline_schedulers.keys()
)
scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
if scheduler is None:
scheduler = get_config_value("scheduler")

View File

@ -2,9 +2,9 @@ from os import path
from flask import Flask, send_from_directory
from .utils import wrap_route
from .context import ServerContext
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .utils import wrap_route
def serve_bundle_file(context: ServerContext, filename="index.html"):
@ -26,9 +26,11 @@ def output(context: ServerContext, filename: str):
)
def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/")(wrap_route(index, context)),
app.route("/<path:filename>")(wrap_route(index_path, context)),
app.route("/output/<path:filename>")(wrap_route(output, context)),
]
def register_static_routes(
app: Flask, context: ServerContext, pool: DevicePoolExecutor
):
return [
app.route("/")(wrap_route(index, context)),
app.route("/<path:filename>")(wrap_route(index_path, context)),
app.route("/output/<path:filename>")(wrap_route(output, context)),
]

View File

@ -1,6 +1,6 @@
from functools import partial, update_wrapper
from os import makedirs, path
from typing import Callable, Dict, List, Tuple
from functools import partial, update_wrapper
from flask import Flask
@ -22,7 +22,12 @@ def get_model_path(context: ServerContext, model: str):
return base_join(context.model_path, model)
def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]):
def register_routes(
app: Flask,
context: ServerContext,
pool: DevicePoolExecutor,
routes: List[Tuple[str, Dict, Callable]],
):
pass

View File

@ -11,7 +11,7 @@ from .chain import (
)
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext
from .worker import WorkerContext, ProgressCallback
from .worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -1,2 +1,2 @@
from .context import WorkerContext, ProgressCallback
from .pool import DevicePoolExecutor
from .pool import DevicePoolExecutor

View File

@ -1,7 +1,8 @@
from logging import getLogger
from torch.multiprocessing import Queue, Value
from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value
from ..params import DeviceParams
logger = getLogger(__name__)
@ -9,6 +10,7 @@ logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext:
cancel: "Value[bool]" = None
key: str = None

View File

@ -1 +1 @@
# TODO: queue-based logger
# TODO: queue-based logger

View File

@ -1,9 +1,10 @@
from collections import Counter
from logging import getLogger
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Lock, Process, Value
from ..params import DeviceParams
from ..server import ServerContext
from .context import WorkerContext
@ -35,7 +36,7 @@ class DevicePoolExecutor:
self.pending = {}
self.progress = {}
self.workers = {}
self.jobs = {} # Dict[Output, Device]
self.jobs = {} # Dict[Output, Device]
self.job_count = 0
# TODO: make this a method
@ -58,15 +59,21 @@ class DevicePoolExecutor:
cancel = Value("B", False, lock=lock)
finished = Value("B", False)
self.finished[name] = finished
progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why?
progress = Value(
"I", 0
) # , lock=lock) # needs its own lock for some reason. TODO: why?
self.progress[name] = progress
pending = Queue()
self.pending[name] = pending
context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished)
context = WorkerContext(
name, cancel, device, pending, progress, self.log_queue, finished
)
self.context[name] = context
logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(lock, context, server))
self.workers[name] = Process(
target=worker_init, args=(lock, context, server)
)
self.workers[name].start()
def cancel(self, key: str) -> bool:
@ -78,7 +85,7 @@ class DevicePoolExecutor:
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
if not key in self.jobs:
if key not in self.jobs:
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
@ -88,7 +95,6 @@ class DevicePoolExecutor:
return (finished.value, progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
@ -96,9 +102,7 @@ class DevicePoolExecutor:
if self.devices[i].device == needs_device.device:
return i
pending = [
self.pending[d.device].qsize() for d in self.devices
]
pending = [self.pending[d.device].qsize() for d in self.devices]
jobs = Counter(range(len(self.devices)))
jobs.update(pending)
@ -128,7 +132,7 @@ class DevicePoolExecutor:
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit:]
self.finished[:] = self.finished[-self.finished_limit :]
def recycle(self):
for name, proc in self.workers.items():
@ -149,10 +153,11 @@ class DevicePoolExecutor:
lock = self.locks[name]
logger.debug("starting worker for device %s", name)
self.workers[name] = Process(target=worker_init, args=(lock, context, self.server))
self.workers[name] = Process(
target=worker_init, args=(lock, context, self.server)
)
self.workers[name].start()
def submit(
self,
key: str,
@ -171,7 +176,10 @@ class DevicePoolExecutor:
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
"assigning job %s to device %s: %s",
key,
device_idx,
self.devices[device_idx],
)
device = self.devices[device_idx]
@ -180,7 +188,6 @@ class DevicePoolExecutor:
self.jobs[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(

View File

@ -1,12 +1,12 @@
from logging import getLogger
import torch # has to come before ORT
from onnxruntime import get_available_providers
from torch.multiprocessing import Lock, Queue
from traceback import format_exception
from setproctitle import setproctitle
from .context import WorkerContext
from setproctitle import setproctitle
from torch.multiprocessing import Lock, Queue
from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches
from .context import WorkerContext
logger = getLogger(__name__)
@ -29,7 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
apply_patches(server)
setproctitle("onnx-web worker: %s", context.device.device)
setproctitle("onnx-web worker: %s" % (context.device.device))
while True:
job = context.pending.get()
@ -52,4 +52,3 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
except Exception as e:
logger.error(format_exception(type(e), e, e.__traceback__))