diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 9697d7b3..575a3fef 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -17,18 +17,32 @@ from .diffusers.run import ( ) from .diffusers.stub_scheduler import StubScheduler from .diffusers.upscale import run_upscale_correction -from .image import ( +from .image.utils import ( expand_image, + valid_image, +) +from .image.mask_filter import ( mask_filter_gaussian_multiply, mask_filter_gaussian_screen, mask_filter_none, +) +from .image.noise_source import ( noise_source_fill_edge, noise_source_fill_mask, noise_source_gaussian, noise_source_histogram, noise_source_normal, noise_source_uniform, - valid_image, +) +from .image.source_filter import ( + source_filter_canny, + source_filter_depth, + source_filter_hed, + source_filter_mlsd, + source_filter_normal, + source_filter_pose, + source_filter_scribble, + source_filter_segment, ) from .onnx import OnnxRRDBNet, OnnxTensor from .params import ( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 8b1d47ea..d203bc4e 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, List +from typing import Any, List, Optional import numpy as np import torch @@ -19,6 +19,7 @@ from ..params import ( UpscaleParams, ) from ..server import ServerContext +from ..server.load import get_source_filters from ..utils import run_gc from ..worker import WorkerContext from .load import get_latents_from_seed, load_pipeline @@ -222,11 +223,16 @@ def run_img2img_pipeline( upscale: UpscaleParams, source: Image.Image, strength: float, + source_filter: Optional[str] = None, ) -> None: (prompt, loras) = get_loras_from_prompt(params.prompt) (prompt, inversions) = get_inversions_from_prompt(prompt) params.prompt = prompt + # filter the source image + if source_filter is not None: + source = get_source_filters(source_filter)(source) + pipe = load_pipeline( server, params.pipeline, # this is one of the only places this can actually vary between different pipelines @@ -243,6 +249,8 @@ def run_img2img_pipeline( pipe_params["controlnet_conditioning_scale"] = strength elif params.pipeline == "img2img": pipe_params["strength"] = strength + elif params.pipeline == "pix2pix": + pipe_params["image_guidance_scale"] = strength progress = job.get_progress_callback() if params.lpw(): diff --git a/api/onnx_web/image/__init__.py b/api/onnx_web/image/__init__.py new file mode 100644 index 00000000..894e85f8 --- /dev/null +++ b/api/onnx_web/image/__init__.py @@ -0,0 +1,30 @@ +from .utils import ( + expand_image, + valid_image, +) +from .mask_filter import ( + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, +) +from .noise_source import ( + noise_source_fill_edge, + noise_source_fill_mask, + noise_source_gaussian, + noise_source_histogram, + noise_source_normal, + noise_source_uniform, +) +from .source_filter import ( + source_filter_canny, + source_filter_depth, + source_filter_face, + source_filter_gaussian, + source_filter_hed, + source_filter_mlsd, + source_filter_noise, + source_filter_normal, + source_filter_openpose, + source_filter_scribble, + source_filter_segment, +) diff --git a/api/onnx_web/image/laion_face.py b/api/onnx_web/image/laion_face.py new file mode 100644 index 00000000..d5a70dc5 --- /dev/null +++ b/api/onnx_web/image/laion_face.py @@ -0,0 +1,167 @@ +# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py + +from typing import Mapping + +import mediapipe as mp +import numpy + +mp_drawing = mp.solutions.drawing_utils +mp_drawing_styles = mp.solutions.drawing_styles +mp_face_detection = mp.solutions.face_detection # Only for counting faces. +mp_face_mesh = mp.solutions.face_mesh +mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION +mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS +mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS + +DrawingSpec = mp.solutions.drawing_styles.DrawingSpec +PoseLandmark = mp.solutions.drawing_styles.PoseLandmark + +min_face_size_pixels: int = 64 +f_thick = 2 +f_rad = 1 +right_iris_draw = DrawingSpec( + color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad +) +right_eye_draw = DrawingSpec( + color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad +) +right_eyebrow_draw = DrawingSpec( + color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad +) +left_iris_draw = DrawingSpec( + color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad +) +left_eye_draw = DrawingSpec( + color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad +) +left_eyebrow_draw = DrawingSpec( + color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad +) +mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad) +head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) + +# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. +face_connection_spec = {} +for edge in mp_face_mesh.FACEMESH_FACE_OVAL: + face_connection_spec[edge] = head_draw +for edge in mp_face_mesh.FACEMESH_LEFT_EYE: + face_connection_spec[edge] = left_eye_draw +for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: + face_connection_spec[edge] = left_eyebrow_draw +# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: +# face_connection_spec[edge] = left_iris_draw +for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: + face_connection_spec[edge] = right_eye_draw +for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + face_connection_spec[edge] = right_eyebrow_draw +# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: +# face_connection_spec[edge] = right_iris_draw +for edge in mp_face_mesh.FACEMESH_LIPS: + face_connection_spec[edge] = mouth_draw +iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} + + +def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2): + """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all + landmarks. Until our PR is merged into mediapipe, we need this separate method.""" + if len(image.shape) != 3: + raise ValueError("Input image must be H,W,C.") + image_rows, image_cols, image_channels = image.shape + if image_channels != 3: # BGR channels + raise ValueError("Input image must contain three channel bgr data.") + for idx, landmark in enumerate(landmark_list.landmark): + if (landmark.HasField("visibility") and landmark.visibility < 0.9) or ( + landmark.HasField("presence") and landmark.presence < 0.5 + ): + continue + if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: + continue + image_x = int(image_cols * landmark.x) + image_y = int(image_rows * landmark.y) + draw_color = None + if isinstance(drawing_spec, Mapping): + if drawing_spec.get(idx) is None: + continue + else: + draw_color = drawing_spec[idx].color + elif isinstance(drawing_spec, DrawingSpec): + draw_color = drawing_spec.color + image[ + image_y - halfwidth : image_y + halfwidth, + image_x - halfwidth : image_x + halfwidth, + :, + ] = draw_color + + +def reverse_channels(image): + """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.""" + # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order. + # im[:,:,::[2,1,0]] would also work but makes a copy of the data. + return image[:, :, ::-1] + + +def generate_annotation(img_rgb, max_faces: int, min_confidence: float): + """ + Find up to 'max_faces' inside the provided input image. + If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many + pixels in the image. + """ + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=max_faces, + refine_landmarks=True, + min_detection_confidence=min_confidence, + ) as facemesh: + img_height, img_width, img_channels = img_rgb.shape + assert img_channels == 3 + + results = facemesh.process(img_rgb).multi_face_landmarks + + if results is None: + print("No faces detected in controlnet image for Mediapipe face annotator.") + return numpy.zeros_like(img_rgb) + + # Filter faces that are too small + filtered_landmarks = [] + for lm in results: + landmarks = lm.landmark + face_rect = [ + landmarks[0].x, + landmarks[0].y, + landmarks[0].x, + landmarks[0].y, + ] # Left, up, right, down. + for i in range(len(landmarks)): + face_rect[0] = min(face_rect[0], landmarks[i].x) + face_rect[1] = min(face_rect[1], landmarks[i].y) + face_rect[2] = max(face_rect[2], landmarks[i].x) + face_rect[3] = max(face_rect[3], landmarks[i].y) + if min_face_size_pixels > 0: + face_width = abs(face_rect[2] - face_rect[0]) + face_height = abs(face_rect[3] - face_rect[1]) + face_width_pixels = face_width * img_width + face_height_pixels = face_height * img_height + face_size = min(face_width_pixels, face_height_pixels) + if face_size >= min_face_size_pixels: + filtered_landmarks.append(lm) + else: + filtered_landmarks.append(lm) + + # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start. + empty = numpy.zeros_like(img_rgb) + + # Draw detected faces: + for face_landmarks in filtered_landmarks: + mp_drawing.draw_landmarks( + empty, + face_landmarks, + connections=face_connection_spec.keys(), + landmark_drawing_spec=None, + connection_drawing_spec=face_connection_spec, + ) + draw_pupils(empty, face_landmarks, iris_landmark_spec, 2) + + # Flip BGR back to RGB. + empty = reverse_channels(empty).copy() + + return empty diff --git a/api/onnx_web/image/mask_filter.py b/api/onnx_web/image/mask_filter.py new file mode 100644 index 00000000..8eef1c75 --- /dev/null +++ b/api/onnx_web/image/mask_filter.py @@ -0,0 +1,44 @@ +from PIL import Image, ImageChops, ImageFilter + +from .params import Point + + +def mask_filter_none( + mask: Image.Image, dims: Point, origin: Point, fill="white", **kw +) -> Image.Image: + width, height = dims + + noise = Image.new("RGB", (width, height), fill) + noise.paste(mask, origin) + + return noise + + +def mask_filter_gaussian_multiply( + mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw +) -> Image.Image: + """ + Gaussian blur with multiply, source image centered on white canvas. + """ + noise = mask_filter_none(mask, dims, origin) + + for _i in range(rounds): + blur = noise.filter(ImageFilter.GaussianBlur(5)) + noise = ImageChops.multiply(noise, blur) + + return noise + + +def mask_filter_gaussian_screen( + mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw +) -> Image.Image: + """ + Gaussian blur, source image centered on white canvas. + """ + noise = mask_filter_none(mask, dims, origin) + + for _i in range(rounds): + blur = noise.filter(ImageFilter.GaussianBlur(5)) + noise = ImageChops.screen(noise, blur) + + return noise diff --git a/api/onnx_web/image.py b/api/onnx_web/image/noise_source.py similarity index 52% rename from api/onnx_web/image.py rename to api/onnx_web/image/noise_source.py index 2e658cf0..2b6935e9 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image/noise_source.py @@ -1,57 +1,14 @@ -from typing import Tuple, Union - import numpy as np from numpy import random -from PIL import Image, ImageChops, ImageFilter, ImageOps +from PIL import Image, ImageFilter -from .params import Border, Point, Size +from .params import Point def get_pixel_index(x: int, y: int, width: int) -> int: return (y * width) + x -def mask_filter_none( - mask: Image.Image, dims: Point, origin: Point, fill="white", **kw -) -> Image.Image: - width, height = dims - - noise = Image.new("RGB", (width, height), fill) - noise.paste(mask, origin) - - return noise - - -def mask_filter_gaussian_multiply( - mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw -) -> Image.Image: - """ - Gaussian blur with multiply, source image centered on white canvas. - """ - noise = mask_filter_none(mask, dims, origin) - - for _i in range(rounds): - blur = noise.filter(ImageFilter.GaussianBlur(5)) - noise = ImageChops.multiply(noise, blur) - - return noise - - -def mask_filter_gaussian_screen( - mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw -) -> Image.Image: - """ - Gaussian blur, source image centered on white canvas. - """ - noise = mask_filter_none(mask, dims, origin) - - for _i in range(rounds): - blur = noise.filter(ImageFilter.GaussianBlur(5)) - noise = ImageChops.screen(noise, blur) - - return noise - - def noise_source_fill_edge( source: Image.Image, dims: Point, origin: Point, fill="white", **kw ) -> Image.Image: @@ -163,52 +120,3 @@ def noise_source_histogram( noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) return noise - - -# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232 -def expand_image( - source: Image.Image, - mask: Image.Image, - expand: Border, - fill="white", - noise_source=noise_source_histogram, - mask_filter=mask_filter_none, -): - full_width = expand.left + source.width + expand.right - full_height = expand.top + source.height + expand.bottom - - dims = (full_width, full_height) - origin = (expand.left, expand.top) - - full_source = Image.new("RGB", dims, fill) - full_source.paste(source, origin) - - # new mask pixels need to be filled with white so they will be replaced - full_mask = mask_filter(mask, dims, origin, fill="white") - full_noise = noise_source(source, dims, origin, fill=fill) - full_noise = ImageChops.multiply(full_noise, full_mask) - - full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) - - return (full_source, full_mask, full_noise, (full_width, full_height)) - - -def valid_image( - image: Image.Image, - min_dims: Union[Size, Tuple[int, int]] = [512, 512], - max_dims: Union[Size, Tuple[int, int]] = [512, 512], -) -> Image.Image: - min_x, min_y = min_dims - max_x, max_y = max_dims - - if image.width > max_x or image.height > max_y: - image = ImageOps.contain(image, (max_x, max_y)) - - if image.width < min_x or image.height < min_y: - blank = Image.new(image.mode, (min_x, min_y), "black") - blank.paste(image) - image = blank - - # check for square - - return image diff --git a/api/onnx_web/image/source_filter.py b/api/onnx_web/image/source_filter.py new file mode 100644 index 00000000..1ea0d77d --- /dev/null +++ b/api/onnx_web/image/source_filter.py @@ -0,0 +1,191 @@ +# https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/controlnet_pipe.py + +from logging import getLogger +from os import path + +import cv2 +import numpy as np +import torch +import transformers +from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector +from huggingface_hub import snapshot_download +from PIL import Image + +from ..server.context import ServerContext +from .laion_face import generate_annotation +from .utils import ade_palette + +logger = getLogger(__name__) + + +def pil_to_cv2(source: Image.Image) -> np.ndarray: + return cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR) + + +def filter_model_path(server: ServerContext, filter_name: str) -> str: + return path.join(server.model_path, "filter", filter_name) + + +def source_filter_gaussian(): + pass + + +def source_filter_noise(): + pass + + +def source_filter_face( + server: ServerContext, + source: Image.Image, + max_faces: int = 1, + min_confidence: float = 0.5, +) -> Image.Image: + logger.debug("running face detection on source image") + + image = generate_annotation(pil_to_cv2(source), max_faces, min_confidence) + image = Image.fromarray(image) + + return image + + +def source_filter_segment(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running segmentation on source image") + + openmm_model = snapshot_download( + "openmmlab/upernet-convnext-small", + allow_patterns=["*.bin", "*.json"], + cache_dir=filter_model_path(server, "upernet-convnext-small"), + ) + + image_processor = transformers.AutoImageProcessor.from_pretrained(openmm_model) + image_segmentor = transformers.UperNetForSemanticSegmentation.from_pretrained( + openmm_model + ) + + in_img = source.convert("RGB") + + pixel_values = image_processor(in_img, return_tensors="pt").pixel_values + + with torch.no_grad(): + outputs = image_segmentor(pixel_values) + + seg = image_processor.post_process_semantic_segmentation( + outputs, target_sizes=[in_img.size[::-1]] + )[0] + + color_seg = np.zeros( + (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 + ) # height, width, 3 + + palette = np.array(ade_palette()) + + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + + color_seg = color_seg.astype(np.uint8) + + image = Image.fromarray(color_seg) + + return image + + +def source_filter_mlsd(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running MLSD on source image") + + # TODO: get model name + mlsd = MLSDdetector.from_pretrained(filter_model_path(server, "mlsd")) + image = mlsd(source) + + return image + + +def source_filter_normal(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running normal detection on source image") + + depth_estimator = transformers.pipeline( + "depth-estimation", + model=snapshot_download( + "Intel/dpt-hybrid-midas", + allow_patterns=["*.bin", "*.json"], + cache_dir=filter_model_path(server, "dpt-hybrid-midas"), + ), + ) + + image = depth_estimator(source)["predicted_depth"][0] + + image = image.numpy() + + image_depth = image.copy() + image_depth -= np.min(image_depth) + image_depth /= np.max(image_depth) + + bg_threhold = 0.4 + + x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) + x[image_depth < bg_threhold] = 0 + + y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3) + y[image_depth < bg_threhold] = 0 + + z = np.ones_like(x) * np.pi * 2.0 + + image = np.stack([x, y, z], axis=2) + image /= np.sum(image**2.0, axis=2, keepdims=True) ** 0.5 + image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) + + return image + + +def source_filter_hed(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running HED detection on source image") + + # TODO: get model name + hed = HEDdetector.from_pretrained(filter_model_path(server, "hed")) + image = hed(source) + + return image + + +def source_filter_scribble(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running scribble detection on source image") + + # TODO: get model name + hed = HEDdetector.from_pretrained(filter_model_path(server, "hed")) + image = hed(source, scribble=True) + + return image + + +def source_filter_depth(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running depth detection on source image") + depth_estimator = transformers.pipeline("depth-estimation") + + image = depth_estimator(source)["depth"] + image = np.array(image) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + image = Image.fromarray(image) + + return image + + +def source_filter_canny( + server: ServerContext, source: Image.Image, low_threshold=100, high_threshold=200 +) -> Image.Image: + logger.debug("running Canny detection on source image") + + image = cv2.Canny(pil_to_cv2(source), low_threshold, high_threshold) + image = Image.fromarray(image) + + return image + + +def source_filter_openpose(server: ServerContext, source: Image.Image) -> Image.Image: + logger.debug("running OpenPose detection on source image") + + # TODO: get model name + model = OpenposeDetector.from_pretrained(filter_model_path(server, "openpose")) + image = model(source) + + return image diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py new file mode 100644 index 00000000..956c4411 --- /dev/null +++ b/api/onnx_web/image/utils.py @@ -0,0 +1,214 @@ +from typing import Tuple, Union + +from PIL import Image, ImageChops, ImageOps + +from .mask_filter import mask_filter_none +from .noise_source import noise_source_histogram +from .params import Border, Size + + +# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232 +def expand_image( + source: Image.Image, + mask: Image.Image, + expand: Border, + fill="white", + noise_source=noise_source_histogram, + mask_filter=mask_filter_none, +): + full_width = expand.left + source.width + expand.right + full_height = expand.top + source.height + expand.bottom + + dims = (full_width, full_height) + origin = (expand.left, expand.top) + + full_source = Image.new("RGB", dims, fill) + full_source.paste(source, origin) + + # new mask pixels need to be filled with white so they will be replaced + full_mask = mask_filter(mask, dims, origin, fill="white") + full_noise = noise_source(source, dims, origin, fill=fill) + full_noise = ImageChops.multiply(full_noise, full_mask) + + full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) + + return (full_source, full_mask, full_noise, (full_width, full_height)) + + +def valid_image( + image: Image.Image, + min_dims: Union[Size, Tuple[int, int]] = [512, 512], + max_dims: Union[Size, Tuple[int, int]] = [512, 512], +) -> Image.Image: + min_x, min_y = min_dims + max_x, max_y = max_dims + + if image.width > max_x or image.height > max_y: + image = ImageOps.contain(image, (max_x, max_y)) + + if image.width < min_x or image.height < min_y: + blank = Image.new(image.mode, (min_x, min_y), "black") + blank.paste(image) + image = blank + + # check for square + + return image + + +# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/palette.py +# and others +def ade_palette(): + """ADE20K palette that maps each class to RGB values.""" + return [ + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 6fc84e51..c34e8b54 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -42,6 +42,7 @@ from .load import ( get_mask_filters, get_network_models, get_noise_sources, + get_source_filters, get_upscaling_models, ) from .params import ( @@ -151,6 +152,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server, "img2img") upscale = upscale_from_request() + source_filter = get_from_list( + request.args, "sourceFilter", list(get_source_filters().keys()) + ) strength = get_and_clamp_float( request.args, @@ -160,6 +164,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) + # TODO: add filtered source to outputs output = make_output_name(server, "img2img", params, size, extras=[strength]) job_name = output[0] logger.info("img2img job queued for: %s", job_name) @@ -175,6 +180,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): source, strength, needs_device=device, + source_filter=source_filter, ) return jsonify(json_params(output, params, size, upscale=upscale)) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 1bea44e0..c11666a6 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -19,6 +19,17 @@ from ..image import ( # mask filters; noise sources noise_source_histogram, noise_source_normal, noise_source_uniform, + source_filter_canny, + source_filter_depth, + source_filter_face, + source_filter_gaussian, + source_filter_hed, + source_filter_mlsd, + source_filter_noise, + source_filter_normal, + source_filter_openpose, + source_filter_scribble, + source_filter_segment, ) from ..models.meta import NetworkModel from ..params import DeviceParams @@ -32,11 +43,15 @@ logger = getLogger(__name__) config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} # pipeline params -platform_providers = { - "cpu": "CPUExecutionProvider", - "cuda": "CUDAExecutionProvider", - "directml": "DmlExecutionProvider", - "rocm": "ROCMExecutionProvider", +highres_methods = [ + "bilinear", + "lanczos", + "upscale", +] +mask_filters = { + "none": mask_filter_none, + "gaussian-multiply": mask_filter_gaussian_multiply, + "gaussian-screen": mask_filter_gaussian_screen, } noise_sources = { "fill-edge": noise_source_fill_edge, @@ -46,17 +61,25 @@ noise_sources = { "normal": noise_source_normal, "uniform": noise_source_uniform, } -mask_filters = { - "none": mask_filter_none, - "gaussian-multiply": mask_filter_gaussian_multiply, - "gaussian-screen": mask_filter_gaussian_screen, +platform_providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "directml": "DmlExecutionProvider", + "rocm": "ROCMExecutionProvider", +} +source_filters = { + "gaussian": source_filter_gaussian, + "noise": source_filter_noise, + "face": source_filter_face, + "segment": source_filter_segment, + "mlsd": source_filter_mlsd, + "normal": source_filter_normal, + "hed": source_filter_hed, + "scribble": source_filter_scribble, + "depth": source_filter_depth, + "canny": source_filter_canny, + "openpose": source_filter_openpose, } -highres_methods = [ - "bilinear", - "lanczos", - "upscale", -] - # Available ORT providers available_platforms: List[DeviceParams] = [] @@ -111,6 +134,10 @@ def get_noise_sources(): return noise_sources +def get_source_filters(): + return source_filters + + def get_config_value(key: str, subkey: str = "default", default=None): return config_params.get(key, {}).get(subkey, default) diff --git a/api/requirements/base.txt b/api/requirements/base.txt index 53b76c39..9bb6d9bf 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -5,6 +5,7 @@ protobuf<4,>=3.20.2 ### AI packages ### accelerate coloredlogs +controlnet_aux diffusers onnx # onnxruntime has many platform-specific packages