Source code for punchbowl.level1.despike

from concurrent.futures import ThreadPoolExecutor

import numpy as np
from ndcube import NDCube
from prefect import get_run_logger
from scipy.ndimage import binary_dilation, gaussian_filter
from threadpoolctl import threadpool_limits

from punchbowl.data import load_ndcube_from_fits
from punchbowl.level1.deficient_pixel import mean_correct
from punchbowl.level1.sqrt import decode_sqrt_data
from punchbowl.prefect import punch_task


[docs] def despike_polseq( reference: NDCube, neighbors: list[NDCube], filter_width: float=25.0, hpf_zscore_thresh: float=10.0, )->tuple[NDCube, np.ndarray]: """ Remove cosmic ray spikes from a single polarization sequence of images. This code takes as input multiple (N) images from the same roll sequence. It constructs a high-pass-filtered version of the input images. At each pixel, it computes the median and standard deviation of the (N-1) dimmest pixels, and then the z-score of each pixel. If the z-score exceeds a threshold, a cosmic ray is assumed and the pixel is filled in with the mean of its neighbors. Parameters ---------- reference : NDCube an NDCube to correct for cosmic rays neighbors : List[NDCube] a list of NDCube objects representing a polarization image sequence, should not include the reference image filter_width: float width of the gaussian filter used in created the high-pass-filtered image hpf_zscore_thresh: float number of standard deviations above the sequence median[sic] that causes a pixel to be marked as a cosmic ray. Returns ------- (NDCube, np.ndarray) a NDCube with spikes replaced by the average of their neighbors, and a list of spike locations for all neighbors """ sequence = np.stack([cube.data.copy() for cube in [*neighbors, reference]], axis=0) seq_len = sequence.shape[0] # saturated regions can lead to weird uncertainties and leftovers so we try to mask them inf_uncertainty_mask = sequence >= 60_000 inf_uncertainty_mask = binary_dilation(inf_uncertainty_mask) inf_uncertainty_mask = np.any(inf_uncertainty_mask, axis=0) sequence[np.stack([inf_uncertainty_mask for _ in range(seq_len)])] = 0 # create the high-pass-filtered images def blur_one_image(image: np.ndarray)->np.ndarray: return gaussian_filter(image, [filter_width,filter_width], mode="nearest") with ThreadPoolExecutor(len(sequence)) as p: lpf_decoded = np.stack(list(p.map(blur_one_image, sequence))) hpf = sequence.astype(float) - lpf_decoded.astype(float) hpf_sorted = np.sort(hpf, axis=0) #For a polarization sequence of length N, this finds the median #of the N-1 lowest values of each pixel match seq_len: case 7: hpf_median_s = np.nanmean(hpf_sorted[2:4],axis=0) case 6: hpf_median_s = hpf_sorted[2] case 5: hpf_median_s = np.nanmean(hpf_sorted[1:3],axis=0) case 4: hpf_median_s = hpf_sorted[1] case 3: hpf_median_s = np.nanmean(hpf_sorted[0:2],axis=0) case _: raise RuntimeError(f"A sequence length of {seq_len} is not supported.") hpf_stdev = np.std(hpf_sorted[:-1], axis=0, ddof=0) hpf_zscore = (hpf - hpf_median_s)/hpf_stdev sequence_spike_mask = np.zeros_like(sequence, dtype=bool) sequence_spike_mask[hpf_zscore>=hpf_zscore_thresh] = 1 image_bounds = sequence[-1] == 0 # the image is zero where there's no data correction_mask = (hpf_zscore >= hpf_zscore_thresh)[-1] correction_mask[inf_uncertainty_mask] = 1 correction_mask[image_bounds] = 0 sequence_spike_mask[-1] = correction_mask reference.data[correction_mask] = np.nan reference.data = mean_correct(data_array=reference.data, mask_array=~np.isnan(reference.data)) # any remaining nans are bad! reference.uncertainty.array[np.isnan(reference.data)] = np.inf reference.data[np.isnan(reference.data)] = 0 return reference, sequence_spike_mask
[docs] @punch_task def despike_polseq_task(data_object: NDCube, neighbors: list[NDCube] | list[str], filter_width: float=25.0, hpf_zscore_thresh: float=10.0, max_workers: int | None = None)-> NDCube: """ Despike a polarization sequence of images using a simple statistical test. Parameters ---------- data_object : NDCube Image to be despiked. neighbors : list[NDCube] | list[str] Sequence of neighbor images from the same spacecraft and roll sequence to use in despiking. filter_width : float, optional width of the gaussian filter used to construct the high-pass-filtered image, in pixels. hpf_zscore_thresh: float, optional number of standard deviations above the sequence median[sic] that causes a pixel to be marked as a cosmic ray. max_workers : int, optional Max number of threads to use Returns ------- NDCube Despiked cube. """ logger = get_run_logger() neighbors = neighbors if neighbors is not None else [] if 3 <= len(neighbors) <= 7: logger.info(f"Neighbors = {neighbors}") neighbors = [load_ndcube_from_fits(n) if isinstance(n, str) else n for n in neighbors] neighbors = [decode_sqrt_data(n) for n in neighbors] with threadpool_limits(max_workers): data_object, spikes = despike_polseq( data_object, neighbors, filter_width=filter_width, hpf_zscore_thresh=hpf_zscore_thresh) data_object.uncertainty.array[spikes[-1]] = np.inf data_object.meta.history.add_now("LEVEL1-despike", "image despiked") data_object.meta.history.add_now("LEVEL1-despike", f"filter_width={filter_width}") data_object.meta.history.add_now("LEVEL1-despike", f"zscore_thresh={hpf_zscore_thresh}") data_object.meta.history.add_now("LEVEL1-despike", f"neighbor_count={len(neighbors)}") else: data_object.meta.history.add_now("LEVEL1-despike", "Incompatible neighbor count so no correction applied") logger.info(f"Incompatible neighbor count {len(neighbors)} so despiking is skipped") return data_object