Source code for punchbowl.level1.alignment_parallel

# This is the alignment code that's run by the parallel workers. Since we can't fork under prefect, each worker has
# to freshly import the file containing the code it'll run. By moving the code into its own file, we cut the number
# of imports. Each worker's import work drops from ~4.5 s to ~3 s by doing this.


import astropy.units as u
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS, NoConvergence
from lmfit import Parameters, minimize
from scipy.spatial import KDTree


[docs] def refine_pointing_single_step( guess_wcs: WCS, observed_tree: KDTree, catalog_stars: SkyCoord, method: str = "least_squares", ra_tolerance: float = 10, dec_tolerance: float = 5, fix_crval: bool = False, fix_crota: bool = False, fix_pv: bool = True) -> WCS: """ Perform a single step of pointing refinement. Parameters ---------- guess_wcs : WCS the initial guess for the world coordinate system observed_tree: KDTree coordinates of the observed star positions extracted from the image, as a tree catalog_stars : SkyCoord the coordinates of known stars to be matched with the observed stars method : str method used by lmfit for minimization ra_tolerance : float how many degrees the guess WCS is allowed to be incorrect by in right ascension dec_tolerance : float how many degrees the guess WCS is allowed to be incorrect by in declination fix_crval : bool if True the crval is not allowed to vary, otherwise it can be fit fix_crota : bool if True the crota is not allowed to vary, otherwise it can be fit fix_pv : bool if True the pv is not allowed to vary, otherwise it can be fit Returns ------- WCS the new world coordinate system """ guess_wcs = guess_wcs.deepcopy() # set up the optimization params = Parameters() initial_crota = extract_crota_from_wcs(guess_wcs) params.add("crota", value=initial_crota.to(u.rad).value, min=-np.pi, max=np.pi, vary=not fix_crota) params.add("crval1", value=guess_wcs.wcs.crval[0], min=guess_wcs.wcs.crval[0] - ra_tolerance, max=guess_wcs.wcs.crval[0] + ra_tolerance, vary=not fix_crval) params.add("crval2", value=guess_wcs.wcs.crval[1], min=guess_wcs.wcs.crval[1] - dec_tolerance, max=guess_wcs.wcs.crval[1] + dec_tolerance, vary=not fix_crval) params.add("platescale", value=abs(guess_wcs.wcs.cdelt[0]), min=0, max=1, vary=False) pv = guess_wcs.wcs.get_pv()[0][-1] if guess_wcs.wcs.get_pv() else 0.0 params.add("pv", value=pv, min=0.0, max=1.0, vary=not fix_pv) ra = catalog_stars.ra.to_value(u.deg) dec = catalog_stars.dec.to_value(u.deg) with np.errstate(all="ignore"): out = minimize(_residual, params, method=method, args=(ra, dec, observed_tree, guess_wcs), max_nfev=1000, calc_covar=False) return (out.params["platescale"].value, out.params["crval1"].value, out.params["crval2"].value, out.params["crota"].value, out.params["pv"].value)
[docs] def _residual(params: Parameters, ra: np.ndarray, dec: np.ndarray, observed_tree: KDTree, guess_wcs: WCS, max_error: float = 30) -> float: """ Residual used when optimizing the pointing. Parameters ---------- params : Parameters optimization parameters from lmfit ra, dec : np.ndarray expected coordinates of the stars, in degrees observed_tree : KDTree a KDTree of the pixel coordinates of the observed stars guess_wcs : WCS initial guess of the world coordinate system, must overlap with the true WCS max_error: float stars more distant than this are complete misses, and their error is zeroed out Returns ------- np.ndarray residual """ guess_wcs.wcs.cdelt = (-params["platescale"].value, params["platescale"].value) guess_wcs.wcs.crval = (params["crval1"].value, params["crval2"].value) guess_wcs.wcs.pc = np.array( [ [np.cos(params["crota"]), -np.sin(params["crota"])], [np.sin(params["crota"]), np.cos(params["crota"])], ], ) errors, _ = get_errors(guess_wcs, (ra, dec), observed_tree, catalog_stars_in_pixels=False) errors = errors[errors < max_error] return np.nansum(np.abs(errors)) / len(errors)
[docs] def get_errors(wcs: WCS, catalog_stars: SkyCoord | tuple[np.ndarray, np.ndarray], observed_stars: np.ndarray | KDTree, catalog_stars_in_pixels: bool = True) -> tuple[np.ndarray, np.ndarray]: """Compute errors between expected and observed star locations.""" if isinstance(observed_stars, np.ndarray): observed_stars = KDTree(observed_stars) if isinstance(catalog_stars, SkyCoord): try: xs, ys = catalog_stars.to_pixel(wcs, mode="all") except NoConvergence as e: xs, ys = e.best_solution[:, 0], e.best_solution[:, 1] elif not catalog_stars_in_pixels: xs, ys = wcs.world_to_pixel_values(*catalog_stars) else: xs, ys = catalog_stars refined_coords = np.stack([xs, ys], axis=-1) errors = np.empty(refined_coords.shape[0]) closest_stars = np.empty(refined_coords.shape) for coord_i, coord in enumerate(refined_coords): dd, ii = observed_stars.query(coord, k=1) errors[coord_i] = dd closest_stars[coord_i] = observed_stars.data[ii] return errors, closest_stars
[docs] def extract_crota_from_wcs(wcs: WCS) -> tuple[float, float]: """Extract CROTA from a WCS.""" delta_ratio = abs(wcs.wcs.cdelt[1]) / abs(wcs.wcs.cdelt[0]) return (np.arctan2(wcs.wcs.pc[1, 0] / delta_ratio, wcs.wcs.pc[0, 0])) * u.rad