1
0
Fork 0

apply lint, add missing file

This commit is contained in:
Sean Sube 2023-11-18 18:13:13 -06:00
parent a63669c76b
commit 7e6749e0d7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
32 changed files with 176 additions and 84 deletions

View File

@ -1,2 +1,2 @@
from .pipeline import ChainPipeline, PipelineStage, StageParams from .pipeline import ChainPipeline, PipelineStage, StageParams
from .stages import * from .stages import * # NOQA

View File

@ -1,11 +1,11 @@
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from .result import StageResult
from ..params import ImageParams, Size, SizeChart, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext from ..server.context import ServerContext
from ..worker.context import WorkerContext from ..worker.context import WorkerContext
from .result import StageResult
class BaseStage: class BaseStage:

View File

@ -1,8 +1,7 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
import cv2 import cv2
import numpy as np
from PIL import Image from PIL import Image
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
@ -28,4 +28,6 @@ class BlendLinearStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")
return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources]) return StageResult(
images=[Image.blend(source, stage_source, alpha) for source in sources]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
@ -38,4 +38,8 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources]) return StageResult(
images=[
Image.composite(stage_source, source, mult_mask) for source in sources
]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image

View File

@ -1,8 +1,7 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np
from PIL import Image from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
@ -74,12 +73,15 @@ class CorrectGFPGANStage(BaseStage):
device = worker.get_device() device = worker.get_device()
gfpgan = self.load(server, stage, upscale, device) gfpgan = self.load(server, stage, upscale, device)
outputs = [gfpgan.enhance( outputs = [
gfpgan.enhance(
source, source,
has_aligned=False, has_aligned=False,
only_center_face=False, only_center_face=False,
paste_back=True, paste_back=True,
weight=upscale.face_strength, weight=upscale.face_strength,
) for source in sources.as_numpy()] )
for source in sources.as_numpy()
]
return StageResult(images=outputs) return StageResult(images=outputs)

View File

@ -1,11 +1,11 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import Optional
from .pipeline import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from .pipeline import ChainPipeline
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from boto3 import Session from boto3 import Session
from PIL import Image from PIL import Image

View File

@ -12,8 +12,8 @@ from ..server import ServerContext
from ..utils import is_debug, run_gc from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .tile import needs_tile, process_tile_order
from .result import StageResult from .result import StageResult
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -76,7 +76,7 @@ class ChainPipeline:
params: ImageParams, params: ImageParams,
sources: StageResult, sources: StageResult,
callback: Optional[ProgressCallback], callback: Optional[ProgressCallback],
**kwargs **kwargs,
) -> StageResult: ) -> StageResult:
result = self( result = self(
worker, server, params, sources=sources, callback=callback, **kwargs worker, server, params, sources=sources, callback=callback, **kwargs
@ -108,10 +108,10 @@ class ChainPipeline:
params: ImageParams, params: ImageParams,
sources: StageResult, sources: StageResult,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs **pipeline_kwargs,
) -> StageResult: ) -> StageResult:
""" """
DEPRECATED: use `run` instead DEPRECATED: use `.run()` instead
""" """
if callback is None: if callback is None:
callback = worker.get_progress_callback() callback = worker.get_progress_callback()

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image

View File

@ -1,5 +1,4 @@
from logging import getLogger from logging import getLogger
from typing import List
from PIL import Image from PIL import Image

View File

@ -1,7 +1,8 @@
from PIL import Image
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image
class StageResult: class StageResult:
""" """
@ -10,6 +11,7 @@ class StageResult:
This class intentionally does not provide `__iter__`, to ensure clients get results in the format This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected. they are expected.
""" """
arrays: Optional[List[np.ndarray]] arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]] images: Optional[List[Image.Image]]
@ -38,4 +40,4 @@ class StageResult:
if self.images is not None: if self.images is not None:
return self.images return self.images
return [Image.fromarray(i, "RGB") for i in self.arrays] return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List, Optional from typing import Callable, Optional
from PIL import Image from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -132,7 +132,9 @@ class SourceTxt2ImgStage(BaseStage):
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
if params.is_panorama() or params.is_xl(): if params.is_panorama() or params.is_xl():
logger.debug("prompt alternatives are not supported for panorama or SDXL") logger.debug(
"prompt alternatives are not supported for panorama or SDXL"
)
else: else:
prompt_embeds = encode_prompt( prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg() pipe, prompt_pairs, params.batch, params.do_cfg()

View File

@ -2,8 +2,8 @@ from logging import getLogger
from .base import BaseStage from .base import BaseStage
from .blend_denoise import BlendDenoiseStage from .blend_denoise import BlendDenoiseStage
from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage from .correct_codeformer import CorrectCodeformerStage

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@ -1,8 +1,7 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np
from PIL import Image from PIL import Image
from ..onnx import OnnxRRDBNet from ..onnx import OnnxRRDBNet

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import torch import torch
from PIL import Image from PIL import Image

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -80,12 +80,15 @@ class UpscaleSwinIRStage(BaseStage):
logger.trace("SwinIR input shape: %s", image.shape) logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
logger.trace("SwinIR output shape: %s", ( logger.trace(
"SwinIR output shape: %s",
(
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
)) ),
)
output = swinir(image) output = swinir(image)
output = np.clip(np.squeeze(output, axis=0), 0, 1) output = np.clip(np.squeeze(output, axis=0), 0, 1)

View File

@ -81,7 +81,9 @@ def convert_diffusion_diffusers_xl(
output=dest_path, output=dest_path,
task="stable-diffusion-xl", task="stable-diffusion-xl",
device=device, device=device,
fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm fp16=conversion.has_optimization(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt", framework="pt",
) )

View File

@ -574,10 +574,9 @@ def optimize_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
) -> None: ) -> None:
if ( if server.has_optimization(
server.has_optimization("diffusers-attention-slicing") "diffusers-attention-slicing"
or server.has_optimization("diffusers-attention-slicing-auto") ) or server.has_optimization("diffusers-attention-slicing-auto"):
):
logger.debug("enabling auto attention slicing on SD pipeline") logger.debug("enabling auto attention slicing on SD pipeline")
try: try:
pipe.enable_attention_slicing(slice_size="auto") pipe.enable_attention_slicing(slice_size="auto")

