add resizing logic to preparation script
This commit is contained in:
parent
2047f7d8cf
commit
4fd889180b
|
@ -1,6 +1,6 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
from PIL.Image import Image, open as pil_open
|
from PIL.Image import Image, open as pil_open, merge, Resampling
|
||||||
from torchvision.transforms import RandomCrop, Resize, Normalize, ToTensor
|
from torchvision.transforms import RandomCrop, Resize, Normalize, ToTensor
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
@ -37,6 +37,7 @@ def parse_args():
|
||||||
parser.add_argument("--crops", type=int)
|
parser.add_argument("--crops", type=int)
|
||||||
parser.add_argument("--height", type=int, default=512)
|
parser.add_argument("--height", type=int, default=512)
|
||||||
parser.add_argument("--width", type=int, default=512)
|
parser.add_argument("--width", type=int, default=512)
|
||||||
|
parser.add_argument("--scale", type=float, default=1.5)
|
||||||
parser.add_argument("--threshold", type=float, default=0.75)
|
parser.add_argument("--threshold", type=float, default=0.75)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -51,9 +52,17 @@ def load_images(root: str) -> List[Tuple[str, Image]]:
|
||||||
prefix, _ext = path.splitext(name)
|
prefix, _ext = path.splitext(name)
|
||||||
prefix = path.basename(prefix)
|
prefix = path.basename(prefix)
|
||||||
|
|
||||||
image = pil_open(name)
|
try:
|
||||||
image = ImageOps.exif_transpose(image)
|
image = pil_open(name)
|
||||||
images.append((prefix, image))
|
image = ImageOps.exif_transpose(image)
|
||||||
|
|
||||||
|
if image.mode == "L":
|
||||||
|
image = merge("RGB", (image, image, image))
|
||||||
|
|
||||||
|
logger.info("adding %s to sources", name)
|
||||||
|
images.append((prefix, image))
|
||||||
|
except:
|
||||||
|
logger.exception("error loading image")
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
@ -66,10 +75,18 @@ def save_images(root: str, images: List[Tuple[str, Image]]):
|
||||||
logger.info("saved %s images to %s", len(images), root)
|
logger.info("saved %s images to %s", len(images), root)
|
||||||
|
|
||||||
|
|
||||||
def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int]) -> List[Tuple[str, Image]]:
|
def resize_images(images: List[Tuple[str, Image]], size: Tuple[int, int], min_scale: float) -> List[Tuple[str, Image]]:
|
||||||
results = []
|
results = []
|
||||||
for name, image in images:
|
for name, image in images:
|
||||||
results.append((name, ImageOps.contain(image, size)))
|
scale = min(image.width / size[0], image.height / size[1])
|
||||||
|
resize = (int(image.width / scale), int(image.height / scale))
|
||||||
|
logger.info("resize %s from %s to %s (%s scale)", name, image.size, resize, scale)
|
||||||
|
|
||||||
|
if scale < min_scale:
|
||||||
|
logger.warning("image %s is too small: %s", name, resize)
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append((name, image.resize(resize, Resampling.LANCZOS)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -97,25 +114,13 @@ def remove_duplicates(sources: List[Tuple[str, Image]], threshold: float, vector
|
||||||
score = similarity(source_vector, cache_vector)
|
score = similarity(source_vector, cache_vector)
|
||||||
logger.debug("similarity score for %s: %s", name, score)
|
logger.debug("similarity score for %s: %s", name, score)
|
||||||
|
|
||||||
if score > threshold:
|
if score.max() > threshold:
|
||||||
cached = True
|
cached = True
|
||||||
|
|
||||||
if cached == False:
|
if cached == False:
|
||||||
vector_cache.append(source_vector)
|
vector_cache.append(source_vector)
|
||||||
results.append((name, source))
|
results.append((name, source))
|
||||||
|
|
||||||
|
|
||||||
# count = len(sources)
|
|
||||||
# for i in range(count):
|
|
||||||
# if i not in duplicates:
|
|
||||||
# for j in range(i + 1, count):
|
|
||||||
# if j not in duplicates and i != j:
|
|
||||||
# score = similarity(vectors[i], vectors[j])
|
|
||||||
# logger.info("similarity score between %s and %s: %s", i, j, score)
|
|
||||||
|
|
||||||
# if score > threshold:
|
|
||||||
# duplicates.add(j)
|
|
||||||
|
|
||||||
logger.info("keeping %s of %s images", len(results), len(sources))
|
logger.info("keeping %s of %s images", len(results), len(sources))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -126,6 +131,12 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for name, source in sources:
|
for name, source in sources:
|
||||||
|
logger.info("cropping %s", name)
|
||||||
|
|
||||||
|
if source.width < size[0] or source.height < size[1]:
|
||||||
|
logger.info("a small image leaked into the set: %s", name)
|
||||||
|
continue
|
||||||
|
|
||||||
for i in range(crops):
|
for i in range(crops):
|
||||||
results.append((f"{name}_{i}", transform(source)))
|
results.append((f"{name}_{i}", transform(source)))
|
||||||
|
|
||||||
|
@ -134,11 +145,15 @@ def crop_images(sources: List[Tuple[str, Image]], size: Tuple[int, int], crops:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
size = (int(args.width * args.scale), int(args.height * args.scale))
|
||||||
|
|
||||||
# load unique sources
|
# load unique sources
|
||||||
sources = load_images(args.src)
|
sources = load_images(args.src)
|
||||||
sources = resize_images(sources, (args.width * 2, args.height * 2))
|
logger.info("loaded %s source images, resizing", len(sources))
|
||||||
|
sources = resize_images(sources, size, 0.5)
|
||||||
|
logger.info("resized images, removing duplicates")
|
||||||
sources = remove_duplicates(sources, args.threshold, [])
|
sources = remove_duplicates(sources, args.threshold, [])
|
||||||
|
logger.info("removed duplicated, kept %s source images", len(sources))
|
||||||
|
|
||||||
# randomly crop
|
# randomly crop
|
||||||
cache = []
|
cache = []
|
||||||
|
|
Loading…
Reference in New Issue