1
0
Fork 0

add helper to remove NaN values in UNet output

This commit is contained in:
Sean Sube 2023-07-16 12:03:14 -05:00
parent 8d6ea63330
commit 01c7392a61
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 0 deletions

View File

@ -412,3 +412,16 @@ def pop_random(list: List[str]) -> str:
i = random.randrange(len(list))
list[i], list[-1] = list[-1], list[i]
return list.pop()
def repair_nan(tile: np.ndarray) -> np.ndarray:
flat_tile = tile.flatten()
flat_mask = np.isnan(flat_tile)
if np.any(flat_mask):
logger.warning("repairing NaN values in image")
indices = np.where(~flat_mask, np.arange(flat_mask.shape[0]), 0)
np.maximum.accumulate(indices, out=indices)
return np.reshape(flat_tile[indices], tile.shape)
else:
return tile