import os
import multiprocessing
import multiprocessing as mp
from datetime import UTC, datetime
from concurrent.futures import ProcessPoolExecutor
import astropy
import numba
import numpy as np
import scipy.optimize
from astropy.nddata import StdDevUncertainty
from dateutil.parser import parse as parse_datetime_str
from ndcube import NDCube
from numpy.polynomial import polynomial
from prefect import get_run_logger
from quadprog import solve_qp
from scipy.interpolate import griddata
from threadpoolctl import threadpool_limits
from punchbowl.data import NormalizedMetadata
from punchbowl.data.punch_io import load_ndcube_from_fits
from punchbowl.data.wcs import load_trefoil_wcs
from punchbowl.exceptions import InvalidDataError
from punchbowl.prefect import punch_flow, punch_task
from punchbowl.util import ShmPickleableNDArray, average_datetime, interpolate_data, nan_percentile
[docs]
def solve_qp_cube(input_vals: np.ndarray, cube: np.ndarray,
n_nonnan_required: int=7) -> (np.ndarray, np.ndarray):
"""
Fast solver for the quadratic programming problem.
Parameters
----------
input_vals : np.ndarray
array of times
cube : np.ndarray
array of data
n_nonnan_required : int
The number of non-nan values that must be present in each pixel's time series.
Any pixels with fewer will not be fit, with zeros returned instead.
Returns
-------
np.ndarray
Array of coefficients for solving polynomial
"""
c = np.transpose(input_vals)
cube_is_good = np.isfinite(cube)
num_inputs = np.sum(cube_is_good, axis=0)
solution = np.zeros((input_vals.shape[1], cube.shape[1], cube.shape[2]))
for i in range(cube.shape[1]):
for j in range(cube.shape[2]):
is_good = cube_is_good[:, i, j]
time_series = cube[:, i, j][is_good]
if time_series.size < n_nonnan_required:
solution[:, i, j] = 0
else:
c_iter = c[:, is_good]
g_iter = np.matmul(c_iter, c_iter.T)
a = np.matmul(c_iter, time_series)
try:
solution[:, i, j] = solve_qp(g_iter, a, c_iter, time_series)[0]
except ValueError:
solution[:, i, j] = 0
return np.asarray(solution), num_inputs
[docs]
def model_fcorona_for_cube_real(xt: np.ndarray,
reference_xt: float,
cube: np.ndarray,
min_brightness: float = 1E-18,
clip_factor: float | None = 1,
return_full_curves: bool = False,
num_workers: int | None = 8,
detrend: bool = True,
) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Model the F corona given a list of times and a corresponding data cube.
Parameters
----------
xt : np.ndarray
time array
reference_xt: float
timestamp to evaluate the model for
cube : np.ndarray
observation array
min_brightness: float
pixels dimmer than this value are set to nan and considered empty
clip_factor : float | None
If None, no smoothing is applied.
Otherwise, the difference between the 25th and 75th percentile is computed and values that vary from the median
by more than `clip_factor` times the difference data are rejected.
return_full_curves: bool
If True, this function returns the full curve fitted to the time series at each pixel
and the smoothed data cube. If False (default), only the curve's value at the central
frame is returned, producing a model at one instant in time.
num_workers: int | None
Work is parallelized over this many worker processes. If None, this matches the number of cores.
detrend : bool
Whether to detrend each time series before outlier rejection
Returns
-------
np.ndarray
The F-corona model at the central point in time. If return_full_curves is True, this is
instead the F-corona model at all points in time covered by the data cube
np.ndarray
The number of data points used in solving the F-corona model for each pixel of the output
np.ndarray
The smoothed data cube. Returned only if return_full_curves is True.
"""
# TODO : re-enable F corona modeling
stride = 32
def args() -> tuple:
# Generate a set of args for one task
for i in range(0, cube.shape[0], stride):
for j in range(0, cube.shape[1], stride):
yield (xt, reference_xt, cube[i:i+stride, j:j+stride, :], min_brightness, clip_factor,
return_full_curves, detrend)
def reassemble(inputs: tuple) -> np.ndarray:
output = np.empty((cube.shape[0], cube.shape[1], *inputs[0].shape[2:]), dtype=inputs[0].dtype)
k = 0
for i in range(0, cube.shape[0], stride):
for j in range(0, cube.shape[1], stride):
output[i:i+stride, j:j+stride] = inputs[k]
k += 1
return output
# Since we're parallelizing with processes, we shouldn't run a lot of threads
with threadpool_limits(2), mp.Pool(processes=num_workers) as pool:
chunks = pool.starmap(_model_fcorona_for_cube_inner, args(), chunksize=4)
# Combine the outputs of each task into final output arrays
if return_full_curves:
curves, counts, cubes = zip(*chunks, strict=False)
curves = reassemble(curves)
counts = reassemble(counts)
cubes = reassemble(cubes)
return curves, counts, cubes
model, counts = zip(*chunks, strict=False)
model = reassemble(model)
counts = reassemble(counts)
return model, counts
[docs]
def model_fcorona_for_cube(cube: np.ndarray) -> np.ndarray:
"""
Model the F corona given a list of times and a corresponding data cube.
Parameters
----------
xt : np.ndarray
Unused
reference_xt: float
Unused
cube : np.ndarray
observation array
args : list
Kept for signature compatibility
kwargs : dict
Kept for signature compatibility
Returns
-------
np.ndarray
The F-corona model at the central point in time. If return_full_curves is True, this is
instead the F-corona model at all points in time covered by the data cube
None
Nothing
"""
return nan_percentile(cube, 3)
[docs]
def _model_fcorona_for_cube_inner(xt: np.ndarray,
reference_xt: float,
cube: np.ndarray,
min_brightness: float = 1E-18,
clip_factor: float | None = 1,
return_full_curves: bool=False,
detrend: bool = True,
) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
cube = cube.transpose((2, 0, 1))
cube[cube < min_brightness] = np.nan
xt = np.array(xt)
reference_xt -= xt[0]
xt -= xt[0]
def trend_fcn(x: np.ndarray, xvals: np.ndarray) -> np.ndarray:
c0, c1, c2 = x
return c0 + c1 * xvals + c2 * xvals ** 2
def trend_resid(x: np.ndarray, xvals: np.ndarray, yvals: np.ndarray) -> np.ndarray:
return trend_fcn(x, xvals.ravel()) - yvals.ravel()
good_px = np.isfinite(cube)
if detrend:
if np.sum(good_px) < 20:
detrended_cube = cube
else:
x = np.broadcast_to(xt[:, None, None], cube.shape)
jacobian = np.stack((
0*x[good_px].ravel() + 1,
x[good_px],
x[good_px] ** 2,
), axis=1)
res = scipy.optimize.least_squares(trend_resid, (np.median(cube[good_px]), 0, 0), loss="cauchy",
f_scale=.5e-13, kwargs={"xvals": x[good_px], "yvals": cube[good_px]},
jac=lambda *a, **kw: jacobian) #noqa: ARG005
trend = trend_fcn(res.x, xt)
detrended_cube = cube - trend[:, None, None]
else:
detrended_cube = cube
if clip_factor is not None and np.any(good_px):
low, center, high = nan_percentile(detrended_cube, [25, 50, 75])
width = high - low
a, b, c = np.where(detrended_cube[:, ...] > (center + (clip_factor * width)))
cube[a, b, c] = np.nan
a, b, c = np.where(detrended_cube[:, ...] < (center - (clip_factor * width)))
cube[a, b, c] = np.nan
input_array = np.c_[np.power(xt, 3), np.square(xt), xt, np.ones(len(xt))]
coefficients, counts = solve_qp_cube(input_array, -cube)
coefficients *= -1
if return_full_curves:
return polynomial.polyval(xt, coefficients[::-1, :, :]), counts, cube.transpose((1, 2, 0))
return polynomial.polyval(reference_xt, coefficients[::-1, :, :]), counts
[docs]
def fill_nans_with_interpolation(image: np.ndarray) -> np.ndarray:
"""Fill NaN values in an image using interpolation."""
mask = np.isnan(image)
x, y = np.where(~mask)
known_values = image[~mask]
grid_x, grid_y = np.mgrid[0:image.shape[0], 0:image.shape[1]]
return griddata((x, y), known_values, (grid_x, grid_y), method="cubic")
[docs]
def _load_file(path: str, data_destination: ShmPickleableNDArray) -> tuple[np.ndarray, datetime, str]:
data_destination[:] = np.nan
try:
cube = load_ndcube_from_fits(path, include_provenance=False, dtype=np.float32)
except Exception as e: # noqa: BLE001
return str(e)
cropx = cube.meta["CROPX1"].value, cube.meta["CROPX2"].value
cropy = cube.meta["CROPY1"].value, cube.meta["CROPY2"].value
data_destination[:, cropy[0]:cropy[1], cropx[0]:cropx[1]] = (
np.where(np.isfinite(cube.uncertainty.array), cube.data, np.nan)
)
np.nan_to_num(cube.uncertainty.array, nan=0, posinf=0, neginf=0, copy=False)
# Square the array in-place
cube.uncertainty.array *= cube.uncertainty.array
uncert = np.zeros(data_destination.shape, dtype=np.float32)
uncert[..., cropy[0]:cropy[1], cropx[0]:cropx[1]] = cube.uncertainty.array
return uncert.squeeze(), cube.meta.datetime, cube.meta["OBSCODE"].value
[docs]
@punch_flow(log_prints=True)
def construct_f_corona_model(filenames: list[str], # noqa: C901
reference_time: str | None = None,
num_workers: int = 8,
num_loaders: int | None = None,
fill_nans: bool = False,
polarized: bool = False) -> list[NDCube]:
"""Construct a full F corona model."""
numba.set_num_threads(num_workers)
logger = get_run_logger()
if reference_time is None:
reference_time = datetime.now(UTC)
elif isinstance(reference_time, str):
reference_time = parse_datetime_str(reference_time)
trefoil_wcs, trefoil_shape = load_trefoil_wcs()
logger.info("construct_f_corona_background started")
if len(filenames) == 0:
msg = "Require at least one input file"
raise ValueError(msg)
filenames.sort()
number_of_data_frames = len(filenames)
uncertainty = np.zeros((3, *trefoil_shape) if polarized else trefoil_shape)
sample_counts = np.zeros((3 if polarized else 1, *trefoil_shape) , dtype=int)
data_cube = ShmPickleableNDArray((number_of_data_frames, 3 if polarized else 1, *trefoil_shape), dtype=np.float32)
logger.info("beginning data loading")
dates = []
n_failed = 0
ctx = multiprocessing.get_context("forkserver")
with ProcessPoolExecutor(num_loaders, mp_context=ctx) as pool:
for i, result in enumerate(pool.map(_load_file, filenames, data_cube)):
if isinstance(result, str):
logger.warning(f"Loading {filenames[i]} failed")
logger.warning(result)
n_failed += 1
if n_failed > 0.05 * len(filenames):
raise RuntimeError(f"{n_failed} files failed to load, stopping")
continue
this_uncertainty, date, obscode = result
dates.append(date)
sample_counts += this_uncertainty != 0
uncertainty += this_uncertainty
if (i + 1) % 50 == 0:
logger.info(f"Loaded {i+1}/{len(filenames)} files")
logger.info(f"end of data loading, saw {n_failed} failures")
models = []
for i in range(data_cube.shape[1]):
model_fcorona = model_fcorona_for_cube(data_cube[:, i])
model_fcorona[sample_counts[i] == 0] = np.nan
if fill_nans:
model_fcorona = fill_nans_with_interpolation(model_fcorona)
models.append(model_fcorona)
uncertainty = np.sqrt(uncertainty) / sample_counts
if polarized:
output_data = np.stack(models, axis=0)
meta = NormalizedMetadata.load_template("PF" + obscode, "3")
trefoil_wcs = astropy.wcs.utils.add_stokes_axis_to_wcs(trefoil_wcs, 2)
else:
output_data = models[0]
meta = NormalizedMetadata.load_template("CF" + obscode, "3")
meta.provenance = sorted([os.path.basename(f) for f in filenames])
meta["DATE"] = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
meta["DATE-AVG"] = average_datetime(dates).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
meta["DATE-OBS"] = reference_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
meta["DATE-BEG"] = min(dates).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
meta["DATE-END"] = max(dates).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
output_cube = NDCube(data=output_data,
meta=meta,
wcs=trefoil_wcs,
uncertainty=StdDevUncertainty(uncertainty))
return [output_cube]
[docs]
def subtract_f_corona_background(data_object: NDCube,
before_f_background_model: NDCube,
after_f_background_model: NDCube,
allow_extrapolation: bool = False) -> NDCube:
"""Subtract f corona background."""
# check dimensions match
if data_object.data.shape != before_f_background_model.data.shape:
msg = (
"f_background_subtraction expects the data_object and"
"f_background arrays to have the same dimensions."
f"data_array dims: {data_object.data.shape} "
f"and before_f_background_model dims: {before_f_background_model.data.shape}"
)
raise InvalidDataError(
msg,
)
if data_object.data.shape != after_f_background_model.data.shape:
msg = (
"f_background_subtraction expects the data_object and"
"f_background arrays to have the same dimensions."
f"data_array dims: {data_object.data.shape} "
f"and after_f_background_model dims: {after_f_background_model.data.shape}"
)
raise InvalidDataError(
msg,
)
interpolated_model, interpolated_uncertainty = interpolate_data(
before_f_background_model,
after_f_background_model,
data_object.meta.datetime,
allow_extrapolation=allow_extrapolation,
and_uncertainty=True)
interpolated_model[(data_object.data == 0) & np.isinf(data_object.uncertainty.array)] = 0
original_mask = (data_object.data == 0) * np.isinf(data_object.uncertainty.array)
data_object.data[...] -= interpolated_model
data_object.data[original_mask] = 0
data_object.uncertainty.array[...] = np.sqrt(data_object.uncertainty.array**2 + interpolated_uncertainty**2)
return data_object
[docs]
@punch_task
def subtract_f_corona_background_task(observation: NDCube,
before_f_background_models: list[NDCube | str],
after_f_background_models: list[NDCube | str],
allow_extrapolation: bool = False) -> NDCube:
"""
Subtracts a background f corona model from an observation.
This algorithm linearly interpolates between the before and after models.
Parameters
----------
observation : NDCube
an observation to subtract an f corona model from
before_f_background_models : list[NDCube | str]
NDCube f corona background maps before the observations
after_f_background_models : list[NDCube | str]
NDCube f corona background maps after the observations
allow_extrapolation : bool
If true, allow extrapolation beyond the time range spanned by the two F corona models
Returns
-------
NDCube
A background subtracted data frame
"""
logger = get_run_logger()
logger.info("subtract_f_corona_background started")
before_f_background_models = [load_ndcube_from_fits(f) if isinstance(f, str) else f
for f in before_f_background_models]
after_f_background_models = [load_ndcube_from_fits(f) if isinstance(f, str) else f
for f in after_f_background_models]
for model in before_f_background_models:
if model.meta["OBSCODE"].value != observation.meta["OBSCODE"].value:
continue
if observation.meta["TYPECODE"].value[1] == "R" and model.meta["TYPECODE"].value[0] == "C":
before_model = model
break
if observation.meta["TYPECODE"].value[1] == "P" and model.meta["TYPECODE"].value[0] == "P":
before_model = model
break
else:
raise RuntimeError(f"Could not find before model for {observation.meta['FILENAME']}")
for model in after_f_background_models:
if model.meta["OBSCODE"].value != observation.meta["OBSCODE"].value:
continue
if observation.meta["TYPECODE"].value[1] == "R" and model.meta["TYPECODE"].value[0] == "C":
after_model = model
break
if observation.meta["TYPECODE"].value[1] == "P" and model.meta["TYPECODE"].value[0] == "P":
after_model = model
break
else:
raise RuntimeError(f"Could not find after model for {observation.meta['FILENAME']}")
output = subtract_f_corona_background(observation, before_model, after_model,
allow_extrapolation=allow_extrapolation)
output.meta.history.add_now("LEVEL3-subtract_f_corona_background", "subtracted f corona background")
logger.info("subtract_f_corona_background finished")
return output
[docs]
def create_empty_f_background_model(data_object: NDCube) -> np.ndarray:
"""Create an empty background model."""
return np.zeros_like(data_object.data)