from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import polars
from . import t4_camera
from . import t4_hist
from . import t4_plutils
[docs]
def add_bin_count(
events: Union[polars.DataFrame, polars.LazyFrame],
x_bin_size: int,
y_bin_size: int,
toa_bin_size: float,
) -> Union[polars.DataFrame, polars.LazyFrame]:
"""
For each event add the voxel counts to which the event belongs.
"""
# Compute voxel indices
events = events.with_columns(
x_bin=(polars.col("x") // x_bin_size),
y_bin=(polars.col("y") // y_bin_size),
toa_bin=(polars.col("toa48_ns") // toa_bin_size),
)
# For each row, count how many other rows have the same voxel
events = events.with_columns(
bin_count=polars.count().over(["x_bin", "y_bin", "toa_bin"])
)
# Drop temporary voxel columns
events = events.drop(["x_bin", "y_bin", "toa_bin"])
return events
[docs]
def add_dist_to_poni(
events: Union[polars.DataFrame, polars.LazyFrame],
x_cen: Optional[float] = None,
y_cen: Optional[float] = None,
) -> Tuple[Union[polars.DataFrame, polars.LazyFrame], int, int]:
"""
Add distance to beam center for each event and return beam center coordinates.
"""
if y_cen is None or x_cen is None:
# Find the beam center
counts_2d = t4_hist.hist_2d(events)
y_cen, x_cen = numpy.unravel_index(numpy.argmax(counts_2d), counts_2d.shape)
# Add distance to primary beam
R2 = (polars.col("x") - x_cen) ** 2 + (polars.col("y") - y_cen) ** 2
events = events.with_columns(dist_to_poni=R2.sqrt())
return events, x_cen, y_cen
[docs]
def filter_events(
events: Union[polars.DataFrame, polars.LazyFrame], radius: float
) -> Union[polars.DataFrame, polars.LazyFrame]:
"""
Keep only primary beam and singleton events.
"""
mask = (polars.col("bin_count") == 1) | (polars.col("dist_to_poni") <= radius)
return events.filter(mask).drop(["bin_count", "dist_to_poni"])
[docs]
def filter_missed_dpx(
events: Union[polars.DataFrame, polars.LazyFrame],
period: float,
threshold: float = 50.0,
) -> Tuple[Union[polars.DataFrame, polars.LazyFrame], int, int]:
"""
Remove events belonging to missed DPX trigger periods.
"""
# Sort events by TOA
events = events.sort(polars.col("toa48_ns"))
# Add row index for joining later
with_index = events.with_row_index()
# Time difference with the previous event exceeds the threshold
missed = (polars.col("toa48_ns").diff() - period).abs() > threshold
# Select the DXP pixel events and flag whether they came later
# than the threshold w.r.t. the previous DXP pixel event
dpx_miss = with_index.filter(t4_camera.is_dpx_bot_pixel).select(
polars.col("index"),
dpx_miss=missed,
)
# Count total DPX events and how many are missed
stats = dpx_miss.select(n_total=polars.count(), n_missed=polars.sum("dpx_miss"))
stats = t4_plutils.collect(stats)
n_total = stats["n_total"][0]
n_missed = stats["n_missed"][0]
if n_missed == 0:
# Nothing to filter, return original events
return events, n_missed, n_total
# Propagate missed flags to all events, not just the ones in the DXP pixel
all_miss = with_index.join(dpx_miss, on="index", how="left").select(
polars.col("index"), polars.col("dpx_miss").backward_fill()
)
# Filter out all events flagged as missed
not_missed = (
with_index.join(all_miss, on="index", how="left")
.filter(~polars.col("dpx_miss"))
.drop(["dpx_miss", "index"])
)
return not_missed, n_missed, n_total
[docs]
def compute_reltoa(
events: Union[polars.DataFrame, polars.LazyFrame],
) -> Union[polars.DataFrame, polars.LazyFrame]:
"""
Compute relative TOA after DPX trigger.
"""
# DPX bottom pixel fires once per trigger period.
# That defines the start of a timing window.
# Time of the last trigger for each event
trigger_toa48_ns = (
polars.when(t4_camera.is_dpx_bot_pixel)
.then(polars.col("toa48_ns"))
.forward_fill()
)
# Add time of event since the last trigger
toa_since_trigger = polars.col("toa48_ns") - polars.col("trigger_toa48_ns")
return (
events.with_columns(trigger_toa48_ns=trigger_toa48_ns)
.with_columns(rel_toa_ns=toa_since_trigger)
.drop("trigger_toa48_ns")
)