add helper to remove NaN values in UNet output
This commit is contained in:
parent
8d6ea63330
commit
01c7392a61
|
@ -412,3 +412,16 @@ def pop_random(list: List[str]) -> str:
|
|||
i = random.randrange(len(list))
|
||||
list[i], list[-1] = list[-1], list[i]
|
||||
return list.pop()
|
||||
|
||||
|
||||
def repair_nan(tile: np.ndarray) -> np.ndarray:
|
||||
flat_tile = tile.flatten()
|
||||
flat_mask = np.isnan(flat_tile)
|
||||
|
||||
if np.any(flat_mask):
|
||||
logger.warning("repairing NaN values in image")
|
||||
indices = np.where(~flat_mask, np.arange(flat_mask.shape[0]), 0)
|
||||
np.maximum.accumulate(indices, out=indices)
|
||||
return np.reshape(flat_tile[indices], tile.shape)
|
||||
else:
|
||||
return tile
|
||||
|
|
Loading…
Reference in New Issue