add helper to remove NaN values in UNet output
This commit is contained in:
parent
8d6ea63330
commit
01c7392a61
|
@ -412,3 +412,16 @@ def pop_random(list: List[str]) -> str:
|
||||||
i = random.randrange(len(list))
|
i = random.randrange(len(list))
|
||||||
list[i], list[-1] = list[-1], list[i]
|
list[i], list[-1] = list[-1], list[i]
|
||||||
return list.pop()
|
return list.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def repair_nan(tile: np.ndarray) -> np.ndarray:
|
||||||
|
flat_tile = tile.flatten()
|
||||||
|
flat_mask = np.isnan(flat_tile)
|
||||||
|
|
||||||
|
if np.any(flat_mask):
|
||||||
|
logger.warning("repairing NaN values in image")
|
||||||
|
indices = np.where(~flat_mask, np.arange(flat_mask.shape[0]), 0)
|
||||||
|
np.maximum.accumulate(indices, out=indices)
|
||||||
|
return np.reshape(flat_tile[indices], tile.shape)
|
||||||
|
else:
|
||||||
|
return tile
|
||||||
|
|
Loading…
Reference in New Issue