From 4fd889180bb62d8d1d232c870e057b4c450e5f42 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 3 Sep 2023 16:09:44 -0500 Subject: [PATCH] add resizing logic to preparation script --- api/scripts/prepare-training.py | 55 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/api/scripts/prepare-training.py b/api/scripts/prepare-training.py index 22f2309e..7f3fa905 100644 --- a/api/scripts/prepare-training.py +++ b/api/scripts/prepare-training.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser from typing import Any, List, Tuple -from PIL.Image import Image, open as pil_open +from PIL.Image import Image, open as pil_open, merge, Resampling from torchvision.transforms import RandomCrop, Resize, Normalize, ToTensor from os import environ, path from logging import getLogger @@ -37,6 +37,7 @@ def parse_args(): parser.add_argument("--crops", type=int) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) + parser.add_argument("--scale", type=float, default=1.5) parser.add_argument("--threshold", type=float, default=0.75) return parser.parse_args() @@ -51,9 +52,17 @@ def load_images(root: str) -> List[Tuple[str, Image]]: prefix, _ext = path.splitext(name) prefix = path.basename(prefix) - image = pil_open(name) - image = ImageOps.exif_transpose(image) - images.append((prefix, image)) + try: + image = pil_open(name) + image = ImageOps.exif_transpose(image) + + if image.mode == "L": + image = merge("RGB", (image, image, image)) + + logger.info("adding %s to sources", name) + images.append((prefix, image)) + except: + logger.exception("error loading image") return images @@ -66,10 +75,18 @@ def save_images(root: str, images: List[Tuple[str, Image]]): logger.info("saved %s images to %s", len(images), root) -def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int]) -> List[Tuple[str, Image]]: +def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int], min_scale: float) -> List[Tuple[str, Image]]: results = [] for name, image in images: - results.append((name, ImageOps.contain(image, size))) + scale = min(image.width / size[0], image.height / size[1]) + resize = (int(image.width / scale), int(image.height / scale)) + logger.info("resize %s from %s to %s (%s scale)", name, image.size, resize, scale) + + if scale < min_scale: + logger.warning("image %s is too small: %s", name, resize) + continue + + results.append((name, image.resize(resize, Resampling.LANCZOS))) return results @@ -97,25 +114,13 @@ def remove_duplicates(sources: List[Tuple[str, Image]], threshold: float, vector score = similarity(source_vector, cache_vector) logger.debug("similarity score for %s: %s", name, score) - if score > threshold: + if score.max() > threshold: cached = True if cached == False: vector_cache.append(source_vector) results.append((name, source)) - - # count = len(sources) - # for i in range(count): - # if i not in duplicates: - # for j in range(i + 1, count): - # if j not in duplicates and i != j: - # score = similarity(vectors[i], vectors[j]) - # logger.info("similarity score between %s and %s: %s", i, j, score) - - # if score > threshold: - # duplicates.add(j) - logger.info("keeping %s of %s images", len(results), len(sources)) return results @@ -126,6 +131,12 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops: results = [] for name, source in sources: + logger.info("cropping %s", name) + + if source.width < size[0] or source.height < size[1]: + logger.info("a small image leaked into the set: %s", name) + continue + for i in range(crops): results.append((f"{name}_{i}", transform(source))) @@ -134,11 +145,15 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops: if __name__ == "__main__": args = parse_args() + size = (int(args.width * args.scale), int(args.height * args.scale)) # load unique sources sources = load_images(args.src) - sources = resize_images(sources, (args.width * 2, args.height * 2)) + logger.info("loaded %s source images, resizing", len(sources)) + sources = resize_images(sources, size, 0.5) + logger.info("resized images, removing duplicates") sources = remove_duplicates(sources, args.threshold, []) + logger.info("removed duplicated, kept %s source images", len(sources)) # randomly crop cache = []