# This is the alignment code that's run by the parallel workers. Since we can't fork under prefect, each worker has
# to freshly import the file containing the code it'll run. By moving the code into its own file, we cut the number
# of imports. Each worker's import work drops from ~4.5 s to ~3 s by doing this.
import astropy.units as u
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS, NoConvergence
from lmfit import Parameters, minimize
from scipy.spatial import KDTree
[docs]
def refine_pointing_single_step(
guess_wcs: WCS, observed_tree: KDTree, catalog_stars: SkyCoord, method: str = "least_squares",
ra_tolerance: float = 10, dec_tolerance: float = 5,
fix_crval: bool = False, fix_crota: bool = False, fix_pv: bool = True) -> WCS:
"""
Perform a single step of pointing refinement.
Parameters
----------
guess_wcs : WCS
the initial guess for the world coordinate system
observed_tree: KDTree
coordinates of the observed star positions extracted from the image, as a tree
catalog_stars : SkyCoord
the coordinates of known stars to be matched with the observed stars
method : str
method used by lmfit for minimization
ra_tolerance : float
how many degrees the guess WCS is allowed to be incorrect by in right ascension
dec_tolerance : float
how many degrees the guess WCS is allowed to be incorrect by in declination
fix_crval : bool
if True the crval is not allowed to vary, otherwise it can be fit
fix_crota : bool
if True the crota is not allowed to vary, otherwise it can be fit
fix_pv : bool
if True the pv is not allowed to vary, otherwise it can be fit
Returns
-------
WCS
the new world coordinate system
"""
guess_wcs = guess_wcs.deepcopy()
# set up the optimization
params = Parameters()
initial_crota = extract_crota_from_wcs(guess_wcs)
params.add("crota", value=initial_crota.to(u.rad).value,
min=-np.pi, max=np.pi, vary=not fix_crota)
params.add("crval1", value=guess_wcs.wcs.crval[0],
min=guess_wcs.wcs.crval[0] - ra_tolerance,
max=guess_wcs.wcs.crval[0] + ra_tolerance, vary=not fix_crval)
params.add("crval2", value=guess_wcs.wcs.crval[1],
min=guess_wcs.wcs.crval[1] - dec_tolerance,
max=guess_wcs.wcs.crval[1] + dec_tolerance, vary=not fix_crval)
params.add("platescale", value=abs(guess_wcs.wcs.cdelt[0]), min=0, max=1, vary=False)
pv = guess_wcs.wcs.get_pv()[0][-1] if guess_wcs.wcs.get_pv() else 0.0
params.add("pv", value=pv, min=0.0, max=1.0, vary=not fix_pv)
ra = catalog_stars.ra.to_value(u.deg)
dec = catalog_stars.dec.to_value(u.deg)
with np.errstate(all="ignore"):
out = minimize(_residual, params, method=method,
args=(ra, dec, observed_tree, guess_wcs),
max_nfev=1000, calc_covar=False)
return (out.params["platescale"].value, out.params["crval1"].value, out.params["crval2"].value,
out.params["crota"].value, out.params["pv"].value)
[docs]
def _residual(params: Parameters,
ra: np.ndarray,
dec: np.ndarray,
observed_tree: KDTree,
guess_wcs: WCS,
max_error: float = 30) -> float:
"""
Residual used when optimizing the pointing.
Parameters
----------
params : Parameters
optimization parameters from lmfit
ra, dec : np.ndarray
expected coordinates of the stars, in degrees
observed_tree : KDTree
a KDTree of the pixel coordinates of the observed stars
guess_wcs : WCS
initial guess of the world coordinate system, must overlap with the true WCS
max_error: float
stars more distant than this are complete misses, and their error is zeroed out
Returns
-------
np.ndarray
residual
"""
guess_wcs.wcs.cdelt = (-params["platescale"].value, params["platescale"].value)
guess_wcs.wcs.crval = (params["crval1"].value, params["crval2"].value)
guess_wcs.wcs.pc = np.array(
[
[np.cos(params["crota"]), -np.sin(params["crota"])],
[np.sin(params["crota"]), np.cos(params["crota"])],
],
)
errors, _ = get_errors(guess_wcs, (ra, dec), observed_tree, catalog_stars_in_pixels=False)
errors = errors[errors < max_error]
return np.nansum(np.abs(errors)) / len(errors)
[docs]
def get_errors(wcs: WCS, catalog_stars: SkyCoord | tuple[np.ndarray, np.ndarray],
observed_stars: np.ndarray | KDTree,
catalog_stars_in_pixels: bool = True) -> tuple[np.ndarray, np.ndarray]:
"""Compute errors between expected and observed star locations."""
if isinstance(observed_stars, np.ndarray):
observed_stars = KDTree(observed_stars)
if isinstance(catalog_stars, SkyCoord):
try:
xs, ys = catalog_stars.to_pixel(wcs, mode="all")
except NoConvergence as e:
xs, ys = e.best_solution[:, 0], e.best_solution[:, 1]
elif not catalog_stars_in_pixels:
xs, ys = wcs.world_to_pixel_values(*catalog_stars)
else:
xs, ys = catalog_stars
refined_coords = np.stack([xs, ys], axis=-1)
errors = np.empty(refined_coords.shape[0])
closest_stars = np.empty(refined_coords.shape)
for coord_i, coord in enumerate(refined_coords):
dd, ii = observed_stars.query(coord, k=1)
errors[coord_i] = dd
closest_stars[coord_i] = observed_stars.data[ii]
return errors, closest_stars