Source code for punchbowl.prefect
import logging
from typing import Any
from functools import cache
from collections.abc import Callable
from httpx import ConnectError
from prefect import Flow, Task, flow, get_run_logger, runtime, task
from prefect.cache_policies import NO_CACHE
from prefect.client.schemas.objects import TaskRun
from prefect.states import State
from prefect.variables import Variable
from punchbowl.data.punch_io import get_base_file_name, write_ndcube_to_fits
from punchbowl.data.punchcube import PUNCHCube
[docs]
def completion_debugger(task: Task, task_run: TaskRun, state: State) -> None:
"""Run on task completion during debug mode."""
if Variable.get("debug", False):
cube = state.result()
if isinstance(cube, PUNCHCube):
new_filename = f"{get_base_file_name(cube)}_{task.name}.fits"
write_ndcube_to_fits(cube, new_filename, overwrite=True, write_hash=False)
elif isinstance(cube, list):
for i, c in enumerate(cube):
new_filename = f"{get_base_file_name(c)}_{task.name}_{i}.fits"
write_ndcube_to_fits(c, new_filename, overwrite=True, write_hash=False)
else:
logger = get_run_logger()
logger.error(f"Cannot write debug output for {task} {task_run} in {state}.")
[docs]
def failure_hook(task: Task, task_run: TaskRun, state: State) -> None:
"""Run if a punch_task fails."""
[docs]
def punch_task(func: Callable | None = None, **kwargs: Any) -> Task | Callable:
"""Prefect task that does PUNCH special things."""
if detect_if_running_in_prefect():
# Delegate everything to Prefect
return task(func, **kwargs,
on_completion=[completion_debugger] if _debug_mode else [],
on_failure=[failure_hook],
cache_policy=NO_CACHE)
if func is None:
# We've been used as @punch_task() or @punch_task(arg=val), so we are to return a function that does the
# decoration
return _compatability_decorator
# We've been used as @punch_task, so we are to do the decoration directly
return _compatability_decorator(func)
[docs]
def punch_flow(func: Callable | None = None, **kwargs: Any) -> Flow | Callable:
"""Prefect flow that does PUNCH special things."""
if detect_if_running_in_prefect():
# Delegate everything to Prefect
return flow(func, **kwargs, validate_parameters=False)
if func is None:
# We've been used as @punch_task() or @punch_task(arg=val), so we are to return a function that does the
# decoration
return _compatability_decorator
# We've been used as @punch_task, so we are to do the decoration directly
return _compatability_decorator(func)
[docs]
def _compatability_decorator(func: Callable) -> Callable:
"""Make wrapped functions have a .fn attribute like Prefect Flows and Tasks."""
func.fn = func
func.submit = lambda *args, **kwargs: _CompatabilitySubmitResult(func(*args, **kwargs))
return func
[docs]
class _CompatabilitySubmitResult:
def __init__(self, ret_val: Any) -> None:
self.ret_val = ret_val
[docs]
def result(self) -> Any:
return self.ret_val
[docs]
def wait(self) -> None:
return
[docs]
@cache
def detect_if_running_in_prefect() -> bool:
"""Determine if we're running under Prefect."""
return runtime.flow_run.name is not None
[docs]
def get_logger() -> logging.Logger:
"""Get a logger, which will be the Prefect logger if we're running under Prefect."""
if detect_if_running_in_prefect():
return get_run_logger()
return logging.getLogger("punchbowl")
try:
_debug_mode = Variable.get("debug", False) if detect_if_running_in_prefect() else False
except (ConnectError, RuntimeError):
_debug_mode = False