from collections.abc import Iterable
import matplotlib.pyplot as plt
import numpy as np
from astropy.io.fits import Header
from punchbowl.data import NormalizedMetadata
[docs]
class Limit:
"""
Represents a limit for identifying outliers.
A limit is an x-y curve, with x and y being values from an image's FITS header. A comparison sense is defined (i.e.
'<' or '>', and a given images is considered "good" or "bad" based on how its (x, y) value compares to that curve.
"""
def __init__(self, xkey: str, xs: np.ndarray, ykey: str, ys: np.ndarray, comp: str) -> None:
"""
Create a Limit.
Parameters
----------
xkey : str
The key to read from a header for the x value
xs : np.ndarray
The x values defining the limit curve
ykey : str
The key to read from a header for the y value
ys : np.ndarray
The y values defining the limit curve
comp : str
How to compare a given image to the curve
"""
self.xkey = xkey
self.xs = xs
self.ykey = ykey
self.ys = ys
self.comp = comp
[docs]
def is_good(self, point: Header | NormalizedMetadata | Iterable) -> bool | np.ndarray:
"""Check if a point satisfies a limit."""
if isinstance(point, Header):
x = point[self.xkey]
y = point[self.ykey]
elif isinstance(point, NormalizedMetadata):
x = point[self.xkey].value
y = point[self.ykey].value
elif isinstance(point, Iterable) and isinstance(point[0], (Header, NormalizedMetadata)):
return np.array([self.is_good(p) for p in point])
else:
x, y = point
limit_value = np.interp(x, self.xs, self.ys)
match self.comp:
case "<":
return y < limit_value
case ">":
return y > limit_value
case "=" | "==":
return y == limit_value
case ">=":
return y >= limit_value
case "<=":
return y <= limit_value
case _:
ruff_says = "Unrecognized comparison type"
raise ValueError(ruff_says)
[docs]
def plot(self, points: list[Header | NormalizedMetadata | Iterable] | None = None) -> None:
"""Plot the limit."""
plt.plot(self.xs, self.ys, color="C1")
if points:
if isinstance(points[0], NormalizedMetadata):
xs = [p[self.xkey].value for p in points]
ys = [p[self.ykey].value for p in points]
elif isinstance(points[0], Header):
xs = [p[self.xkey] for p in points]
ys = [p[self.ykey] for p in points]
else:
xs, ys = points
xs = np.asarray(xs)
ys = np.asarray(ys)
plt.scatter(xs, ys, s=10, color="C0")
is_bad = ~np.array([self.is_good(p) for p in zip(xs, ys, strict=False)])
plt.scatter(xs[is_bad], ys[is_bad], s=10, alpha=.5, color="C3")
plt.xlabel(self.xkey)
plt.ylabel(self.ykey)
[docs]
def serialize(self) -> tuple:
"""Convert the limit to a tuple."""
return (np.array((self.xkey, *self.xs)),
np.array((self.ykey, *self.ys)),
np.array((self.comp,)))
[docs]
@staticmethod
def from_serialized(serialized: tuple) -> "Limit":
"""Convert a tuple to a limit."""
x, y, comp = serialized
xkey = x[0].item()
xs = np.array([float(xx) for xx in x[1:]])
ykey = y[0].item()
ys = np.array([float(yy) for yy in y[1:]])
comp = comp.item()
return Limit(xkey, xs, ykey, ys, comp)
[docs]
def __repr__(self) -> str:
"""Repr."""
return f"Limit[{self.xkey}, {self.ykey}, {self.comp}]"
[docs]
class LimitSet:
"""
Represents a set of limits.
A given image will be determined to be "good" or "bad" based on whether it satisfies all the limits.
"""
def __init__(self, limits: list[Limit] | None = None) -> None:
"""Create a LimitSet."""
self.limits = limits or []
[docs]
def add(self, limit: Limit) -> None:
"""Add a Limit to the set."""
self.limits.append(limit)
[docs]
def is_good(self, point: Header | NormalizedMetadata | Iterable) -> bool | np.ndarray:
"""Check if a point satisfies all limits."""
ok = self.limits[0].is_good(point)
for limit in self.limits[1:]:
ok = np.logical_and(ok, limit.is_good(point))
return ok
[docs]
def plot(self, points: list[Header] | None = None, xkey: str | None = None, ykey: str | None = None) -> None:
"""Plot the limits."""
xkey = xkey or self.limits[0].xkey
ykey = ykey or self.limits[0].ykey
for limit in self.limits:
if limit.xkey == xkey and limit.ykey == ykey:
limit.plot()
if points:
xs = np.asarray([p[xkey] for p in points])
ys = np.asarray([p[ykey] for p in points])
plt.scatter(xs, ys, s=10, color="C0")
is_bad = ~np.array([self.is_good(p) for p in points])
plt.scatter(xs[is_bad], ys[is_bad], s=10, alpha=.5, color="C3")
plt.xlabel(xkey)
plt.ylabel(ykey)
[docs]
def to_file(self, path: str) -> None:
"""Write to a file."""
data = {}
for i, limit in enumerate(self.limits):
x, y, comp = limit.serialize()
data[f"x{i}"] = x
data[f"y{i}"] = y
data[f"comp{i}"] = comp
data["n_limits"] = len(self.limits)
np.savez(path, **data)
[docs]
@staticmethod
def from_file(path: str) -> "LimitSet":
"""Load from a file."""
data = np.load(path)
n_limits = data["n_limits"]
limit_set = LimitSet()
for i in range(n_limits):
x = data[f"x{i}"]
y = data[f"y{i}"]
comp = data[f"comp{i}"]
limit_set.add(Limit.from_serialized((x, y, comp)))
return limit_set