Source code for punchbowl.util
import os
import abc
from typing import Generic, TypeVar
from datetime import UTC, datetime
import numpy as np
from ndcube import NDCube
from punchbowl.data import load_ndcube_from_fits, write_ndcube_to_fits
from punchbowl.exceptions import InvalidDataError
from punchbowl.prefect import punch_task
[docs]
def validate_image_is_square(image: np.ndarray) -> None:
"""Check that the input array is square."""
if not isinstance(image, np.ndarray):
msg = f"Image must be of type np.ndarray. Found: {type(image)}."
raise TypeError(msg)
if len(image.shape) != 2:
msg = f"Image must be a 2-D array. Input has {len(image.shape)} dimensions."
raise ValueError(msg)
if not np.equal(*image.shape):
msg = f"Image must be square. Found: {image.shape}."
raise ValueError(msg)
[docs]
@punch_task
def output_image_task(data: NDCube, output_filename: str) -> None:
"""
Prefect task to write an image to disk.
Parameters
----------
data : NDCube
data that is to be written
output_filename : str
where to write the file out
Returns
-------
None
"""
output_dir = os.path.dirname(output_filename)
if output_dir and not os.path.isdir(output_dir):
os.makedirs(output_dir)
write_ndcube_to_fits(data, output_filename)
[docs]
@punch_task(tags=["image_loader"])
def load_image_task(input_filename: str, include_provenance: bool = True, include_uncertainty: bool = True) -> NDCube:
"""
Prefect task to load data for processing.
Parameters
----------
input_filename : str
path to file to load
include_provenance : bool
whether to load the provenance layer
include_uncertainty : bool
whether to load the uncertainty layer
Returns
-------
NDCube
loaded version of the image
"""
return load_ndcube_from_fits(
input_filename, include_provenance=include_provenance, include_uncertainty=include_uncertainty)
[docs]
def average_datetime(datetimes: list[datetime]) -> datetime:
"""Compute average datetime from a list of datetimes."""
timestamps = [dt.replace(tzinfo=UTC).timestamp() for dt in datetimes]
average_timestamp = sum(timestamps) / len(timestamps)
return datetime.fromtimestamp(average_timestamp).astimezone(UTC)
[docs]
def _zvalue_from_index(arr, ind): # noqa: ANN202, ANN001
"""
Do math.
Private helper function to work around the limitation of np.choose() by employing np.take().
arr has to be a 3D array
ind has to be a 2D array containing values for z-indicies to take from arr
See: http://stackoverflow.com/a/32091712/4169585
This is faster and more memory efficient than using the ogrid based solution with fancy indexing.
"""
# get number of columns and rows
_, n_rows, n_cols = arr.shape
# get linear indices and extract elements with np.take()
idx = n_cols*n_rows*ind + n_cols*np.arange(n_rows)[:,None] + np.arange(n_cols)
return np.take(arr, idx)
[docs]
def nan_percentile(arr: np.ndarray, q: list[float] | float) -> np.ndarray:
"""Calculate the nan percentile faster of a 3D cube."""
# np.nanpercentile is slow so use this: https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# valid (non NaN) observations along the first axis
is_good = np.isfinite(arr)
n_valid_obs = np.sum(is_good, axis=0)
# replace NaN with maximum
arr = arr.copy()
arr[~is_good] = np.nanmax(arr)
# sort - former NaNs will move to the end
arr = np.sort(arr, axis=0)
# loop over requested quantiles
qs = [q] if isinstance(q, float | int) else q
result = np.empty((len(qs), *arr.shape[1:]))
for i, quant in enumerate(qs):
# desired position as well as floor and ceiling of it
k_arr = (n_valid_obs - 1) * (quant / 100)
f_arr = np.floor(k_arr).astype(np.int32)
c_arr = np.ceil(k_arr).astype(np.int32)
fc_equal_k_mask = f_arr == c_arr
# linear interpolation (like numpy percentile) takes the fractional part of desired position
floor_val = _zvalue_from_index(arr=arr, ind=f_arr) * (c_arr - k_arr)
ceil_val = _zvalue_from_index(arr=arr, ind=c_arr) * (k_arr - f_arr)
quant_arr = floor_val + ceil_val
# if floor == ceiling take floor value
quant_arr[fc_equal_k_mask] = _zvalue_from_index(arr=arr, ind=k_arr.astype(np.int32))[fc_equal_k_mask]
result[i] = quant_arr
result[:, n_valid_obs == 0] = np.nan
return result
[docs]
def interpolate_data(data_before: NDCube, data_after:NDCube, reference_time: datetime) -> np.ndarray:
"""Interpolates between two data objects."""
before_date = data_before.meta.datetime.timestamp()
after_date = data_after.meta.datetime.timestamp()
observation_date = reference_time.timestamp()
if before_date > observation_date:
msg = "Before data was after the observation date"
raise InvalidDataError(msg)
if after_date < observation_date:
msg = "After data was before the observation date"
raise InvalidDataError(msg)
if before_date == observation_date:
data_interpolated = data_before
elif after_date == observation_date:
data_interpolated = data_after
else:
data_interpolated = ((data_after.data - data_before.data)
* (observation_date - before_date) / (after_date - before_date)
+ data_before.data)
return data_interpolated
[docs]
def find_first_existing_file(inputs: list[NDCube]) -> NDCube | None:
"""Find the first cube that's not None in a list of NDCubes."""
for cube in inputs:
if cube is not None:
return cube
msg = "No cube found. All inputs are None."
raise RuntimeError(msg)
T = TypeVar("T")
[docs]
class DataLoader(abc.ABC, Generic[T]):
"""Interface for passing callable objects instead of file paths to be loaded."""
[docs]
@abc.abstractmethod
def load(self) -> T:
"""Load the data."""
[docs]
@abc.abstractmethod
def src_repr(self) -> str:
"""Return a string representation of the data source."""