feat(api): add source image filters for controlnet and others
This commit is contained in:
parent
bd992398ae
commit
80d00e4477
|
@ -17,18 +17,32 @@ from .diffusers.run import (
|
|||
)
|
||||
from .diffusers.stub_scheduler import StubScheduler
|
||||
from .diffusers.upscale import run_upscale_correction
|
||||
from .image import (
|
||||
from .image.utils import (
|
||||
expand_image,
|
||||
valid_image,
|
||||
)
|
||||
from .image.mask_filter import (
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
mask_filter_none,
|
||||
)
|
||||
from .image.noise_source import (
|
||||
noise_source_fill_edge,
|
||||
noise_source_fill_mask,
|
||||
noise_source_gaussian,
|
||||
noise_source_histogram,
|
||||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
valid_image,
|
||||
)
|
||||
from .image.source_filter import (
|
||||
source_filter_canny,
|
||||
source_filter_depth,
|
||||
source_filter_hed,
|
||||
source_filter_mlsd,
|
||||
source_filter_normal,
|
||||
source_filter_pose,
|
||||
source_filter_scribble,
|
||||
source_filter_segment,
|
||||
)
|
||||
from .onnx import OnnxRRDBNet, OnnxTensor
|
||||
from .params import (
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -19,6 +19,7 @@ from ..params import (
|
|||
UpscaleParams,
|
||||
)
|
||||
from ..server import ServerContext
|
||||
from ..server.load import get_source_filters
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
|
@ -222,11 +223,16 @@ def run_img2img_pipeline(
|
|||
upscale: UpscaleParams,
|
||||
source: Image.Image,
|
||||
strength: float,
|
||||
source_filter: Optional[str] = None,
|
||||
) -> None:
|
||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||
(prompt, inversions) = get_inversions_from_prompt(prompt)
|
||||
params.prompt = prompt
|
||||
|
||||
# filter the source image
|
||||
if source_filter is not None:
|
||||
source = get_source_filters(source_filter)(source)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params.pipeline, # this is one of the only places this can actually vary between different pipelines
|
||||
|
@ -243,6 +249,8 @@ def run_img2img_pipeline(
|
|||
pipe_params["controlnet_conditioning_scale"] = strength
|
||||
elif params.pipeline == "img2img":
|
||||
pipe_params["strength"] = strength
|
||||
elif params.pipeline == "pix2pix":
|
||||
pipe_params["image_guidance_scale"] = strength
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
if params.lpw():
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from .utils import (
|
||||
expand_image,
|
||||
valid_image,
|
||||
)
|
||||
from .mask_filter import (
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
mask_filter_none,
|
||||
)
|
||||
from .noise_source import (
|
||||
noise_source_fill_edge,
|
||||
noise_source_fill_mask,
|
||||
noise_source_gaussian,
|
||||
noise_source_histogram,
|
||||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .source_filter import (
|
||||
source_filter_canny,
|
||||
source_filter_depth,
|
||||
source_filter_face,
|
||||
source_filter_gaussian,
|
||||
source_filter_hed,
|
||||
source_filter_mlsd,
|
||||
source_filter_noise,
|
||||
source_filter_normal,
|
||||
source_filter_openpose,
|
||||
source_filter_scribble,
|
||||
source_filter_segment,
|
||||
)
|
|
@ -0,0 +1,167 @@
|
|||
# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py
|
||||
|
||||
from typing import Mapping
|
||||
|
||||
import mediapipe as mp
|
||||
import numpy
|
||||
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_drawing_styles = mp.solutions.drawing_styles
|
||||
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
|
||||
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
|
||||
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
|
||||
|
||||
DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
|
||||
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
|
||||
|
||||
min_face_size_pixels: int = 64
|
||||
f_thick = 2
|
||||
f_rad = 1
|
||||
right_iris_draw = DrawingSpec(
|
||||
color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
right_eye_draw = DrawingSpec(
|
||||
color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
right_eyebrow_draw = DrawingSpec(
|
||||
color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
left_iris_draw = DrawingSpec(
|
||||
color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
left_eye_draw = DrawingSpec(
|
||||
color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
left_eyebrow_draw = DrawingSpec(
|
||||
color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad
|
||||
)
|
||||
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
|
||||
|
||||
# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
|
||||
face_connection_spec = {}
|
||||
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
|
||||
face_connection_spec[edge] = head_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
|
||||
face_connection_spec[edge] = left_eye_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
|
||||
face_connection_spec[edge] = left_eyebrow_draw
|
||||
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
|
||||
# face_connection_spec[edge] = left_iris_draw
|
||||
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
|
||||
face_connection_spec[edge] = right_eye_draw
|
||||
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
|
||||
face_connection_spec[edge] = right_eyebrow_draw
|
||||
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
|
||||
# face_connection_spec[edge] = right_iris_draw
|
||||
for edge in mp_face_mesh.FACEMESH_LIPS:
|
||||
face_connection_spec[edge] = mouth_draw
|
||||
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
|
||||
|
||||
|
||||
def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
|
||||
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
|
||||
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
|
||||
if len(image.shape) != 3:
|
||||
raise ValueError("Input image must be H,W,C.")
|
||||
image_rows, image_cols, image_channels = image.shape
|
||||
if image_channels != 3: # BGR channels
|
||||
raise ValueError("Input image must contain three channel bgr data.")
|
||||
for idx, landmark in enumerate(landmark_list.landmark):
|
||||
if (landmark.HasField("visibility") and landmark.visibility < 0.9) or (
|
||||
landmark.HasField("presence") and landmark.presence < 0.5
|
||||
):
|
||||
continue
|
||||
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
|
||||
continue
|
||||
image_x = int(image_cols * landmark.x)
|
||||
image_y = int(image_rows * landmark.y)
|
||||
draw_color = None
|
||||
if isinstance(drawing_spec, Mapping):
|
||||
if drawing_spec.get(idx) is None:
|
||||
continue
|
||||
else:
|
||||
draw_color = drawing_spec[idx].color
|
||||
elif isinstance(drawing_spec, DrawingSpec):
|
||||
draw_color = drawing_spec.color
|
||||
image[
|
||||
image_y - halfwidth : image_y + halfwidth,
|
||||
image_x - halfwidth : image_x + halfwidth,
|
||||
:,
|
||||
] = draw_color
|
||||
|
||||
|
||||
def reverse_channels(image):
|
||||
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
|
||||
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
|
||||
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
|
||||
return image[:, :, ::-1]
|
||||
|
||||
|
||||
def generate_annotation(img_rgb, max_faces: int, min_confidence: float):
|
||||
"""
|
||||
Find up to 'max_faces' inside the provided input image.
|
||||
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
|
||||
pixels in the image.
|
||||
"""
|
||||
with mp_face_mesh.FaceMesh(
|
||||
static_image_mode=True,
|
||||
max_num_faces=max_faces,
|
||||
refine_landmarks=True,
|
||||
min_detection_confidence=min_confidence,
|
||||
) as facemesh:
|
||||
img_height, img_width, img_channels = img_rgb.shape
|
||||
assert img_channels == 3
|
||||
|
||||
results = facemesh.process(img_rgb).multi_face_landmarks
|
||||
|
||||
if results is None:
|
||||
print("No faces detected in controlnet image for Mediapipe face annotator.")
|
||||
return numpy.zeros_like(img_rgb)
|
||||
|
||||
# Filter faces that are too small
|
||||
filtered_landmarks = []
|
||||
for lm in results:
|
||||
landmarks = lm.landmark
|
||||
face_rect = [
|
||||
landmarks[0].x,
|
||||
landmarks[0].y,
|
||||
landmarks[0].x,
|
||||
landmarks[0].y,
|
||||
] # Left, up, right, down.
|
||||
for i in range(len(landmarks)):
|
||||
face_rect[0] = min(face_rect[0], landmarks[i].x)
|
||||
face_rect[1] = min(face_rect[1], landmarks[i].y)
|
||||
face_rect[2] = max(face_rect[2], landmarks[i].x)
|
||||
face_rect[3] = max(face_rect[3], landmarks[i].y)
|
||||
if min_face_size_pixels > 0:
|
||||
face_width = abs(face_rect[2] - face_rect[0])
|
||||
face_height = abs(face_rect[3] - face_rect[1])
|
||||
face_width_pixels = face_width * img_width
|
||||
face_height_pixels = face_height * img_height
|
||||
face_size = min(face_width_pixels, face_height_pixels)
|
||||
if face_size >= min_face_size_pixels:
|
||||
filtered_landmarks.append(lm)
|
||||
else:
|
||||
filtered_landmarks.append(lm)
|
||||
|
||||
# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
|
||||
empty = numpy.zeros_like(img_rgb)
|
||||
|
||||
# Draw detected faces:
|
||||
for face_landmarks in filtered_landmarks:
|
||||
mp_drawing.draw_landmarks(
|
||||
empty,
|
||||
face_landmarks,
|
||||
connections=face_connection_spec.keys(),
|
||||
landmark_drawing_spec=None,
|
||||
connection_drawing_spec=face_connection_spec,
|
||||
)
|
||||
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
|
||||
|
||||
# Flip BGR back to RGB.
|
||||
empty = reverse_channels(empty).copy()
|
||||
|
||||
return empty
|
|
@ -0,0 +1,44 @@
|
|||
from PIL import Image, ImageChops, ImageFilter
|
||||
|
||||
from .params import Point
|
||||
|
||||
|
||||
def mask_filter_none(
|
||||
mask: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise.paste(mask, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_multiply(
|
||||
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur with multiply, source image centered on white canvas.
|
||||
"""
|
||||
noise = mask_filter_none(mask, dims, origin)
|
||||
|
||||
for _i in range(rounds):
|
||||
blur = noise.filter(ImageFilter.GaussianBlur(5))
|
||||
noise = ImageChops.multiply(noise, blur)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_screen(
|
||||
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
"""
|
||||
noise = mask_filter_none(mask, dims, origin)
|
||||
|
||||
for _i in range(rounds):
|
||||
blur = noise.filter(ImageFilter.GaussianBlur(5))
|
||||
noise = ImageChops.screen(noise, blur)
|
||||
|
||||
return noise
|
|
@ -1,57 +1,14 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
from .params import Border, Point, Size
|
||||
from .params import Point
|
||||
|
||||
|
||||
def get_pixel_index(x: int, y: int, width: int) -> int:
|
||||
return (y * width) + x
|
||||
|
||||
|
||||
def mask_filter_none(
|
||||
mask: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise.paste(mask, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_multiply(
|
||||
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur with multiply, source image centered on white canvas.
|
||||
"""
|
||||
noise = mask_filter_none(mask, dims, origin)
|
||||
|
||||
for _i in range(rounds):
|
||||
blur = noise.filter(ImageFilter.GaussianBlur(5))
|
||||
noise = ImageChops.multiply(noise, blur)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_screen(
|
||||
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
"""
|
||||
noise = mask_filter_none(mask, dims, origin)
|
||||
|
||||
for _i in range(rounds):
|
||||
blur = noise.filter(ImageFilter.GaussianBlur(5))
|
||||
noise = ImageChops.screen(noise, blur)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_edge(
|
||||
source: Image.Image, dims: Point, origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
|
@ -163,52 +120,3 @@ def noise_source_histogram(
|
|||
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
|
||||
def expand_image(
|
||||
source: Image.Image,
|
||||
mask: Image.Image,
|
||||
expand: Border,
|
||||
fill="white",
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
full_width = expand.left + source.width + expand.right
|
||||
full_height = expand.top + source.height + expand.bottom
|
||||
|
||||
dims = (full_width, full_height)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
full_source = Image.new("RGB", dims, fill)
|
||||
full_source.paste(source, origin)
|
||||
|
||||
# new mask pixels need to be filled with white so they will be replaced
|
||||
full_mask = mask_filter(mask, dims, origin, fill="white")
|
||||
full_noise = noise_source(source, dims, origin, fill=fill)
|
||||
full_noise = ImageChops.multiply(full_noise, full_mask)
|
||||
|
||||
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
|
||||
|
||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||
|
||||
|
||||
def valid_image(
|
||||
image: Image.Image,
|
||||
min_dims: Union[Size, Tuple[int, int]] = [512, 512],
|
||||
max_dims: Union[Size, Tuple[int, int]] = [512, 512],
|
||||
) -> Image.Image:
|
||||
min_x, min_y = min_dims
|
||||
max_x, max_y = max_dims
|
||||
|
||||
if image.width > max_x or image.height > max_y:
|
||||
image = ImageOps.contain(image, (max_x, max_y))
|
||||
|
||||
if image.width < min_x or image.height < min_y:
|
||||
blank = Image.new(image.mode, (min_x, min_y), "black")
|
||||
blank.paste(image)
|
||||
image = blank
|
||||
|
||||
# check for square
|
||||
|
||||
return image
|
|
@ -0,0 +1,191 @@
|
|||
# https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/controlnet_pipe.py
|
||||
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from ..server.context import ServerContext
|
||||
from .laion_face import generate_annotation
|
||||
from .utils import ade_palette
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def pil_to_cv2(source: Image.Image) -> np.ndarray:
|
||||
return cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def filter_model_path(server: ServerContext, filter_name: str) -> str:
|
||||
return path.join(server.model_path, "filter", filter_name)
|
||||
|
||||
|
||||
def source_filter_gaussian():
|
||||
pass
|
||||
|
||||
|
||||
def source_filter_noise():
|
||||
pass
|
||||
|
||||
|
||||
def source_filter_face(
|
||||
server: ServerContext,
|
||||
source: Image.Image,
|
||||
max_faces: int = 1,
|
||||
min_confidence: float = 0.5,
|
||||
) -> Image.Image:
|
||||
logger.debug("running face detection on source image")
|
||||
|
||||
image = generate_annotation(pil_to_cv2(source), max_faces, min_confidence)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_segment(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running segmentation on source image")
|
||||
|
||||
openmm_model = snapshot_download(
|
||||
"openmmlab/upernet-convnext-small",
|
||||
allow_patterns=["*.bin", "*.json"],
|
||||
cache_dir=filter_model_path(server, "upernet-convnext-small"),
|
||||
)
|
||||
|
||||
image_processor = transformers.AutoImageProcessor.from_pretrained(openmm_model)
|
||||
image_segmentor = transformers.UperNetForSemanticSegmentation.from_pretrained(
|
||||
openmm_model
|
||||
)
|
||||
|
||||
in_img = source.convert("RGB")
|
||||
|
||||
pixel_values = image_processor(in_img, return_tensors="pt").pixel_values
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = image_segmentor(pixel_values)
|
||||
|
||||
seg = image_processor.post_process_semantic_segmentation(
|
||||
outputs, target_sizes=[in_img.size[::-1]]
|
||||
)[0]
|
||||
|
||||
color_seg = np.zeros(
|
||||
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
|
||||
) # height, width, 3
|
||||
|
||||
palette = np.array(ade_palette())
|
||||
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
|
||||
color_seg = color_seg.astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(color_seg)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_mlsd(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running MLSD on source image")
|
||||
|
||||
# TODO: get model name
|
||||
mlsd = MLSDdetector.from_pretrained(filter_model_path(server, "mlsd"))
|
||||
image = mlsd(source)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_normal(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running normal detection on source image")
|
||||
|
||||
depth_estimator = transformers.pipeline(
|
||||
"depth-estimation",
|
||||
model=snapshot_download(
|
||||
"Intel/dpt-hybrid-midas",
|
||||
allow_patterns=["*.bin", "*.json"],
|
||||
cache_dir=filter_model_path(server, "dpt-hybrid-midas"),
|
||||
),
|
||||
)
|
||||
|
||||
image = depth_estimator(source)["predicted_depth"][0]
|
||||
|
||||
image = image.numpy()
|
||||
|
||||
image_depth = image.copy()
|
||||
image_depth -= np.min(image_depth)
|
||||
image_depth /= np.max(image_depth)
|
||||
|
||||
bg_threhold = 0.4
|
||||
|
||||
x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
|
||||
x[image_depth < bg_threhold] = 0
|
||||
|
||||
y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
|
||||
y[image_depth < bg_threhold] = 0
|
||||
|
||||
z = np.ones_like(x) * np.pi * 2.0
|
||||
|
||||
image = np.stack([x, y, z], axis=2)
|
||||
image /= np.sum(image**2.0, axis=2, keepdims=True) ** 0.5
|
||||
image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_hed(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running HED detection on source image")
|
||||
|
||||
# TODO: get model name
|
||||
hed = HEDdetector.from_pretrained(filter_model_path(server, "hed"))
|
||||
image = hed(source)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_scribble(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running scribble detection on source image")
|
||||
|
||||
# TODO: get model name
|
||||
hed = HEDdetector.from_pretrained(filter_model_path(server, "hed"))
|
||||
image = hed(source, scribble=True)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_depth(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running depth detection on source image")
|
||||
depth_estimator = transformers.pipeline("depth-estimation")
|
||||
|
||||
image = depth_estimator(source)["depth"]
|
||||
image = np.array(image)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_canny(
|
||||
server: ServerContext, source: Image.Image, low_threshold=100, high_threshold=200
|
||||
) -> Image.Image:
|
||||
logger.debug("running Canny detection on source image")
|
||||
|
||||
image = cv2.Canny(pil_to_cv2(source), low_threshold, high_threshold)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def source_filter_openpose(server: ServerContext, source: Image.Image) -> Image.Image:
|
||||
logger.debug("running OpenPose detection on source image")
|
||||
|
||||
# TODO: get model name
|
||||
model = OpenposeDetector.from_pretrained(filter_model_path(server, "openpose"))
|
||||
image = model(source)
|
||||
|
||||
return image
|
|
@ -0,0 +1,214 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
from PIL import Image, ImageChops, ImageOps
|
||||
|
||||
from .mask_filter import mask_filter_none
|
||||
from .noise_source import noise_source_histogram
|
||||
from .params import Border, Size
|
||||
|
||||
|
||||
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
|
||||
def expand_image(
|
||||
source: Image.Image,
|
||||
mask: Image.Image,
|
||||
expand: Border,
|
||||
fill="white",
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
full_width = expand.left + source.width + expand.right
|
||||
full_height = expand.top + source.height + expand.bottom
|
||||
|
||||
dims = (full_width, full_height)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
full_source = Image.new("RGB", dims, fill)
|
||||
full_source.paste(source, origin)
|
||||
|
||||
# new mask pixels need to be filled with white so they will be replaced
|
||||
full_mask = mask_filter(mask, dims, origin, fill="white")
|
||||
full_noise = noise_source(source, dims, origin, fill=fill)
|
||||
full_noise = ImageChops.multiply(full_noise, full_mask)
|
||||
|
||||
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
|
||||
|
||||
return (full_source, full_mask, full_noise, (full_width, full_height))
|
||||
|
||||
|
||||
def valid_image(
|
||||
image: Image.Image,
|
||||
min_dims: Union[Size, Tuple[int, int]] = [512, 512],
|
||||
max_dims: Union[Size, Tuple[int, int]] = [512, 512],
|
||||
) -> Image.Image:
|
||||
min_x, min_y = min_dims
|
||||
max_x, max_y = max_dims
|
||||
|
||||
if image.width > max_x or image.height > max_y:
|
||||
image = ImageOps.contain(image, (max_x, max_y))
|
||||
|
||||
if image.width < min_x or image.height < min_y:
|
||||
blank = Image.new(image.mode, (min_x, min_y), "black")
|
||||
blank.paste(image)
|
||||
image = blank
|
||||
|
||||
# check for square
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/palette.py
|
||||
# and others
|
||||
def ade_palette():
|
||||
"""ADE20K palette that maps each class to RGB values."""
|
||||
return [
|
||||
[120, 120, 120],
|
||||
[180, 120, 120],
|
||||
[6, 230, 230],
|
||||
[80, 50, 50],
|
||||
[4, 200, 3],
|
||||
[120, 120, 80],
|
||||
[140, 140, 140],
|
||||
[204, 5, 255],
|
||||
[230, 230, 230],
|
||||
[4, 250, 7],
|
||||
[224, 5, 255],
|
||||
[235, 255, 7],
|
||||
[150, 5, 61],
|
||||
[120, 120, 70],
|
||||
[8, 255, 51],
|
||||
[255, 6, 82],
|
||||
[143, 255, 140],
|
||||
[204, 255, 4],
|
||||
[255, 51, 7],
|
||||
[204, 70, 3],
|
||||
[0, 102, 200],
|
||||
[61, 230, 250],
|
||||
[255, 6, 51],
|
||||
[11, 102, 255],
|
||||
[255, 7, 71],
|
||||
[255, 9, 224],
|
||||
[9, 7, 230],
|
||||
[220, 220, 220],
|
||||
[255, 9, 92],
|
||||
[112, 9, 255],
|
||||
[8, 255, 214],
|
||||
[7, 255, 224],
|
||||
[255, 184, 6],
|
||||
[10, 255, 71],
|
||||
[255, 41, 10],
|
||||
[7, 255, 255],
|
||||
[224, 255, 8],
|
||||
[102, 8, 255],
|
||||
[255, 61, 6],
|
||||
[255, 194, 7],
|
||||
[255, 122, 8],
|
||||
[0, 255, 20],
|
||||
[255, 8, 41],
|
||||
[255, 5, 153],
|
||||
[6, 51, 255],
|
||||
[235, 12, 255],
|
||||
[160, 150, 20],
|
||||
[0, 163, 255],
|
||||
[140, 140, 140],
|
||||
[250, 10, 15],
|
||||
[20, 255, 0],
|
||||
[31, 255, 0],
|
||||
[255, 31, 0],
|
||||
[255, 224, 0],
|
||||
[153, 255, 0],
|
||||
[0, 0, 255],
|
||||
[255, 71, 0],
|
||||
[0, 235, 255],
|
||||
[0, 173, 255],
|
||||
[31, 0, 255],
|
||||
[11, 200, 200],
|
||||
[255, 82, 0],
|
||||
[0, 255, 245],
|
||||
[0, 61, 255],
|
||||
[0, 255, 112],
|
||||
[0, 255, 133],
|
||||
[255, 0, 0],
|
||||
[255, 163, 0],
|
||||
[255, 102, 0],
|
||||
[194, 255, 0],
|
||||
[0, 143, 255],
|
||||
[51, 255, 0],
|
||||
[0, 82, 255],
|
||||
[0, 255, 41],
|
||||
[0, 255, 173],
|
||||
[10, 0, 255],
|
||||
[173, 255, 0],
|
||||
[0, 255, 153],
|
||||
[255, 92, 0],
|
||||
[255, 0, 255],
|
||||
[255, 0, 245],
|
||||
[255, 0, 102],
|
||||
[255, 173, 0],
|
||||
[255, 0, 20],
|
||||
[255, 184, 184],
|
||||
[0, 31, 255],
|
||||
[0, 255, 61],
|
||||
[0, 71, 255],
|
||||
[255, 0, 204],
|
||||
[0, 255, 194],
|
||||
[0, 255, 82],
|
||||
[0, 10, 255],
|
||||
[0, 112, 255],
|
||||
[51, 0, 255],
|
||||
[0, 194, 255],
|
||||
[0, 122, 255],
|
||||
[0, 255, 163],
|
||||
[255, 153, 0],
|
||||
[0, 255, 10],
|
||||
[255, 112, 0],
|
||||
[143, 255, 0],
|
||||
[82, 0, 255],
|
||||
[163, 255, 0],
|
||||
[255, 235, 0],
|
||||
[8, 184, 170],
|
||||
[133, 0, 255],
|
||||
[0, 255, 92],
|
||||
[184, 0, 255],
|
||||
[255, 0, 31],
|
||||
[0, 184, 255],
|
||||
[0, 214, 255],
|
||||
[255, 0, 112],
|
||||
[92, 255, 0],
|
||||
[0, 224, 255],
|
||||
[112, 224, 255],
|
||||
[70, 184, 160],
|
||||
[163, 0, 255],
|
||||
[153, 0, 255],
|
||||
[71, 255, 0],
|
||||
[255, 0, 163],
|
||||
[255, 204, 0],
|
||||
[255, 0, 143],
|
||||
[0, 255, 235],
|
||||
[133, 255, 0],
|
||||
[255, 0, 235],
|
||||
[245, 0, 255],
|
||||
[255, 0, 122],
|
||||
[255, 245, 0],
|
||||
[10, 190, 212],
|
||||
[214, 255, 0],
|
||||
[0, 204, 255],
|
||||
[20, 0, 255],
|
||||
[255, 255, 0],
|
||||
[0, 153, 255],
|
||||
[0, 41, 255],
|
||||
[0, 255, 204],
|
||||
[41, 0, 255],
|
||||
[41, 255, 0],
|
||||
[173, 0, 255],
|
||||
[0, 245, 255],
|
||||
[71, 0, 255],
|
||||
[122, 0, 255],
|
||||
[0, 255, 184],
|
||||
[0, 92, 255],
|
||||
[184, 255, 0],
|
||||
[0, 133, 255],
|
||||
[255, 214, 0],
|
||||
[25, 194, 194],
|
||||
[102, 255, 0],
|
||||
[92, 0, 255],
|
||||
]
|
|
@ -42,6 +42,7 @@ from .load import (
|
|||
get_mask_filters,
|
||||
get_network_models,
|
||||
get_noise_sources,
|
||||
get_source_filters,
|
||||
get_upscaling_models,
|
||||
)
|
||||
from .params import (
|
||||
|
@ -151,6 +152,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
device, params, size = pipeline_from_request(server, "img2img")
|
||||
upscale = upscale_from_request()
|
||||
source_filter = get_from_list(
|
||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||
)
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
|
@ -160,6 +164,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
get_config_value("strength", "min"),
|
||||
)
|
||||
|
||||
# TODO: add filtered source to outputs
|
||||
output = make_output_name(server, "img2img", params, size, extras=[strength])
|
||||
job_name = output[0]
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
@ -175,6 +180,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
source,
|
||||
strength,
|
||||
needs_device=device,
|
||||
source_filter=source_filter,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
|
|
@ -19,6 +19,17 @@ from ..image import ( # mask filters; noise sources
|
|||
noise_source_histogram,
|
||||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
source_filter_canny,
|
||||
source_filter_depth,
|
||||
source_filter_face,
|
||||
source_filter_gaussian,
|
||||
source_filter_hed,
|
||||
source_filter_mlsd,
|
||||
source_filter_noise,
|
||||
source_filter_normal,
|
||||
source_filter_openpose,
|
||||
source_filter_scribble,
|
||||
source_filter_segment,
|
||||
)
|
||||
from ..models.meta import NetworkModel
|
||||
from ..params import DeviceParams
|
||||
|
@ -32,11 +43,15 @@ logger = getLogger(__name__)
|
|||
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
|
||||
|
||||
# pipeline params
|
||||
platform_providers = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"rocm": "ROCMExecutionProvider",
|
||||
highres_methods = [
|
||||
"bilinear",
|
||||
"lanczos",
|
||||
"upscale",
|
||||
]
|
||||
mask_filters = {
|
||||
"none": mask_filter_none,
|
||||
"gaussian-multiply": mask_filter_gaussian_multiply,
|
||||
"gaussian-screen": mask_filter_gaussian_screen,
|
||||
}
|
||||
noise_sources = {
|
||||
"fill-edge": noise_source_fill_edge,
|
||||
|
@ -46,17 +61,25 @@ noise_sources = {
|
|||
"normal": noise_source_normal,
|
||||
"uniform": noise_source_uniform,
|
||||
}
|
||||
mask_filters = {
|
||||
"none": mask_filter_none,
|
||||
"gaussian-multiply": mask_filter_gaussian_multiply,
|
||||
"gaussian-screen": mask_filter_gaussian_screen,
|
||||
platform_providers = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"directml": "DmlExecutionProvider",
|
||||
"rocm": "ROCMExecutionProvider",
|
||||
}
|
||||
source_filters = {
|
||||
"gaussian": source_filter_gaussian,
|
||||
"noise": source_filter_noise,
|
||||
"face": source_filter_face,
|
||||
"segment": source_filter_segment,
|
||||
"mlsd": source_filter_mlsd,
|
||||
"normal": source_filter_normal,
|
||||
"hed": source_filter_hed,
|
||||
"scribble": source_filter_scribble,
|
||||
"depth": source_filter_depth,
|
||||
"canny": source_filter_canny,
|
||||
"openpose": source_filter_openpose,
|
||||
}
|
||||
highres_methods = [
|
||||
"bilinear",
|
||||
"lanczos",
|
||||
"upscale",
|
||||
]
|
||||
|
||||
|
||||
# Available ORT providers
|
||||
available_platforms: List[DeviceParams] = []
|
||||
|
@ -111,6 +134,10 @@ def get_noise_sources():
|
|||
return noise_sources
|
||||
|
||||
|
||||
def get_source_filters():
|
||||
return source_filters
|
||||
|
||||
|
||||
def get_config_value(key: str, subkey: str = "default", default=None):
|
||||
return config_params.get(key, {}).get(subkey, default)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ protobuf<4,>=3.20.2
|
|||
### AI packages ###
|
||||
accelerate
|
||||
coloredlogs
|
||||
controlnet_aux
|
||||
diffusers
|
||||
onnx
|
||||
# onnxruntime has many platform-specific packages
|
||||
|
|
Loading…
Reference in New Issue