1
0
Fork 0

add resizing logic to preparation script

This commit is contained in:
Sean Sube 2023-09-03 16:09:44 -05:00
parent 2047f7d8cf
commit 4fd889180b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 35 additions and 20 deletions

View File

@ -1,6 +1,6 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import Any, List, Tuple from typing import Any, List, Tuple
from PIL.Image import Image, open as pil_open from PIL.Image import Image, open as pil_open, merge, Resampling
from torchvision.transforms import RandomCrop, Resize, Normalize, ToTensor from torchvision.transforms import RandomCrop, Resize, Normalize, ToTensor
from os import environ, path from os import environ, path
from logging import getLogger from logging import getLogger
@ -37,6 +37,7 @@ def parse_args():
parser.add_argument("--crops", type=int) parser.add_argument("--crops", type=int)
parser.add_argument("--height", type=int, default=512) parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=512) parser.add_argument("--width", type=int, default=512)
parser.add_argument("--scale", type=float, default=1.5)
parser.add_argument("--threshold", type=float, default=0.75) parser.add_argument("--threshold", type=float, default=0.75)
return parser.parse_args() return parser.parse_args()
@ -51,9 +52,17 @@ def load_images(root: str) -> List[Tuple[str, Image]]:
prefix, _ext = path.splitext(name) prefix, _ext = path.splitext(name)
prefix = path.basename(prefix) prefix = path.basename(prefix)
try:
image = pil_open(name) image = pil_open(name)
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
if image.mode == "L":
image = merge("RGB", (image, image, image))
logger.info("adding %s to sources", name)
images.append((prefix, image)) images.append((prefix, image))
except:
logger.exception("error loading image")
return images return images
@ -66,10 +75,18 @@ def save_images(root: str, images: List[Tuple[str, Image]]):
logger.info("saved %s images to %s", len(images), root) logger.info("saved %s images to %s", len(images), root)
def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int]) -> List[Tuple[str, Image]]: def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int], min_scale: float) -> List[Tuple[str, Image]]:
results = [] results = []
for name, image in images: for name, image in images:
results.append((name, ImageOps.contain(image, size))) scale = min(image.width / size[0], image.height / size[1])
resize = (int(image.width / scale), int(image.height / scale))
logger.info("resize %s from %s to %s (%s scale)", name, image.size, resize, scale)
if scale < min_scale:
logger.warning("image %s is too small: %s", name, resize)
continue
results.append((name, image.resize(resize, Resampling.LANCZOS)))
return results return results
@ -97,25 +114,13 @@ def remove_duplicates(sources: List[Tuple[str, Image]], threshold: float, vector
score = similarity(source_vector, cache_vector) score = similarity(source_vector, cache_vector)
logger.debug("similarity score for %s: %s", name, score) logger.debug("similarity score for %s: %s", name, score)
if score > threshold: if score.max() > threshold:
cached = True cached = True
if cached == False: if cached == False:
vector_cache.append(source_vector) vector_cache.append(source_vector)
results.append((name, source)) results.append((name, source))
# count = len(sources)
# for i in range(count):
# if i not in duplicates:
# for j in range(i + 1, count):
# if j not in duplicates and i != j:
# score = similarity(vectors[i], vectors[j])
# logger.info("similarity score between %s and %s: %s", i, j, score)
# if score > threshold:
# duplicates.add(j)
logger.info("keeping %s of %s images", len(results), len(sources)) logger.info("keeping %s of %s images", len(results), len(sources))
return results return results
@ -126,6 +131,12 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops:
results = [] results = []
for name, source in sources: for name, source in sources:
logger.info("cropping %s", name)
if source.width < size[0] or source.height < size[1]:
logger.info("a small image leaked into the set: %s", name)
continue
for i in range(crops): for i in range(crops):
results.append((f"{name}_{i}", transform(source))) results.append((f"{name}_{i}", transform(source)))
@ -134,11 +145,15 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops:
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
size = (int(args.width * args.scale), int(args.height * args.scale))
# load unique sources # load unique sources
sources = load_images(args.src) sources = load_images(args.src)
sources = resize_images(sources, (args.width * 2, args.height * 2)) logger.info("loaded %s source images, resizing", len(sources))
sources = resize_images(sources, size, 0.5)
logger.info("resized images, removing duplicates")
sources = remove_duplicates(sources, args.threshold, []) sources = remove_duplicates(sources, args.threshold, [])
logger.info("removed duplicated, kept %s source images", len(sources))
# randomly crop # randomly crop
cache = [] cache = []