Source code for punchbowl.level1.stray_light

import time
import pathlib
import warnings
import multiprocessing
from logging import Logger
from datetime import UTC, datetime, timedelta
from functools import cached_property
from itertools import repeat, pairwise
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor

import numba
import numpy as np
import reproject
import scipy
from astropy.nddata import StdDevUncertainty
from astropy.wcs import WCS
from dateutil.parser import parse as parse_datetime
from lmfit import Parameters, minimize
from lmfit.minimizer import MinimizerResult
from ndcube import NDCube
from prefect import get_run_logger

from punchbowl.data import NormalizedMetadata, load_ndcube_from_fits, load_trefoil_wcs
from punchbowl.exceptions import (
    CantInterpolateWarning,
    IncorrectPolarizationStateError,
    IncorrectTelescopeError,
    InvalidDataError,
)
from punchbowl.level2.polarization import resolve_polarization
from punchbowl.level2.resample import reproject_cube
from punchbowl.prefect import punch_flow, punch_task
from punchbowl.util import (
    DataLoader,
    ShmPickleableNDArray,
    average_datetime,
    bundle_matched_mzp,
    inpaint_nans,
    interpolate_data,
    load_mask_file,
    nan_gaussian,
    nan_percentile,
)


[docs] class SkewFitResult: """Stores inputs and result of skewed Gaussian fitting.""" def __init__(self, fit: MinimizerResult, bin_centers: np.ndarray, scaled_x_values: np.ndarray, bin_values: np.ndarray, stack: np.ndarray, scale_factor: float, weights: np.ndarray, target_center: float) -> None: """Initialize class.""" self.fit = fit self.bin_centers = bin_centers self.scaled_x_values = scaled_x_values self.bin_values = bin_values self.stack = stack self.scale_factor = scale_factor self.weights = weights self.target_center = target_center self.x0 = self.fit.params["x0"].value self.A = self.fit.params["A"].value self.alpha = self.fit.params["alpha"].value self.sigma = self.fit.params["sigma"].value self.m = self.fit.params["m"].value self.b = self.fit.params["b"].value self.dx = np.min(np.diff(self.bin_centers)) @cached_property def result(self) -> float: """Return the mode of the skewed Gaussian.""" # Uses the approximation from https://en.wikipedia.org/wiki/Skew_normal_distribution. I took the factor of # 1/np.sqrt(2) out of the actual skew-gaussian function, so what we have as the fitted alpha is *actually* # alpha / sqrt(2) # noqa: ERA001 a = self.alpha * np.sqrt(2) # Find the mode of a skew Gaussian delta = a / np.sqrt(1 + a ** 2) mode = ( np.sqrt(2 / np.pi) * delta - (1 - np.pi / 4) * (np.sqrt(2 / np.pi) * delta) ** 3 / (1 - 2 / np.pi * delta ** 2) - np.sign(a) / 2 * np.exp(-2 * np.pi / np.abs(a)) ) return (mode * self.sigma + self.x0) / self.scale_factor
[docs] def fit_is_sus(self) -> bool: """Flag fits that look suspicious.""" maxval = self.bin_values.max() if (3 * maxval < self.A or np.abs(self.alpha) > 10000 or self.sigma < 0.5 * self.dx * self.scale_factor or self.sigma > 2 * ( self.bin_centers[-1] - self.bin_centers[0]) * self.scale_factor or not (-4 * maxval < self.m * self.result * self.scale_factor + self.b < 2 * maxval) ): return True return any(param.value == param.min or param.value == param.max for param in self.fit.params.values())
[docs] def plot(self, mark_result: bool = True, mark_tcenter: bool = True) -> None: """Plot the fit.""" import matplotlib.pyplot as plt # noqa: PLC0415 plt.step(self.bin_centers, self.bin_values, where="mid") plt.scatter(self.bin_centers, self.bin_values) fit_x = np.linspace(self.bin_centers[0], self.bin_centers[-1], 200) plt.plot(fit_x, skew_gaussian(fit_x * self.scale_factor, self.A, self.alpha, self.x0, self.sigma, 0, 0), color="C1", label="Skew Gaussian") plt.plot(fit_x, skew_gaussian(fit_x * self.scale_factor, 0, 0, 0, 1, self.m, self.b), color="C2", label="Linear") plt.plot(fit_x, skew_gaussian(fit_x * self.scale_factor, self.A, self.alpha, self.x0, self.sigma, self.m, self.b), color="C3", label="Model") if mark_result: plt.axvline(self.result, color="C4", label="Our result") if mark_tcenter: plt.axvline(self.target_center, color="C5", label="Targeted peak", ls=":") plt.legend()
# To compute skew-gaussians a bit faster, we pre-generate a lookup table. This table's resolution is good to within # 0.0002 (where the absolute values are within [0, 1]), and where the function values aren't minute, the table is # good to within 0.02%. Using the table saves about 30% of the computation time---and we compute a *lot* of skew # Gaussians! pdf_table_vals = np.arange(-4.2, 4.2, 0.02) cdf_table_vals = np.arange(-3, 3, 0.02) pdf_vals = np.exp(-0.5 * pdf_table_vals**2) cdf_vals = 1 + scipy.special.erf(cdf_table_vals)
[docs] def skew_gaussian(x: np.ndarray, A: float, alpha: float, x0: float, sigma: float, m: float, b: float, # noqa: N803 ) -> np.ndarray: """Calculate a skewed Gaussian.""" y = (x - x0) / sigma pdf = np.interp(y, pdf_table_vals, pdf_vals, left=0, right=0) cdf = np.interp(alpha * y, cdf_table_vals, cdf_vals, left=0, right=2) return A * pdf * cdf + m * x + b
[docs] def _resid_skew(params: Parameters, scaled_x_values: np.ndarray, y_values: np.ndarray, bin_weights: np.ndarray, ) -> np.ndarray: """Evaluate function and return the residual.""" params = params.valuesdict() resids = y_values - skew_gaussian(scaled_x_values, params["A"], params["alpha"], params["x0"], params["sigma"], params["m"], params["b"]) resids *= bin_weights return resids
[docs] def pick_peak(bin_values: np.ndarray) -> int: """Pick the first (left-most) peak, but isn't fooled if the bins dip by <20% and then keep going up.""" peak_min_height = 0.15 * bin_values.max() largest_seen = -1 largest_idx = -1 n_consecutive_downhill = 0 i = 0 while True: if i > 0: if bin_values[i] >= bin_values[i-1]: n_consecutive_downhill = 0 else: n_consecutive_downhill += 1 if bin_values[i] > largest_seen: largest_seen = bin_values[i] largest_idx = i else: peak_seems_peakish = bin_values[i] < 0.85 * largest_seen or n_consecutive_downhill >= 3 peak_is_valid = largest_seen >= peak_min_height this_bin_is_ok = bin_values[i] > 0 if peak_seems_peakish and peak_is_valid and this_bin_is_ok: return largest_idx i += 1 if i >= len(bin_values): return largest_idx
[docs] def find_peak_end(bin_values: np.ndarray, peak_location: int, direction: int) -> int: """Find the edges of the first (left-most) peak, by expanding until appreciable up-hillage is found.""" lowest_seen = np.inf lowest_idx = -1 i = peak_location while True: if bin_values[i] < lowest_seen: lowest_seen = bin_values[i] lowest_idx = i if bin_values[i] > 1.25 * lowest_seen and bin_values[i] > 0: return lowest_idx i += direction if i >= len(bin_values) or i < 0: return lowest_idx
[docs] class OutOfPointsError(RuntimeError): """Raised when the histogram runs out of points."""
[docs] def fit_skew(stack: np.ndarray, ret_all: bool = False, x_scale_factor: float = 1e13, weight: bool = True, # noqa: C901 plot_histogram_steps: bool = False, exclude_above_percentile: float = 0) -> float | SkewFitResult: """Fit a skewed Gaussian to a histogram of data values to estimate the stray light value.""" # The bulk of this function is producing a histogram of the data and progressively zooming in that histogram # until we've found the left-most peak in the data. Then a bit at the end of the function fits that final # histogram with a skewed Gaussian. if exclude_above_percentile: percentile_value = np.percentile(stack, exclude_above_percentile) stack = stack[stack < percentile_value] # Build our first histogram bin_values, bin_edges, *_ = np.histogram(stack, bins=50) dx = bin_edges[1] - bin_edges[0] if plot_histogram_steps: import matplotlib.pyplot as plt # noqa: PLC0415 dx = bin_edges[1] - bin_edges[0] plt.bar(bin_edges[:-1] + dx/2, bin_values, width=dx) # We start by trimming outliers. We do that by making a histogram, finding the tallest bin, and then working out # from there until we hit bins with little to no counts. We care about the main part of the distribution, # so anything beyond those (nearly-) empty bins is an outlier. So we exclude those points and "zoom in" by # re-making the histogram using only points between those two identified bins. This process repeats until there # aren't any (nearly-) empty bins. min_count = 0.01 * bin_values.max() # Safety valve to avoid infinite loops max_loops_remaining = 10 while np.any(bin_values <= min_count) and max_loops_remaining: max_loops_remaining -= 1 peak = np.argmax(bin_values) # Go to the left, looking for nearly-empty bins istart = peak while bin_values[istart] > min_count and istart > 0: istart -= 1 # Set our new low bound to be the high edge of the bin if it's empty, or the low end if it's full but it's # the last bin. stopped_on_small_bin = istart > 0 or bin_values[istart] <= min_count low = bin_edges[istart + 1] if stopped_on_small_bin else bin_edges[istart] # Go to the right, looking for nearly-empty bins istop = peak while bin_values[istop] > min_count and istop < len(bin_values) - 1: istop += 1 stopped_on_small_bin = istop < len(bin_values) - 1 or bin_values[istop] <= min_count high = bin_edges[istop] if stopped_on_small_bin else bin_edges[istop + 1] if plot_histogram_steps: plt.axvline(low) plt.axvline(high) plt.title("Zooming to cut outlier bins") plt.show() # Re-make the histogram within these bounds bin_values, bin_edges, *_ = np.histogram(stack, bins=50, range=(low, high)) dx = bin_edges[1] - bin_edges[0] if plot_histogram_steps: plt.bar(bin_edges[:-1] + dx/2, bin_values, width=dx) if np.sum(bin_values) < 100: raise OutOfPointsError min_count = 0.01 * bin_values.max() # Now the outliers should be gone. When present, they were dragging the range of our histogram way out, # so the core distribution had very poor resolution. Now we should have good resolution on the core area, # and we can refine our zoom range better. Here we identify the target peak, walk downhill from it to find its # edges, and we zoom there to isolate our targeted peak and avoid fitting a different peak. imax = pick_peak(bin_values) peak_location = bin_edges[imax] + dx / 2 ilow = find_peak_end(bin_values, imax, -1) ihigh = find_peak_end(bin_values, imax, 1) low = bin_edges[ilow] high = bin_edges[ihigh + 1] if plot_histogram_steps: plt.axvline(peak_location, ls="--") plt.axvline(low) plt.axvline(high) plt.title("Zooming in to isolate peak") plt.show() bin_values, bin_edges, *_ = np.histogram(stack, bins=20, range=(low, high)) dx = bin_edges[1] - bin_edges[0] bin_centers = bin_edges[:-1] + dx / 2 if plot_histogram_steps: plt.bar(bin_edges[:-1] + dx/2, bin_values, width=dx) if np.sum(bin_values) < 100: raise OutOfPointsError # Next we walk out from the target peak until we find bins that are low relative to the peak, to chop off the # tails of the distribution. imax = pick_peak(bin_values) peak_val = bin_values[imax] for istart in range(imax - 1, -1, -1): if bin_values[istart] < .4 * peak_val: break else: istart = 0 low = bin_edges[istart] for istop in range(imax, len(bin_values)): if bin_values[istop] < .5 * peak_val: break if istop > len(bin_values) - 1: istop = len(bin_values) - 1 high = bin_edges[1 + istop] if plot_histogram_steps: plt.title("Zooming in to exclude tail") plt.axvline(bin_edges[imax] + dx/2, ls="--") plt.axvline(low) plt.axvline(high) plt.show() bin_values, bin_edges, *_ = np.histogram(stack, bins=20, range=(low, high)) if np.sum(bin_values) < 100: raise OutOfPointsError dx = bin_edges[1] - bin_edges[0] bin_centers = bin_edges[:-1] + dx / 2 imax = pick_peak(bin_values) peak_val = bin_values[imax] peak_location = bin_centers[imax] if plot_histogram_steps: plt.bar(bin_centers, bin_values, width=dx) # Now, if the target peak doesn't seem wide enough (in terms of number of bins), we'll zoom in further. (If we # only had a couple of bins in the actual peak, we'll probably get a really poor fit.). Note that this step may # be unnecessary---it was added before the "zoom to target peak" step earlier, but that step may well always give # good resolution on that peak. while True: # We need to compute how many bins wide our peak is (roughly) p2p = peak_val - np.min(bin_values) # We're comparing bins' height above the minimum bin value, not the height above 0! n_in_peak = np.sum(bin_values > peak_val - 0.4 * p2p) if plot_histogram_steps: plt.suptitle(f"p2p {p2p}, thresh {peak_val - 0.4 * p2p}, {n_in_peak} bins above thresh") if n_in_peak > 0.2 * len(bin_values): break center = bin_centers[imax] dlow = center - low low = center - 0.8 * dlow dhigh = high - center high = center + 0.8 * dhigh if plot_histogram_steps: plt.title("Zooming in to widen peak") plt.axvline(bin_edges[imax] + dx/2, ls="--") plt.axvline(low) plt.axvline(high) plt.show() bin_values, bin_edges, *_ = np.histogram(stack, bins=20, range=(low, high)) if np.sum(bin_values) < 100: raise OutOfPointsError dx = bin_edges[1] - bin_edges[0] bin_centers = bin_edges[:-1] + dx / 2 if plot_histogram_steps: plt.bar(bin_centers, bin_values, width=dx) imax = pick_peak(bin_values) peak_val = bin_values[imax] peak_location = bin_centers[imax] if plot_histogram_steps: plt.axvline(peak_location, ls="--") plt.title("Final distribution") plt.show() # This concludes the histogram preparation. Now we get ready for fitting. # Sometimes there are empty bins that just seem to make the fit worse, so exclude them full_bins = bin_values > .05 * peak_val bin_values = bin_values[full_bins] if np.sum(bin_values) < 100 or len(bin_values) < 5: raise OutOfPointsError bin_centers = bin_centers[full_bins] imax = np.where(bin_values == peak_val)[0][0] # No longer valid del bin_edges if weight: # Assign weights that just taper off with distance from the peak bin. bin_weights = 1 / (40 + np.abs(np.arange(0, len(bin_values)) - imax)) bin_weights /= bin_weights.max() else: bin_weights = np.ones_like(bin_values) params = Parameters() params.add("A", value=0.5/np.sqrt(2*np.pi) * np.max(bin_values), min=0, max=2 * peak_val) params.add("alpha", value=0, min=0) params.add("x0", value=x_scale_factor * peak_location, min=(bin_centers[0] - dx) * x_scale_factor, max=(bin_centers[-1] + dx) * x_scale_factor) params.add("sigma", value=6 * dx * x_scale_factor, min=1e-20, max=10) params.add("m", value=0, vary=True) params.add("b", value=0, vary=True) scaled_x_values = bin_centers * x_scale_factor with np.errstate(all="ignore"): out = minimize(_resid_skew, params, args=(scaled_x_values, bin_values, bin_weights), method="least_squares", calc_covar=False, ftol=2e-4, gtol=2e-4) r = SkewFitResult(out, bin_centers=bin_centers, scaled_x_values=scaled_x_values, bin_values=bin_values, stack=stack, scale_factor=x_scale_factor, weights=bin_weights, target_center=peak_location) if ret_all: return r if r.fit_is_sus(): return np.nan return r.result
REQUIRED_FRACTION_OF_NEIGHBORHOOD_PIXELS = 0.5
[docs] def _estimate_stray_light_one_slice(data_array: np.ndarray, y: int, x_grid: np.ndarray, half_width: int, bin_mask: np.ndarray) -> np.ndarray: """This is our parallel worker, computing the stray light model for one y coordinate.""" # noqa: D401 D404 result = np.empty(x_grid.shape) for j, x in enumerate(x_grid): stack = data_array[bin_mask, y - half_width : y + half_width + 1, x - half_width : x + half_width + 1].ravel() n_pts = stack.size stack = stack[np.abs(stack) > 1e-17] if stack.size < n_pts * REQUIRED_FRACTION_OF_NEIGHBORHOOD_PIXELS: r = np.nan else: try: r = fit_skew(stack, False) except OutOfPointsError: r = np.nan result[j] = r return result
[docs] def _load_files(filepaths: list[str], mosaic_wcs: WCS, logger: Logger, do_uncertainty: bool, pool: ProcessPoolExecutor, polarized: bool) -> tuple[ShmPickleableNDArray, ShmPickleableNDArray, list[WCS], list[NormalizedMetadata], np.ndarray]: shape = (len(filepaths), 3 if polarized else 1, 2048, 2048) data_array = ShmPickleableNDArray(shape, dtype=np.float32) shape = (len(filepaths), 3 if polarized else 1, *mosaic_wcs.array_shape) reprojected_array = ShmPickleableNDArray(shape, dtype=np.float32) uncertainties = None metas = [] wcses = [] n_failed = 0 logger.info(f"Will read {len(filepaths)} {'triplets' if polarized else 'images'}") for i, result in enumerate(pool.map( _load_and_reproject, filepaths, repeat(mosaic_wcs), data_array, reprojected_array, repeat(polarized))): if isinstance(result, str): logger.warning(f"Loading {filepaths[i]} failed") logger.warning(result) n_failed += 1 metas.append(None) wcses.append(None) if n_failed > .05 * len(filepaths): raise RuntimeError(f"{n_failed} files failed to load, stopping") continue these_metas, these_wcses, these_uncertainties = result metas.append(these_metas) wcses.append(these_wcses) if do_uncertainty: if uncertainties is None: uncertainties = np.zeros(data_array.shape[1:], dtype=np.float32) for j, (uncertainty, meta) in enumerate(zip(these_uncertainties, these_metas, strict=True)): if uncertainty is not None and not meta["OUTLIER"].value: # The final uncertainty is sqrt(sum(square(input uncertainties))), so we accumulate the squares here uncertainties[j] += np.nan_to_num(uncertainty, posinf=0, neginf=0) ** 2 if (i + 1) % 100 == 0: logger.info(f"Loaded {i + 1}/{len(filepaths)} {'triplets' if polarized else 'files'}") logger.info(f"Finished loading files, saw {n_failed} failures") return data_array, reprojected_array, wcses, metas, uncertainties
bottom_crops = [230, 240, 243]
[docs] def _load_and_reproject(paths: str | tuple[str], target_wcs: WCS, data_destination: np.ndarray, repro_destination: np.ndarray, polarized: bool) -> tuple[list, list] | str: repro_destination[:] = np.nan if not polarized: paths = [paths] try: cubes = [load_ndcube_from_fits(path, dtype=np.float32, include_provenance=False) for path in paths] except Exception as e: # noqa: BLE001 data_destination[:] = np.nan return str(e) for cube in cubes: if not np.any(np.isfinite(cube.uncertainty.array)): # If this happens, when reproject_cube trims all-bad # rows/columns, it makes a zero-size array and the # reprojection crashes. data_destination[:] = np.nan return f"All-bad image {cube.meta['FILENAME'].value}" bottom_crop = bottom_crops[int(cubes[0].meta["OBSCODE"].value) - 1] for i in range(len(cubes)): data_destination[i, :] = cubes[i].data resolved_cubes = resolve_polarization(cubes) if polarized else cubes repro_input = np.empty(data_destination.shape, dtype=data_destination.dtype) for i, cube in enumerate(resolved_cubes): repro_input[i] = np.where(np.isinf(cube.uncertainty.array), np.nan, cube.data) # Now we prepare the image that gets reprojected and used to build the coronal background, and we trim # aggressively to remove areas that are usually low-quality y, x = np.mgrid[:2048, :2048] # Clip the upper corners repro_input[:, y > 1300 + x] = np.nan repro_input[:, y > 1300 + (2048 - x)] = np.nan # Clip the lower-left corner, including a good portion of the bottom edge repro_input[:, y < 1200 - 1.3 * x] = np.nan repro_input[:, y < 850 - 0.75 * x] = np.nan # Clip the lower-right corner, including a good portion of the bottom edge repro_input[:, y < 1200 - 1.3 * (2048 - x)] = np.nan repro_input[:, y < 850 - 0.75 * (2048 - x)] = np.nan # Don't even reproject the sides and bottom repro_input = repro_input[:, bottom_crop:, 350:-350] wcs_cropped = cubes[0].wcs[bottom_crop:, 350:-350] with warnings.catch_warnings(), np.errstate(all="ignore"): warnings.filterwarnings(action="ignore", message=".*failed to converge to the requested.*") reproject_cube.fn(NDCube(repro_input, meta=cubes[0].meta, wcs=wcs_cropped), target_wcs, repro_destination.shape, rolloff_strength=0, rolloff_width=0, output_array=repro_destination, repro_args={"boundary_mode": "ignore", "bad_value_mode": "ignore"}, do_uncertainty=False) uncerts = [cube.uncertainty.array if cube.uncertainty is not None else None for cube in cubes] return [cube.meta for cube in cubes], [cube.wcs for cube in cubes], uncerts
[docs] def _subtract_coronal_model(data_slice: np.ndarray, wcses: list[WCS], metas: list[NormalizedMetadata], corona_models: list, corona_model_dates: list, coronal_wcs: WCS) -> None: if wcses is None: return meta = metas[0] wcs = wcses[0] dobs = meta.datetime if dobs <= corona_model_dates[0]: model = corona_models[0] elif dobs >= corona_model_dates[-1]: model = corona_models[-1] else: for i, j in pairwise(range(len(corona_model_dates))): if corona_model_dates[i] < dobs <= corona_model_dates[j]: break model = ((corona_models[j] - corona_models[i]) * ((dobs - corona_model_dates[i]).total_seconds() / (corona_model_dates[j] - corona_model_dates[i]).total_seconds()) + corona_models[i]) bottom_crop = 200 model = reproject.reproject_adaptive((model, coronal_wcs), wcs[bottom_crop:], (2048 - bottom_crop, 2048), roundtrip_coords=False, return_footprint=False, bad_value_mode="ignore", boundary_mode="ignore") np.nan_to_num(model, copy=False) if data_slice.shape[0] > 1: cubes = [] for i in range(data_slice.shape[0]): metas[i]["POLARREF"] = "Solar" cubes.append(NDCube(data=model[i], wcs=wcses[i][bottom_crop:], meta=metas[i])) cubes = resolve_polarization(cubes, "mzpinstru") for i in range(len(cubes)): model[i] = cubes[i].data data_slice[:, bottom_crop:, :] -= model
[docs] @punch_task def _build_and_subtract_corona(reprojected_array: np.ndarray, data_array: np.ndarray, metas: list[list[NormalizedMetadata]], wcses: list[list[WCS]], mosaic_wcs: WCS, mask: np.ndarray, pool: ProcessPoolExecutor, polarized: bool) -> None: logger = get_run_logger() logger.info("Making coronal models") corona_models = [] corona_model_dates = [] valid_dates = [m[0].datetime for m in metas if m is not None] dstart = valid_dates[0] dstop = dstart + timedelta(hours=30) while dstop < valid_dates[-1]: istart = np.argmin([np.abs((m[0].datetime - dstart).total_seconds()) if m is not None else 9e99 for m in metas]) istop = np.argmin([np.abs((m[0].datetime - dstop).total_seconds()) if m is not None else 9e99 for m in metas]) if istop - istart < 50: continue mdate = average_datetime([m[0].datetime for m in metas[istart:istop] if m is not None]) model = ShmPickleableNDArray.empty_like(reprojected_array[0]) for i in range(reprojected_array.shape[1]): model[i] = nan_percentile(reprojected_array[istart:istop, i], 5) def blur_one_image(i: int) -> None: model[i] = nan_gaussian(model[i], 3.5) # noqa: B023 with ThreadPoolExecutor(len(model)) as p: p.map(blur_one_image, range(len(model))) np.nan_to_num(model, copy=False) corona_models.append(model) corona_model_dates.append(mdate) dstart += timedelta(hours=24) dstop += timedelta(hours=24) logger.info("Models made; subtracting") for i, _ in enumerate(pool.map(_subtract_coronal_model, data_array, wcses, metas, repeat(corona_models), repeat(corona_model_dates), repeat(mosaic_wcs))): data_array[i] *= mask if (i + 1) % 100 == 0: logger.info(f"Corona-subtracted {i + 1}/{len(data_array)} {'triplets' if polarized else 'files'}") logger.info("Models subtracted")
[docs] def _make_one_sl_model(bin_n: int, bin_mask: np.ndarray, logger: Logger, outliers: np.ndarray, strided_image_mask: np.ndarray, x_grid: np.ndarray, y_grid: np.ndarray, window_half_width: int, image_mask: np.ndarray, data_array: np.ndarray, make_plots_along_the_way: bool, blur_sigma: float, stride: int, window_size: int, pool: ProcessPoolExecutor) -> np.ndarray: logger.info(f"Starting bin {bin_n + 1}, containing {np.sum(bin_mask)} images") n_outliers = np.sum(outliers[bin_mask]) logger.info(f"{n_outliers} outliers in this bin") if n_outliers < 0.1 * np.sum(bin_mask): # If there's a few outliers, ignore them. If there's lots, we're probably at eclipse season and we have to # buckle up and use them anyway. bin_mask = bin_mask * ~outliers logger.info("Beginning model fitting") stray_light_estimate = np.stack(list(pool.map( _estimate_stray_light_one_slice, repeat(data_array), y_grid, repeat(x_grid), repeat(window_half_width), repeat(bin_mask))), axis=0) logger.info("Finished model fitting") if make_plots_along_the_way: import matplotlib.pyplot as plt # noqa: PLC0415 plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("Raw") plt.show() # Fill spots where the fitting didn't succeed. But don't fill stuff outside the image mask. stray_light_estimate[~strided_image_mask] = 0 stray_light_estimate = inpaint_nans(stray_light_estimate, kernel_size=5) if make_plots_along_the_way: plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("post inpaint") plt.show() # Now the outer masked region needs to be NaNs so it doesn't impact the Gaussian blurring we're about to do. stray_light_estimate[~strided_image_mask] = np.nan if make_plots_along_the_way: plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("Filled") plt.show() if blur_sigma: stray_light_estimate = nan_gaussian(stray_light_estimate, blur_sigma) if make_plots_along_the_way: plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("Blurred") plt.show() stray_light_estimate[~strided_image_mask] = 0 if make_plots_along_the_way: plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("Masked") plt.show() if stride > 1 or window_size > 1: # Upsample to a proper output size interper = scipy.interpolate.RegularGridInterpolator( (y_grid, x_grid), stray_light_estimate, method="linear", bounds_error=False, fill_value=None) out_y, out_x = np.mgrid[:data_array.shape[1], :data_array.shape[2]] stray_light_estimate = interper(np.stack((out_y, out_x), axis=-1)) stray_light_estimate *= image_mask if make_plots_along_the_way: plt.imshow(stray_light_estimate, vmin=0, vmax=.5e-12, origin="lower") plt.title("Interped, final") plt.show() logger.info(f"Finished with bin {bin_n + 1}") return stray_light_estimate
[docs] @punch_flow def estimate_stray_light(filepaths: list[str], # noqa: C901 do_uncertainty: bool = True, reference_time: datetime | str | None = None, stride: int = 10, window_size: int = 5, blur_sigma: float = 1.5, n_crota_bins: int = 30, crota_bin_width: float = 45, image_mask_path: str | None = None, make_plots_along_the_way: bool = False, polarized: bool = False, num_workers: int | None = None) -> list[NDCube]: """Estimate the fixed stray light pattern using a percentile.""" logger = get_run_logger() numba.set_num_threads(num_workers) if window_size % 2 == 0: raise ValueError("Window size must be odd") logger.info(f"Running with {len(filepaths)} input files") if isinstance(reference_time, str): reference_time = datetime.strptime(reference_time, "%Y-%m-%d %H:%M:%S").replace(tzinfo=UTC) image_mask = load_mask_file(image_mask_path) if image_mask_path is not None else None # Make sure things are in temporal order filepaths = sorted(filepaths) inputs_to_load = bundle_matched_mzp(filepaths) if polarized else filepaths mosaic_wcs, _ = load_trefoil_wcs() # Fit the edges of every image in the mosiac by zooming out a tad, and down-size the mosaic model_downscale = 2 mosaic_wcs.wcs.cdelt = 0.024 * model_downscale, 0.024 * model_downscale mosaic_wcs.wcs.crpix = 2048 // model_downscale, 2048 // model_downscale mosaic_wcs.array_shape = (4096 // model_downscale, 4096 // model_downscale) ctx = multiprocessing.get_context("forkserver") with ProcessPoolExecutor(num_workers, mp_context=ctx) as pool: start = time.time() data_array, reprojected_array, wcses, metas, uncertainty = _load_files( inputs_to_load, mosaic_wcs, logger, do_uncertainty, pool, polarized) time_loading = time.time() - start outliers = [np.array([True if m is None else (m[i]["OUTLIER"].value != 0) for m in metas]) for i in range(data_array.shape[1])] # Get first non-None value valid_meta = next(filter(None, metas)) valid_wcs = next(filter(None, wcses)) if image_mask is None: image_mask = ~np.all(data_array == 0, axis=(0, 1)) bottom_crop = bottom_crops[int(valid_meta[0]["OBSCODE"].value) - 1] image_mask[:bottom_crop] = 0 start = time.time() _build_and_subtract_corona(reprojected_array, data_array, metas, wcses, mosaic_wcs, image_mask, pool, polarized) time_corona = time.time() - start # Free this memory early, as we don't need it anymore reprojected_array.free() del reprojected_array # Build our CROTA bins bin_centers = np.linspace(-180, 180, n_crota_bins, endpoint=False) bin_starts = bin_centers - crota_bin_width / 2 bin_stops = bin_centers + crota_bin_width / 2 crota_is_in_bin = (lambda crota, binn: ((bin_starts[binn] < crota <= bin_stops[binn]) or (bin_starts[binn] < crota - 360 <= bin_stops[binn]) or (bin_starts[binn] < crota + 360 <= bin_stops[binn]))) bin_masks = [] for binn in range(n_crota_bins): mask = np.array([False if m is None else crota_is_in_bin(m[0]["CROTA"].value, binn) for m in metas]) bin_masks.append(mask) logger.info(f"Bin centered at CROTA {bin_centers[binn]} has {np.sum(mask)} images") window_half_width = window_size // 2 # Build a grid with `stride` as the spacing, but exclude from the edges so that the window we use at each # stride position fits x_grid = np.arange(window_half_width, data_array.shape[-1] - window_half_width, stride) y_grid = np.arange(window_half_width, data_array.shape[-2] - window_half_width, stride) # Downsample the image mask carefully, to have each superpixel indicate whether it contains enough pixels # inside the mask for this function's inner loop to get enough samples. strided_image_mask = np.empty((y_grid.size, x_grid.size), dtype=bool) for i, y in enumerate(y_grid): for j, x in enumerate(x_grid): sample = image_mask[y - window_half_width:y + window_half_width + 1, x - window_half_width:x + window_half_width + 1] strided_image_mask[i, j] = sample.sum() > sample.size * REQUIRED_FRACTION_OF_NEIGHBORHOOD_PIXELS models_per_pol = [] start = time.time() for i in range(data_array.shape[1]): logger.info(f"Starting SL modeling for polarization state {i+1}") models = [] for bin_n, bin_mask in enumerate(bin_masks): models.append(_make_one_sl_model(bin_n, bin_mask, logger, outliers[i], strided_image_mask, x_grid, y_grid, window_half_width, image_mask, data_array[:, i], make_plots_along_the_way, blur_sigma, stride, window_size, pool)) models_per_pol.append(models) time_making_models = time.time() - start # Free this memory early, as we don't need it anymore data_array.free() del data_array logger.info(f"Spent {time_loading:.1f} s loading & reprojecting, {time_corona:.1f} s subtracting corona, and " f"{time_making_models:.1f} s making SL models") if do_uncertainty: uncertainty = np.sqrt(uncertainty) / len(filepaths) out_cubes = [] for i in range(len(models_per_pol)): out_type = "S" + valid_meta[i].product_code[1:] meta = NormalizedMetadata.load_template(out_type, "1") meta.provenance = sorted([m[i]["FILENAME"].value for m in metas if m is not None]) all_date_obses = [m[i].datetime for m in metas if m is not None] meta["DATE-AVG"] = average_datetime(all_date_obses).strftime("%Y-%m-%dT%H:%M:%S") meta["DATE-OBS"] = reference_time.strftime("%Y-%m-%dT%H:%M:%S") if reference_time else meta["DATE-AVG"].value meta["DATE-BEG"] = min(all_date_obses).strftime("%Y-%m-%dT%H:%M:%S") meta["DATE-END"] = max(all_date_obses).strftime("%Y-%m-%dT%H:%M:%S") meta["DATE"] = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] meta.history.add_now("stray light", f"Generated with {len(meta.provenance)} files running from " f"{min(all_date_obses).strftime('%Y-%m-%dT%H:%M:%S')} to " f"{max(all_date_obses).strftime('%Y-%m-%dT%H:%M:%S')}") meta["FILEVRSN"] = valid_meta[0]["FILEVRSN"].value # Let's put in a valid, representative WCS, with the right scale and sun-relative pointing, etc. wcs = valid_wcs[0] wcs.cpdis1 = None wcs.cpdis2 = None out_cube = NDCube(data=np.array(models_per_pol[i]), meta=meta, wcs=wcs, uncertainty=StdDevUncertainty(uncertainty[i])) out_cubes.append(out_cube) return out_cubes
[docs] @punch_task def remove_stray_light_task(data_object: NDCube, #noqa: C901 stray_light_before_path: pathlib.Path | str | NDCube | DataLoader, stray_light_after_path: pathlib.Path | str | NDCube | DataLoader) -> NDCube: """ Prefect task to remove stray light from an image. Stray light is light in an optical system which was not intended in the design. The PUNCH instrument stray light will be mapped periodically as part of the ongoing in-flight calibration effort. The stray light maps will be generated directly from the L0 and L1 science data. Separating instrumental stray light from the F-corona. This has been demonstrated with SOHO/LASCO and with STEREO/COR2 observations. It requires an instrumental roll to hold the stray light pattern fixed while the F-corona rotates in the field of view. PUNCH orbital rolls will be used to create similar effects. Uncertainty across the image plane is calculated using a known stray light model and the difference between the calculated stray light and the ground truth. The uncertainty is convolved with the input uncertainty layer to produce the output uncertainty layer. Parameters ---------- data_object : NDCube data to operate on stray_light_before_path: pathlib path to stray light model before observation to apply to data stray_light_after_path: pathlib path to stray light model after observation to apply to data Returns ------- NDCube modified version of the input with the stray light removed """ if stray_light_before_path is None or stray_light_after_path is None: data_object.meta.history.add_now("LEVEL1-remove_stray_light", "Stray light correction skipped") return data_object if isinstance(stray_light_before_path, NDCube): stray_light_before_model = stray_light_before_path elif isinstance(stray_light_before_path, DataLoader): stray_light_before_model = stray_light_before_path.load() else: stray_light_before_path = pathlib.Path(stray_light_before_path) if not stray_light_before_path.exists(): msg = f"File {stray_light_before_path} does not exist." raise InvalidDataError(msg) stray_light_before_model = load_ndcube_from_fits(stray_light_before_path) if isinstance(stray_light_after_path, NDCube): stray_light_after_model = stray_light_after_path elif isinstance(stray_light_after_path, DataLoader): stray_light_after_model = stray_light_after_path.load() else: stray_light_after_path = pathlib.Path(stray_light_after_path) if not stray_light_after_path.exists(): msg = f"File {stray_light_after_path} does not exist." raise InvalidDataError(msg) stray_light_after_model = load_ndcube_from_fits(stray_light_after_path) for model in stray_light_before_model, stray_light_after_model: if model.meta["TELESCOP"].value != data_object.meta["TELESCOP"].value: msg=f"Incorrect TELESCOP value within {model.meta['FILENAME'].value}" raise IncorrectTelescopeError(msg) if model.meta["OBSLAYR1"].value != data_object.meta["OBSLAYR1"].value: msg=f"Incorrect polarization state within {model.meta['FILENAME'].value}" raise IncorrectPolarizationStateError(msg) if model.data.shape[1:] != data_object.data.shape: msg = f"Incorrect stray light function shape within {model.meta['FILENAME'].value}" raise InvalidDataError(msg) # Here we handle the CROTA bins. First, figure out which bin we're in. # First bin center is duplicated at the end bin_centers = np.linspace(-180, 180, stray_light_before_model.shape[0] + 1) bin_width = 360 / stray_light_before_model.shape[0] crota = data_object.meta["CROTA"].value # CROTA falls within [-180, 180] for before_bin, after_bin in pairwise(range(len(bin_centers))): if bin_centers[before_bin] < crota <= bin_centers[after_bin]: break fpos = (crota - bin_centers[before_bin]) / bin_width if after_bin == len(bin_centers) - 1: after_bin = 0 before_at_orbit_pos = (stray_light_before_model.data[before_bin] * (1 - fpos) + stray_light_before_model.data[after_bin] * fpos) after_at_orbit_pos = (stray_light_after_model.data[before_bin] * (1 - fpos) + stray_light_after_model.data[after_bin] * fpos) stray_light_before_model = NDCube( data=before_at_orbit_pos, meta=stray_light_before_model.meta, wcs=stray_light_before_model.wcs) stray_light_after_model = NDCube( data=after_at_orbit_pos, meta=stray_light_after_model.meta, wcs=stray_light_after_model.wcs) # Next we do the temporal interpolation. # For the quickpunch case, our stray light models run right up to the current time, with their DATE-OBS likely days # in the past. It feels reckless to interpolate the six-hour variation in the model over several days, so let's # instead interpolate using the nearst of DATE-BEG, DATE-AVG, or DATE-END. (DATE-BEG will be the best choice when # reprocessing.) delta_dateavg = abs(parse_datetime(stray_light_before_model.meta["DATE-AVG"].value + " UTC") - data_object.meta.datetime) delta_datebeg = abs(parse_datetime(stray_light_before_model.meta["DATE-BEG"].value + " UTC") - data_object.meta.datetime) delta_dateend = abs(parse_datetime(stray_light_before_model.meta["DATE-END"].value + " UTC") - data_object.meta.datetime) closest = min(delta_datebeg, delta_dateavg, delta_dateend) if closest is delta_datebeg: time_key = "DATE-BEG" elif closest is delta_dateavg: time_key = "DATE-AVG" else: time_key = "DATE-END" if stray_light_before_model.meta[time_key].value == stray_light_after_model.meta[time_key].value: warnings.warn( "Timestamps are identical for the stray light models; can't inter/extrapolate", CantInterpolateWarning) stray_light_model = stray_light_before_model.data else: stray_light_model = interpolate_data(stray_light_before_model, stray_light_after_model, data_object.meta.datetime, time_key=time_key, allow_extrapolation=True) data_object.data[:, :] -= stray_light_model uncertainty = 0 # TODO: when we have real uncertainties, use them # uncertainty = stray_light_model.uncertainty.array # noqa: ERA001 data_object.uncertainty.array[...] = np.sqrt(data_object.uncertainty.array**2 + uncertainty**2) data_object.meta.history.add_now("LEVEL1-remove_stray_light", f"stray light removed with {stray_light_before_model.meta['FILENAME'].value} " f"and {stray_light_after_model.meta['FILENAME'].value}") return data_object