Source code for punchbowl.levelq.pca

import os
import logging
import multiprocessing as mp
import logging.handlers
from itertools import repeat
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import scipy.signal
from astropy.coordinates import EarthLocation, get_body
from astropy.io import fits
from astropy.time import Time
from astropy.wcs import WCS
from ndcube import NDCube
from prefect import get_run_logger
from sklearn.decomposition import PCA
from threadpoolctl import threadpool_limits

from punchbowl.data import NormalizedMetadata
from punchbowl.prefect import punch_task
from punchbowl.util import DataLoader, load_image_task


[docs] @punch_task def pca_filter(input_cubes: list[NDCube], files_to_fit: list[NDCube | DataLoader | str], n_components: int=50, med_filt: int=5, n_strides: int = 8, blend_size: int = 70) -> None: """Run PCA-based filtering.""" logger = get_run_logger() all_files_to_fit, bodies_in_quarter, to_subtract, good_data_mask, is_outlier = load_files( input_cubes, files_to_fit, blend_size) # 25 threads per worker would saturate all our cores if they all run at once, but experience shows they don't. with threadpool_limits(min(25, os.cpu_count())), ThreadPoolExecutor(min(n_strides, os.cpu_count())) as p: for subtracted_cube_indices, subtracted_images in p.map( pca_filter_one_stride, repeat(all_files_to_fit), range(n_strides), repeat(n_strides), repeat(bodies_in_quarter), repeat(to_subtract), repeat(n_components), repeat(med_filt), repeat(blend_size), repeat(good_data_mask), repeat(is_outlier), repeat(logger)): for index, image in zip(subtracted_cube_indices, subtracted_images, strict=False): input_cubes[index].data[...] = image logger.info("PCA filtering finished")
[docs] def load_files(input_cubes: list[NDCube], files_to_fit: list[NDCube | str | DataLoader], blend_size: int = 70) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Load files.""" logger = get_run_logger() # Join these two sets of things into one list, sorted by observation time, and keep track of which ones need to # be subtracted and where they are in the original list of input cubes. We sort by observation time to ensure the # staggered dropping of files is spread evenly over time. things_to_load = np.array(input_cubes + files_to_fit, dtype=object) input_list_indices = np.concatenate((range(len(input_cubes)), [-1] * len(files_to_fit))).astype(int) def sort_key(thing: str | NDCube | DataLoader) -> str: if isinstance(thing, str): return os.path.basename(thing) if isinstance(thing, NDCube): return thing.meta["FILENAME"].value return os.path.basename(thing.src_repr()) keys = np.array([sort_key(t) for t in things_to_load]) sort_by_date = np.argsort(keys) things_to_load = things_to_load[sort_by_date] input_list_indices = input_list_indices[sort_by_date] # We'll pre-allocate an array to load files into. Since we'll reject any bad files, we'll track an insertion # index as we add good files, and at the end we'll slice the array to drop any empty spots at the end. When # np.empty allocates memory, the OS doesn't *actually* allocate those pages until they're used, so we won't # actually use any more RAM than needed. all_files_to_fit = np.empty((len(things_to_load), *input_cubes[0].data.shape), dtype=input_cubes[0].data.dtype) index_to_insert = 0 loaded_input_list_indices = [] n_outliers = 0 body_finding_inputs = [] # If a file-to-be-subtracted is an outlier, we want to keep it in the stack so we can still try to subtract it, # but we don't want it to factor in to the fitting, so we mark it here. is_outlier = [] # We want to know which pixels are masked in every image. It's possible we'll have files made with two different # versions of the mask, so here we detect which pixels are maxed by looking for data == 0 and uncertainty == inf, # and we track which pixels satisfy that for every image. is_masked = np.ones(input_cubes[0].data.shape, dtype=bool) for input_file, input_list_index in zip(things_to_load, input_list_indices, strict=False): if isinstance(input_file, NDCube): data, meta = input_file.data, input_file.meta body_finding_input = (input_file.meta, input_file.wcs) uncertainty_is_inf = np.isinf(input_file.uncertainty.array) elif isinstance(input_file, str): cube = load_image_task(input_file, include_provenance=False, include_uncertainty=True) data, meta = cube.data, cube.meta body_finding_input = (cube.meta, cube.wcs) uncertainty_is_inf = np.isinf(cube.uncertainty.array) elif isinstance(input_file, DataLoader): data, meta, wcs, uncertainty_is_inf = input_file.load() body_finding_input = (meta, wcs) else: raise TypeError(f"Invalid type {type(input_file)} for input file") is_good = meta["OUTLIER"].value == 0 if is_good or input_list_index >= 0: loaded_input_list_indices.append(input_list_index) all_files_to_fit[index_to_insert] = data index_to_insert += 1 body_finding_inputs.append(body_finding_input) is_masked *= uncertainty_is_inf * (data == 0) is_outlier.append(not is_good) if not is_good: n_outliers += 1 else: n_outliers += 1 # Crop the unused end of the array all_files_to_fit = all_files_to_fit[:index_to_insert] logger.info(f"Total of {len(all_files_to_fit)} images to fit and/or subtract, filling " f"{all_files_to_fit.nbytes / 1024 ** 3:.2f} GB") logger.info(f"(Drawn from {len(input_cubes)} images to subtract and {len(files_to_fit)} extra images for fitting)") logger.info(f"({n_outliers} outliers were rejected from fitting)") loaded_input_list_indices = np.array(loaded_input_list_indices) is_outlier = np.array(is_outlier) logger.info("Locating planets") # We have a lot of data in memory right now, so forking is expensive as all that memory has to be marked as # copy-on-write. Using a forkserver avoids that work. ctx = mp.get_context("forkserver") with ctx.Pool(min(25, os.cpu_count())) as p: bodies_in_quarter = np.array(p.starmap(find_bodies_in_image_quarters, zip(body_finding_inputs, repeat(blend_size)))) good_data_mask = ~is_masked return all_files_to_fit, bodies_in_quarter, loaded_input_list_indices, good_data_mask, is_outlier
[docs] def pca_filter_one_stride(all_files_to_fit: np.ndarray, stride: int, n_strides: int, bodies_in_quarter: np.ndarray, input_list_indices: np.ndarray, n_components: int, med_filt: int, blend_size: int, good_data_mask: np.ndarray, is_outlier: np.ndarray, logger: logging.Logger, ) -> tuple[np.ndarray, np.ndarray]: """Run PCA-based filtering for one stride position.""" stride_filter = np.arange(len(all_files_to_fit)) % n_strides == stride # This will mark the images we'll be subtracting from---those are the only ones we'll drop from the fitting to_subtract_filter = stride_filter * (input_list_indices >= 0) if not np.any(to_subtract_filter): logger.info(f"Stride {stride} has no images to subtract") return [], [] images_to_subtract = all_files_to_fit[to_subtract_filter] # This tracks where each image-to-be-subtracted is in the main list of NDCubes subtracted_cube_indices = input_list_indices[to_subtract_filter] # The quartering approach that protects from planets/the Moon wrecking the PCA components can leave seams. To # reduce that, we have a small blend region at those seams. Here we define a mask that's 1 in the core of a # quarter and tapers to 0 through the blend region. yy, _ = np.indices(images_to_subtract.shape[1:]) blend_mask = np.clip(((all_files_to_fit.shape[1] / 2 - 1 + blend_size / 2) - yy) / blend_size, 0, 1) blend_mask = blend_mask * blend_mask.T # Flip it around to make one for each quarter blend_masks = [blend_mask, blend_mask[:, ::-1], blend_mask[::-1], blend_mask[::-1, ::-1]] # We need to PCA separately for each quarter of the image for i, mask in enumerate(blend_masks): # We mark the images that don't have any planets in the quarter we're filtering for (since those can # contaminate the PCA components) no_bodies_in_quarter = np.all(bodies_in_quarter[:, :, i] == False, axis=1) # noqa: E712 images_to_fit = all_files_to_fit[no_bodies_in_quarter * ~to_subtract_filter * ~is_outlier] tag = f"stride {stride}, quarter {i+1}" logger.info(f"Starting to filter {tag}, fitting {len(images_to_fit)} images") filtered_by_quarter = run_pca_filtering(images_to_subtract, images_to_fit, n_components, med_filt, tag, good_data_mask, logger) if i == 0: final_reconstruction = mask * filtered_by_quarter else: final_reconstruction += mask * filtered_by_quarter return subtracted_cube_indices, final_reconstruction
[docs] def run_pca_filtering(images_to_subtract: np.ndarray, images_to_fit: np.ndarray, n_components: int, med_filt: int, tag: str, good_data_mask: np.ndarray, logger: logging.Logger) -> np.ndarray: """Run PCA filtering.""" pca = PCA(n_components=n_components) # The image array has to be re-shaped into (n_images, n_pixels). (i.e., PCA wants 1D vectors, not 2D images) pca.fit(images_to_fit[:, good_data_mask]) logger.info(f"Fitting finished for {tag}") transformed = pca.transform(images_to_subtract[:, good_data_mask]) if med_filt: for i in range(len(pca.components_)): comp = np.zeros(images_to_fit.shape[1:], dtype=pca.components_[i].dtype) comp[good_data_mask] = pca.components_[i] comp = scipy.signal.medfilt2d(comp, med_filt) pca.components_[i] = comp[good_data_mask] logger.info(f"Median smoothing finished for {tag}") reconstructed = np.zeros_like(images_to_subtract) reconstructed[:, good_data_mask] = pca.inverse_transform(transformed) return images_to_subtract - reconstructed
[docs] def find_bodies_in_image_quarters(frame: str | NDCube | tuple[NormalizedMetadata, WCS], extra_padding: int = 0) -> list: """Find celestial bodies in image.""" if isinstance(frame, str): header = fits.getheader(frame, 1) wcs = WCS(header) location = header["GEOD_LON"], header["GEOD_LAT"], header["GEOD_ALT"] image_shape = header["NAXIS2"], header["NAXIS1"] elif isinstance(frame, NDCube): location = frame.meta["GEOD_LON"].value, frame.meta["GEOD_LAT"].value, frame.meta["GEOD_ALT"].value wcs = frame.wcs image_shape = frame.data.shape elif isinstance(frame, tuple): meta, wcs = frame location = meta["GEOD_LON"].value, meta["GEOD_LAT"].value, meta["GEOD_ALT"].value image_shape = wcs.array_shape else: msg = "Type of 'frame' not recognized" raise TypeError(msg) results = [] for body in ["Mercury", "Venus", "Moon", "Mars", "Jupiter", "Saturn"]: body_loc = get_body(body, time=Time(wcs.wcs.dateobs), location=EarthLocation.from_geodetic(*location)) x, y = wcs.world_to_pixel(body_loc) # Extra margin for the big moon w = 100 if body == "Moon" else 10 w += extra_padding in_left = 0 - w <= x <= image_shape[1] / 2 + w in_right = image_shape[1] / 2 - w <= x <= image_shape[1] + w in_bottom = 0 - w <= y <= image_shape[0] / 2 + w in_top = image_shape[0] / 2 - w <= y <= image_shape[0] + w body_in_quarter = [ in_left and in_bottom, in_right and in_bottom, in_left and in_top, in_right and in_top] results.append(body_in_quarter) return results