Source code for playnano.analysis.utils.particles

"""
Particle-based postprocessing helpers.

These functions take the raw outputs of feature detection and
particle tracking modules and turn them into tabular data,
plots, and CSV/HDF5 exports.
"""

from pathlib import Path
from typing import Any, Mapping, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


[docs] def flatten_particle_features( grouping_output: Mapping[str, Any], detection_output: Mapping[str, Any], *, object_key: Optional[str] = None, object_id_field: str = "cluster_id", frame_key: str = "frames", index_key: str = "point_indices", ) -> pd.DataFrame: """ Build a DataFrame linking each grouping analysis results to detected features. Each object (e.g. cluster or track) is linked to its corresponding detected feature metadata using (frame index, point index) pairs to locate features in the output of a feature detection step and merges metadata into one flattened table. Parameters ---------- grouping_output : dict Dictionary from a grouping module (e.g. clustering or tracking). Must contain a list of group objects under the `object_key`, where each object has lists of `frames` and `point_indices`. detection_output : dict Dictionary from a detection module (e.g. feature_detection), which must contain the key 'features_per_frame': a list of feature dicts per frame. object_key : str, optional Key in `grouping_output` pointing to the list of group objects. Default is "clusters". object_id_field : str, optional Column name to use in the output DataFrame to identify the group, e.g., "cluster_id" or "track_id". Default is "cluster_id". frame_key : str, optional Key in each group object listing the frames the object appears in. Default is "frames". index_key : str, optional Key in each group object listing the per-frame point indices (used to match detections in `features_per_frame`). Default is "point_indices". Returns ------- pd.DataFrame Flattened DataFrame linking features to group membership. Includes feature metadata and: - object_id_field (e.g. "cluster_id") - frame - timestamp - label - centroid_x, centroid_y - area - mean_intensity - min_intensity - max_intensity """ if object_key is None: if "tracks" in grouping_output: object_key = "tracks" object_id_field = "track_id" elif "clusters" in grouping_output: object_key = "clusters" object_id_field = "cluster_id" else: raise ValueError("Unable to autodetect object_key. Please specify.") features_per_frame = detection_output.get("features_per_frame", []) rows = [] for obj in grouping_output.get(object_key, []): cid = obj["id"] frames = obj.get(frame_key) point_indices = obj.get(index_key) if frames is None or point_indices is None: raise KeyError( f"Grouping objects must have '{frame_key}' and '{index_key}' lists" ) for frame_idx, pt_idx in zip(frames, point_indices, strict=False): # Defensive: skip if index out of range if frame_idx >= len(features_per_frame): continue frame_features = features_per_frame[frame_idx] if pt_idx >= len(frame_features): continue feat = frame_features[pt_idx] # Build row dict row = row = { object_id_field: cid, "frame": frame_idx, "timestamp": feat.get("frame_timestamp", np.nan), "label": feat.get("label", None), # Follow scikit‑image’s convention for coordinatles, row, col i.e. y, x "centroid_x": feat["centroid"][1], # col (x) "centroid_y": feat["centroid"][0], # row (y) "area": feat.get("area", np.nan), "mean": feat.get("mean", np.nan), "min": feat.get("min", np.nan), "max": feat.get("max", np.nan), } rows.append(row) return pd.DataFrame(rows)
[docs] def plot_particle_labels_3d( df: pd.DataFrame, object_id_field: str = "track_id", ax: Optional[plt.Axes] = None, save_to: Optional[Path] = None, cmap: str = "tab10", ) -> plt.Axes: """ Plot particle ids in 3D (x, y, time), colored by object ID. Parameters ---------- df : pandas.DataFrame Must contain ['centroid_x','centroid_y','timestamp', object_id_field] object_id_field : str Column to use for color grouping (e.g. "track_id", "cluster_id") ax : matplotlib Axes, optional A 3D Axes to draw into, or None to create a new one. save_to : Path, optional If given, save the figure to file. cmap : str Colormap name for particle group colors. Returns ------- ax : Axes The 3D axes used. """ from mpl_toolkits.mplot3d import Axes3D # noqa if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ids = df[object_id_field].unique() colors = plt.get_cmap(cmap)(np.linspace(0, 1, len(ids))) for oid, c in zip(ids, colors, strict=False): sub = df[df[object_id_field] == oid] ax.scatter( sub["centroid_x"], sub["centroid_y"], sub["timestamp"], label=f"{object_id_field} {oid}", color=c, ) ax.set_xlabel("X (px)") ax.set_ylabel("Y (px)") ax.set_zlabel("Time (s)") ax.legend() if save_to: ax.get_figure().savefig(save_to, dpi=150) return ax
[docs] def export_particle_csv(df: pd.DataFrame, out_path: Path) -> None: """ Write the flattened track DataFrame to CSV. Parameters ---------- df : pandas.DataFrame out_path : Path Path to write the .csv file (will create parent dirs). """ out_path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(out_path, index=False)