View File

@ -13,6 +13,7 @@ from ..chain import (
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.highres import stage_highres from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image from ..output import save_image

View File

@ -208,6 +208,7 @@ class ImageParams:
unet_overlap: float unet_overlap: float
vae_tile: int vae_tile: int
vae_overlap: float vae_overlap: float
denoise: int
def __init__( def __init__(
self, self,
@ -230,6 +231,7 @@ class ImageParams:
unet_tile: int = 512, unet_tile: int = 512,
vae_overlap: float = 0.25, vae_overlap: float = 0.25,
vae_tile: int = 512, vae_tile: int = 512,
denoise: int = 3,
) -> None: ) -> None:
self.model = model self.model = model
self.pipeline = pipeline self.pipeline = pipeline
@ -250,6 +252,7 @@ class ImageParams:
self.unet_tile = unet_tile self.unet_tile = unet_tile
self.vae_overlap = vae_overlap self.vae_overlap = vae_overlap
self.vae_tile = vae_tile self.vae_tile = vae_tile
self.denoise = denoise
def do_cfg(self): def do_cfg(self):
return self.cfg > 1.0 return self.cfg > 1.0
@ -320,6 +323,7 @@ class ImageParams:
"unet_tile": self.unet_tile, "unet_tile": self.unet_tile,
"vae_overlap": self.vae_overlap, "vae_overlap": self.vae_overlap,
"vae_tile": self.vae_tile, "vae_tile": self.vae_tile,
"denoise": self.denoise,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -343,6 +347,7 @@ class ImageParams:
kwargs.get("unet_tile", self.unet_tile), kwargs.get("unet_tile", self.unet_tile),
kwargs.get("vae_overlap", self.vae_overlap), kwargs.get("vae_overlap", self.vae_overlap),
kwargs.get("vae_tile", self.vae_tile), kwargs.get("vae_tile", self.vae_tile),
kwargs.get("denoise", self.denoise),
) )

View File

@ -138,7 +138,9 @@ def apply_patch_basicsr(server: ServerContext):
import basicsr.utils.download_util import basicsr.utils.download_util
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError: except ImportError:
logger.info("unable to import basicsr utils for patching") logger.info("unable to import basicsr utils for patching")
except AttributeError: except AttributeError:
@ -151,7 +153,9 @@ def apply_patch_codeformer(server: ServerContext):
import codeformer.facelib.utils.misc import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) codeformer.facelib.utils.misc.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError: except ImportError:
logger.info("unable to import codeformer utils for patching") logger.info("unable to import codeformer utils for patching")
except AttributeError: except AttributeError:

View File

@ -0,0 +1,61 @@
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
self.pipelines = pipelines or {}
self.stages = stages or {}
PluginModule = Callable[[ServerContext], PluginExports]
def load_plugins(server: ServerContext) -> PluginExports:
combined_exports = PluginExports()
for plugin in server.plugins:
logger.info("loading plugin module: %s", plugin)
try:
module: PluginModule = import_module(plugin)
exports = module(server)
for name, pipeline in exports.pipelines.items():
if name in combined_exports.pipelines:
logger.warning(
"multiple plugins exported a pipeline named %s", name
)
else:
combined_exports.pipelines[name] = pipeline
for name, stage in exports.stages.items():
if name in combined_exports.stages:
logger.warning("multiple plugins exported a stage named %s", name)
else:
combined_exports.stages[name] = stage
except Exception:
logger.exception("error importing plugin")
return combined_exports
def register_plugins(exports: PluginExports) -> bool:
success = True
for name, pipeline in exports.pipelines.items():
success = success and add_pipeline(name, pipeline)
for name, stage in exports.stages.items():
success = success and add_stage(name, stage)
return success

View File

@ -1,14 +1,22 @@
import unittest import unittest
from multiprocessing import Queue, Value from multiprocessing import Queue, Value
from os import getpid from os import getpid
from onnx_web.errors import RetryException
from onnx_web.errors import RetryException
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobCommand from onnx_web.worker.command import JobCommand
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main from onnx_web.worker.worker import (
EXIT_ERROR,
EXIT_INTERRUPT,
EXIT_MEMORY,
EXIT_REPLACED,
MEMORY_ERRORS,
worker_main,
)
from tests.helpers import test_device from tests.helpers import test_device
def main_memory(_worker): def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0]) raise Exception(MEMORY_ERRORS[0])