import os
import abc
import warnings
import threading
from typing import Any, Self, Generic, TypeVar
from datetime import UTC, datetime
from functools import cached_property
from multiprocessing.shared_memory import SharedMemory
import numba
import numpy as np
from astropy.wcs import WCS
from dateutil.parser import parse as parse_datetime
from ndcube import NDCube
from numpy.typing import ArrayLike
from scipy.ndimage import gaussian_filter
from scipy.signal import convolve2d
from punchbowl.data import load_ndcube_from_fits, write_ndcube_to_fits
from punchbowl.exceptions import InvalidDataError, MissingTimezoneWarning
from punchbowl.prefect import punch_task
[docs]
def validate_image_is_square(image: np.ndarray) -> None:
"""Check that the input array is square."""
if not isinstance(image, np.ndarray):
msg = f"Image must be of type np.ndarray. Found: {type(image)}."
raise TypeError(msg)
if len(image.shape) != 2:
msg = f"Image must be a 2-D array. Input has {len(image.shape)} dimensions."
raise ValueError(msg)
if not np.equal(*image.shape):
msg = f"Image must be square. Found: {image.shape}."
raise ValueError(msg)
[docs]
def load_mask_file(path: str) -> np.ndarray:
"""
Load a PUNCH instrument mask.
To write a .bin file that this function can read, use:
with open('PUNCH_L2_MS1_20250101000000_v0j.bin', 'wb') as f:
np.packbits(np.isfinite(mask).T).tofile(f)
"""
with open(path, "rb") as f:
b = f.read()
return np.unpackbits(np.frombuffer(b, dtype=np.uint8)).reshape(2048, 2048).T.astype(bool)
[docs]
@punch_task
def output_image_task(data: NDCube, output_filename: str) -> None:
"""
Prefect task to write an image to disk.
Parameters
----------
data : NDCube
data that is to be written
output_filename : str
where to write the file out
Returns
-------
None
"""
output_dir = os.path.dirname(output_filename)
if output_dir and not os.path.isdir(output_dir):
os.makedirs(output_dir)
write_ndcube_to_fits(data, output_filename)
[docs]
@punch_task(tags=["image_loader"])
def load_image_task(input_filename: str, include_provenance: bool = True, include_uncertainty: bool = True,
dtype: type = float) -> NDCube:
"""
Prefect task to load data for processing.
Parameters
----------
input_filename : str
path to file to load
include_provenance : bool
whether to load the provenance layer
include_uncertainty : bool
whether to load the uncertainty layer
dtype : type
dtype to cast the data to
Returns
-------
NDCube
loaded version of the image
"""
return load_ndcube_from_fits(
input_filename, include_provenance=include_provenance, include_uncertainty=include_uncertainty, dtype=dtype)
[docs]
def average_datetime(datetimes: list[datetime]) -> datetime:
"""Compute average datetime from a list of datetimes."""
timestamps = [dt.replace(tzinfo=UTC).timestamp() for dt in datetimes]
average_timestamp = sum(timestamps) / len(timestamps)
return datetime.fromtimestamp(average_timestamp).astimezone(UTC)
[docs]
@numba.njit(parallel=True, cache=True)
def nan_percentile(array: np.ndarray, percentile: float | list[float]) -> float | np.ndarray:
"""
Calculate the nan percentile of a 3D cube. Isn't as fast as possible on a single core, but parallelizes very well.
It's documented that numba's sort is slower than numpy's, and this runs single-threaded ~half as fast as the old
implementation using numpy. But this parallelizes extremely well, even up to 128 cores for a 1kx2kx2k cube! Thread
count can be configured by setting numba.config.NUMBA_NUM_THREADS
The .copy() for each sequence means that, even though percentiling along the zeroth dimension seems wrong from a CPU
cache standpoint, transposing the input cube makes very little difference (much less than the time cost of copying
the cube into a transposed orientation!). Disabling the copy for a well-dimensioned array doesn't make a clear
difference to execution time.
The nan handling appears to add only negligible computation time
"""
percentiles = np.atleast_1d(np.array(percentile))
percentiles = percentiles / 100
output = np.empty((len(percentiles), *array.shape[1:]))
for i in numba.prange(array.shape[1]):
for j in range(array.shape[2]):
sequence = array[:, i, j].copy()
n_valid_obs = len(sequence)
sequence_max = np.nanmax(sequence)
for index in range(len(sequence)):
if np.isnan(sequence[index]):
sequence[index] = sequence_max
n_valid_obs -= 1
if n_valid_obs == 0:
for k in range(len(percentiles)):
output[k, i, j] = np.nan
sequence.sort()
for k in range(len(percentiles)):
index = (n_valid_obs - 1) * percentiles[k]
f = int(np.floor(index))
c = int(np.ceil(index))
if f == c:
output[k, i, j] = sequence[f]
else:
f_val = sequence[f]
c_val = sequence[c]
output[k, i, j] = f_val + (c_val - f_val) * (index - f)
if isinstance(percentile, (int, float)):
return output[0]
return output
[docs]
@numba.njit(parallel=True, cache=True)
def parallel_sort_first_axis(array: np.ndarray, handle_nans: bool = False, inplace: bool = False) -> np.ndarray:
"""
Sorts a 3D cube along the first axis.
Parallelizes very well on punch190 and phoenix.
It's documented that numba's sort is slower than numpy's, but this parallelizes extremely well, even up to 64 cores
for a 1kx2kx2k cube! Thread count can be configured by setting numba.config.NUMBA_NUM_THREADS
The .copy() for each sequence means that, even though sorting along the zeroth dimension seems wrong from a CPU
cache standpoint, transposing the input cube makes very little difference (much less than the time cost of copying
the cube into a transposed orientation!).
If handle_nans is True, NaNs are explicitly sorted to the high end of the array. Numba's sort appears to do this
anyway and still sorts the rest of the array correctly, but the flag ensures this behavior with a speed penalty.
Sorting in-place offers a ~50% speed boost in a 1kx2kx2k test case.
"""
output = array if inplace else np.empty_like(array)
for i in numba.prange(array.shape[1]):
for j in range(array.shape[2]):
sequence = array[:, i, j].copy()
if handle_nans:
bad_val = np.nanmax(sequence) + 1
for index in range(len(sequence)):
if np.isnan(sequence[index]):
sequence[index] = bad_val
sequence.sort()
if handle_nans:
for index in range(len(sequence)):
if sequence[index] == bad_val:
sequence[index] = np.nan
output[:, i, j] = sequence
return output
[docs]
@numba.njit(parallel=True, cache=True)
def nan_percentile_2d(array: np.ndarray, percentile: float | list[float], # noqa: C901
window_size: int, preserve_nans: bool = True) -> float | np.ndarray:
"""
Percentile-filter a 2D cube with NaN awareness. Parallelizes well.
Each pixel is replaced with a percentile of the non-NaN pixels in a local window. At the image edges, the local
window is clamped at the image boundary.
See nan_percentile for performance notes
When preserve_nans is True, NaN pixels will remain NaN. Otherwise they will be replaced with the percentile value.
"""
percentiles = np.atleast_1d(np.array(percentile))
percentiles = percentiles / 100
half_window_size = window_size // 2
output = np.empty((len(percentiles), *array.shape))
for i in numba.prange(array.shape[0]):
for j in range(array.shape[1]):
if preserve_nans and np.isnan(array[i, j]):
for k in range(len(percentiles)):
output[k, i, j] = np.nan
continue
imin = max(0, i - half_window_size)
jmin = max(0, j - half_window_size)
imax = min(array.shape[0], i + half_window_size + 1)
jmax = min(array.shape[1], j + half_window_size + 1)
sequence = array[imin:imax, jmin:jmax].flatten()
n_valid_obs = len(sequence)
sequence_max = np.nanmax(sequence)
for index in range(len(sequence)):
if np.isnan(sequence[index]):
sequence[index] = sequence_max
n_valid_obs -= 1
if n_valid_obs == 0:
for k in range(len(percentiles)):
output[k, i, j] = np.nan
continue
sequence.sort()
for k in range(len(percentiles)):
index = (n_valid_obs - 1) * percentiles[k]
f = int(np.floor(index))
c = int(np.ceil(index))
if f == c:
output[k, i, j] = sequence[f]
else:
f_val = sequence[f]
c_val = sequence[c]
output[k, i, j] = f_val + (c_val - f_val) * (index - f)
if isinstance(percentile, (int, float)):
return output[0]
return output
[docs]
def nan_gaussian(image: np.ndarray, sigma: float) -> np.ndarray:
"""Gaussian filter, where NaN pixels are ignored in the convolution, and NaN inputs become NaN outputs."""
nans = np.isnan(image)
# Ensure we don't modify the input image when replacing nans
image = np.where(nans, 0, image)
image = gaussian_filter(image, sigma, mode="constant", cval=0)
valid_weight = gaussian_filter((~nans).astype(float), sigma, mode="constant", cval=0)
with np.errstate(all="ignore"):
image /= valid_weight
image[nans] = np.nan
return image
[docs]
def interpolate_data(data_before: NDCube, data_after:NDCube, reference_time: datetime, time_key: str = "DATE-OBS",
allow_extrapolation: bool = False, and_uncertainty: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Interpolates between two data objects."""
before_date = parse_datetime(data_before.meta[time_key].value + " UTC").timestamp()
after_date = parse_datetime(data_after.meta[time_key].value + " UTC").timestamp()
if reference_time.tzinfo is None:
warnings.warn("Reference time has no timezone, but should probably be set to UTC", MissingTimezoneWarning)
observation_date = reference_time.timestamp()
if before_date > observation_date and not allow_extrapolation:
msg = "Before data was after the observation date"
raise InvalidDataError(msg)
if after_date < observation_date and not allow_extrapolation:
msg = "After data was before the observation date"
raise InvalidDataError(msg)
if before_date == observation_date:
data_interpolated = data_before.data
uncert_interpolated = data_before.uncertainty.array
elif after_date == observation_date:
data_interpolated = data_after.data
uncert_interpolated = data_after.uncertainty.array
else:
data_interpolated = ((data_after.data - data_before.data)
* (observation_date - before_date) / (after_date - before_date)
+ data_before.data)
if and_uncertainty:
uncert_interpolated = ((data_after.uncertainty.array - data_before.uncertainty.array)
* (observation_date - before_date) / (after_date - before_date)
+ data_before.uncertainty.array)
if and_uncertainty:
return data_interpolated, uncert_interpolated
return data_interpolated
[docs]
def find_first_existing_file(inputs: list[NDCube]) -> NDCube | None:
"""Find the first cube that's not None in a list of NDCubes."""
for cube in inputs:
if cube is not None:
return cube
msg = "No cube found. All inputs are None."
raise RuntimeError(msg)
[docs]
def get_dateobs(file: str | NDCube) -> datetime:
"""Convert file path or NDCube to date_obs."""
if isinstance(file, NDCube):
return file.meta.datetime
date = file.split("_")[-2]
return datetime.strptime(date, "%Y%m%d%H%M%S") # noqa: DTZ007
[docs]
def get_polstate(file: str | NDCube) -> str:
"""Convert file path or NDCube to date_obs."""
if isinstance(file, NDCube):
return file.meta["TYPECODE"].value[1]
typecode = file.split("_")[-3]
return typecode[1]
[docs]
def bundle_matched_mzp(m_cubes: list[NDCube | str],
z_cubes: list[NDCube | str] | None = None,
p_cubes: list[NDCube | str] | None = None,
threshold: float = 75.0) -> list[tuple[NDCube | str, NDCube | str, NDCube | str]]:
"""Search and bundle MZP triplets closest in time."""
if z_cubes is None:
cubes = m_cubes
m_cubes, z_cubes, p_cubes = [], [], []
for cube in cubes:
pol = get_polstate(cube)
if pol == "M":
m_cubes.append(cube)
elif pol == "Z":
z_cubes.append(cube)
elif pol == "P":
p_cubes.append(cube)
else:
raise ValueError("Unrecognized pol state")
m_dateobs = [get_dateobs(cube) for cube in m_cubes]
z_dateobs = [get_dateobs(cube) for cube in z_cubes]
p_dateobs = [get_dateobs(cube) for cube in p_cubes]
# use Z as the reference
triplets = []
for z_index, z_datetime in enumerate(z_dateobs):
m_deltas = [abs((z_datetime - m_datetime).total_seconds()) for m_datetime in m_dateobs]
p_deltas = [abs((z_datetime - p_datetime).total_seconds()) for p_datetime in p_dateobs]
matching_m = np.argmin(m_deltas)
matching_p = np.argmin(p_deltas)
m_time_diff = m_deltas[matching_m]
p_time_diff = p_deltas[matching_p]
if m_time_diff > threshold or p_time_diff > threshold:
missing = []
if m_time_diff > threshold:
missing.append("M")
if p_time_diff > threshold:
missing.append("P")
msg = f"No matching {' and '.join(missing)} for Z at {z_datetime.isoformat()}"
warnings.warn(msg)
else:
triplets.append((m_cubes[matching_m], z_cubes[z_index], p_cubes[matching_p]))
return triplets
[docs]
@numba.njit(cache=True, parallel=True)
def masked_mean(array: ArrayLike,
mask: ArrayLike)-> np.ndarray:
"""Masked nanmean along the first axis of entries where both mask is True and data is finite."""
output = np.empty(array.shape[1:])
for i in numba.prange(array.shape[1]):
for j in range(array.shape[2]):
sequence = array[:, i, j].copy()
n_good = 0
for k in range(len(sequence)):
if not np.isfinite(sequence[k]) or not mask[k, i, j]:
sequence[k] = 0
else:
n_good += 1
if n_good == 0:
output[i, j] = np.nan
else:
output[i, j] = np.sum(sequence) / n_good
return output
T = TypeVar("T")
[docs]
class DataLoader(abc.ABC, Generic[T]):
"""Interface for passing callable objects instead of file paths to be loaded."""
[docs]
@abc.abstractmethod
def load(self) -> T:
"""Load the data."""
[docs]
@abc.abstractmethod
def src_repr(self) -> str:
"""Return a string representation of the data source."""
[docs]
def inpaint_nans(image: np.ndarray, kernel_size: int = 5) -> np.ndarray:
"""
Fill nans in an image with a neighborhood value.
Parameters
----------
image : np.ndarray
image with nans
kernel_size : int
odd integer size for the smoothing kernel
Returns
-------
np.ndarray
image with nans filled
"""
image = image.copy() # don't mutate the original image
if kernel_size % 2 == 0:
msg = "Kernel size must be odd."
raise RuntimeError(msg)
kernel = np.ones((kernel_size, kernel_size))
kernel[kernel_size//2, kernel_size//2] = 0
last_nan_mask = np.zeros(image.shape, dtype=bool)
while np.any(np.isnan(image)):
nan_mask = np.isnan(image)
if np.all(nan_mask == last_nan_mask):
# Nothing's changed, so let's bail out. This can happen if an image has corrupted packets, causing every
# row to pass the row threshold and thus every pixel is NaN
break
last_nan_mask = nan_mask
image[nan_mask] = 0
neighbors = convolve2d(~nan_mask, kernel, mode="same", boundary="symm")
convolved = convolve2d(image, kernel, mode="same", boundary="symm")
convolved[neighbors>0] = convolved[neighbors>0]/neighbors[neighbors>0]
convolved[neighbors==0] = np.nan
convolved[~nan_mask] = image[~nan_mask]
image = convolved
return image
[docs]
def compute_tb(data: NDCube | np.ndarray) -> np.ndarray:
"""Compute total brightness from input NDCube or 3D data array of shape (MZP, ...)."""
if isinstance(data, np.ndarray):
return 2/3 * np.sum(data, axis=0)
if data.meta["OBS-MODE"].value == "Polar_BpB":
return data.data[0, ...]
return 2/3 * np.sum(data.data, axis=0)
[docs]
def censor_wcs(wcs: WCS, obstime: bool = True, observer: bool = True) -> WCS:
"""
Remove observer details from a WCS.
When input images have slightly different viewpoints, Sunpy will say this
is an invalid coordinate transformation. Here we censor information from the
WCS to pacify Sunpy.
"""
wcs = wcs.deepcopy()
if observer:
wcs.wcs.aux.hgln_obs = None
wcs.wcs.aux.hglt_obs = None
wcs.wcs.aux.dsun_obs = None
if obstime:
wcs.wcs.dateobs = ""
wcs.wcs.dateavg = ""
wcs.wcs.datebeg = ""
wcs.wcs.dateend = ""
return wcs
[docs]
class ShmPickleableNDArray(np.ndarray):
"""
A numpy array backed by shared memory that pickles without copying data.
Pickling happens by only transmitting the shared memory name (and array shape, etc.) and re-connecting to the
shared memory on the receiving side, without ever pickling or copying the array contents. This is extremely
useful when multi-processing with large data arrays, as data can be sent back and forth between workers with zero
copying, and in a very seamless way.
Python spawns a tracker process that ensures the shared memory is freed after the main process terminates. Memory
is also freed when an array is deleted (when Python determines the array's reference count has dropped to
zero)---this implies that any views have also been deleted, since views keep a reference to their base array.
ShmPickleableNDArray supports indexing and slicing, creating views into the same shared-memory array the same way
that normal NDArrays do. Note that operations that produce a copy of the data, suce as "advanced indexing" (
indexing with an array of booleans or integers) produces a new array not backed by shared memory, which will not
enjoy any advantages when pickling. In such a case, the resulting array will raise a RuntimeError if it is pickled.
"""
def __new__(cls, shape: tuple, dtype: np.dtype = np.float64, buffer_name: str | None = None, strides: Any = None,
offset: int = 0, persist: bool = False, **kwargs: dict) -> Self:
"""Create a new array."""
nbytes = 1
if isinstance(shape, int):
nbytes *= shape
else:
for e in shape:
nbytes *= e
nbytes *= np.dtype(dtype).itemsize
# size is ignored if create is False
shm = SharedMemory(create=buffer_name is None, size=nbytes, track=True, name=buffer_name)
obj = super().__new__(cls, shape=shape, dtype=dtype, buffer=shm.buf, strides=strides,
offset=offset, **kwargs)
obj._shm = shm
obj.persist = persist
return obj
[docs]
def __array_finalize__(self, obj: Any) -> None:
"""Finalize array setup."""
if not hasattr(self, "persist"):
self.persist = True
if not hasattr(self, "_shm"):
self._shm = None
self.is_freed = getattr(obj, "is_freed", False)
[docs]
@classmethod
def from_array(cls, array: np.ndarray) -> "ShmPickleableNDArray":
"""Convert an array into a ShmPickleableNDArray."""
obj = ShmPickleableNDArray.empty_like(array)
obj[:] = array
return obj
[docs]
@classmethod
def empty_like(cls, array: np.ndarray) -> "ShmPickleableNDArray":
"""Create an empty array like the given array."""
return ShmPickleableNDArray(array.shape, array.dtype)
@cached_property
def orig_array(self) -> "ShmPickleableNDArray":
"""Get the whole underlying array."""
base = self
while isinstance(base.base, ShmPickleableNDArray):
base = base.base
return base
@property
def shm(self) -> SharedMemory:
"""Access the base shared memory."""
return self.orig_array._shm # noqa: SLF001
@property
def numpy(self) -> np.ndarray:
"""Convert to a plain numpy array."""
return np.array(self)
[docs]
def free(self) -> None:
"""
Free shared memory immediately.
Each shared memory object has a tracker process that ensures it is freed when the process that created it
terminates. This function is only needed to free the memory early, before the process ends. Note that
accessing the array after freeing the backing memory may result in a segfault.
"""
if (shm := self._shm) is None:
return
self._shm = None
self.is_freed = True
shm.unlink()
thread = threading.Thread(target=shm.close, daemon=True)
thread.start()
[docs]
def __del__(self) -> None:
"""Delete the array."""
if hasattr(super(), "__del__"):
super().__del__()
if not self.persist and not self.is_freed:
self.free()
[docs]
def __getitem__(self, *args: tuple, **kwargs: dict) -> "ShmPickleableNDArray":
"""Index the array."""
if self.orig_array.is_freed:
# Guard against segfaults after freeing the shared memory
raise RuntimeError("Attempt to access array that's already been freed")
return super().__getitem__(*args, **kwargs)
[docs]
def __setitem__(self, *args: tuple, **kwargs: dict) -> None:
"""Index the array."""
if self.orig_array.is_freed:
# Guard against segfaults after freeing the shared memory
raise RuntimeError("Attempt to access array that's already been freed")
return super().__setitem__(*args, **kwargs)
[docs]
def __repr__(self, *args: tuple, **kwargs: dict) -> str:
"""Repr the array."""
if self.orig_array.is_freed:
# Guard against segfaults after freeing the shared memory
return "<freed ShmPickleableNDArray>"
return super().__repr__(*args, **kwargs)
@property
def data(self) -> memoryview:
"""Access array data directly."""
if self.orig_array.is_freed:
# Guard against segfaults after freeing the shared memory
raise RuntimeError("Attempt to access array that's already been freed")
return super().data
[docs]
def __reduce__(self) -> tuple:
"""Pickle the object."""
base = self.orig_array
if base._shm is None: # noqa: SLF001
raise RuntimeError(
"Pickling an ShmPickleableNDArray view of a non-ShmPickleableNDArray array. "
"Use ShmPickleableNDArray.numpy?")
base_bounds = np.lib.array_utils.byte_bounds(base)
our_bounds = np.lib.array_utils.byte_bounds(self)
offset = our_bounds[0] - base_bounds[0]
return ShmPickleableNDArray.__new__, (ShmPickleableNDArray, self.shape, self.dtype, self.shm.name,
self.strides, offset, True)