Source code for playnano.analysis.modules.particle_tracking

"""
Module for linking particle features across frames to build trajectories.

This module defines the ParticleTrackingModule, which links features
detected in sequential frames of an AFM image stack using nearest-neighbor
matching based on feature coordinates.

Features are matched across frames if they lie within a specified maximum
distance. Tracks are formed by chaining these matches over time.

Each resulting track includes:
    - A unique track ID
    - A list of frames spanned by the track (first→last detection)
    - A list of point indices aligned with frames; missing detections are None
    - A list of coordinates describing the particle's positions

Optionally, per-track masks are extracted from the labeled feature masks.

See Also
--------
playnano.analysis.modules.feature_detection : Mask based particle detection method.
playnano.analysis.modules.log_blob_detection : LoG-based particle detection method.

.. versionadded:: 0.2.0

Author
------
Daniel E. Rollins (d.e.rollins@leeds.ac.uk) / GitHub: derollins

AI Transparency Note
--------------------
AI-based tools were used for limited typing/formatting assistance
and for debugging, refactoring, and documentation suggestions. All code paths,
algorithms, and final behaviour were reviewed and validated by the author.
"""

from typing import Any, Optional, Sequence

import numpy as np

from playnano.afm_stack import AFMImageStack
from playnano.analysis.base import AnalysisModule


[docs] class ParticleTrackingModule(AnalysisModule): """ Link detected features frame-to-frame to produce particle trajectories. This module links features detected by a prior feature detection module using nearest-neighbor coordinate matching across adjacent frames. A new track is created for each unmatched feature. Version ------- 0.2.0 Version 0.2.0 allows tracks to continue across missing detections and supports scaling the distance threshold with time since last detection. """ version = "0.2.0" @property def name(self) -> str: """ Name of the analysis module. Returns ------- str Unique identifier: "particle_tracking". """ return "particle_tracking" requires = ["feature_detection", "log_blob_detection"] def _get_detection_outputs( self, previous_results: dict[str, Any], *, detection_module: str, coord_key: str, ) -> tuple[list[list[dict]], list[np.ndarray]]: """ Retrieve per-frame features and labeled masks from a detection module. Parameters ---------- previous_results : dict[str, Any] Results from earlier pipeline steps. detection_module : str Preferred detection module name. If not available, the most recent available module from ``self.requires`` is used. coord_key : str Key in the detection output containing per-frame features. Returns ------- feats : list[list[dict]] Per-frame feature dicts. masks : list[np.ndarray] Per-frame labeled masks. Raises ------ RuntimeError If no suitable detection module output is available, or required keys are missing. """ if detection_module in previous_results: chosen = detection_module else: available = [m for m in reversed(self.requires) if m in previous_results] if not available: raise RuntimeError( f"{self.name!r} requires one of {self.requires}, but none were found in previous results." # noqa ) chosen = available[0] fd_out = previous_results[chosen] if coord_key not in fd_out: raise RuntimeError( f"{self.name!r} expected detection output {chosen!r} to contain {coord_key!r}." # noqa ) if "labeled_masks" not in fd_out: raise RuntimeError( f"{self.name!r} expected detection output {chosen!r} to contain 'labeled_masks'." # noqa ) return fd_out[coord_key], fd_out["labeled_masks"] def _extract_coords( self, f: dict, coord_columns: Sequence[str] ) -> tuple[float, float]: """ Extract 2D coordinates from a feature dictionary. Parameters ---------- f : dict Feature dictionary. coord_columns : Sequence[str] Keys to use for coordinates. If missing, falls back to ``f["centroid"]``. Returns ------- coords : tuple[float, float] (coord0, coord1) coordinate pair. Raises ------ KeyError If neither ``coord_columns`` nor a valid ``centroid`` entry is present. """ try: return tuple(float(f[k]) for k in coord_columns) except KeyError: c = f.get( "centroid" ) # 'centroid' is an output of playnano.analysis.modules.feature_detection if not c or len(c) < 2: raise KeyError( f"Missing coordinate keys {coord_columns} and fallback 'centroid'" ) from None return tuple(c[:2]) def _scaled_threshold(self, base: float, dt: int, mode: str) -> float: """ Scale a base distance threshold by a frame gap. Parameters ---------- base : float Base threshold (applied when dt == 1). dt : int Frame difference since the last detection (>= 1). mode : str One of {"constant", "linear", "sqrt"}. Returns ------- float Scaled threshold. Raises ------ ValueError If ``mode`` is not supported. """ if mode == "constant": return base if mode == "linear": return base * dt if mode == "sqrt": return base * float(np.sqrt(dt)) raise ValueError( "distance_scale must be one of {'constant','sqrt','linear'}; " f"got {mode!r}" ) def _densify_track(self, trk: dict[str, Any]) -> None: """ Convert sparse detections in a track to a dense within-span representation. The track's ``frames``, ``coords`` and ``point_indices`` arrays are replaced in-place so they span from the first to last detection (inclusive). Missing detections within the span are represented by ``None`` entries. Parameters ---------- trk : dict[str, Any] Track dictionary with sparse lists: - frames : list[int] - coords : list[tuple[float, float]] - point_indices : list[int] Returns ------- None The input track dict is modified in-place. """ det_frames = trk["frames"] det_coords = trk["coords"] det_idxs = trk["point_indices"] start = det_frames[0] end = det_frames[-1] span_frames = list(range(start, end + 1)) coords_dense: list[Optional[tuple[float, float]]] = [None] * len(span_frames) idx_dense: list[Optional[int]] = [None] * len(span_frames) for f, c, i in zip(det_frames, det_coords, det_idxs, strict=False): j = f - start coords_dense[j] = c idx_dense[j] = i trk["frames"] = span_frames trk["coords"] = coords_dense trk["point_indices"] = idx_dense def _last_detection(self, trk: dict[str, Any]) -> Optional[tuple[int, int]]: """ Find the last valid (frame, point_index) pair in a dense track. Parameters ---------- trk : dict[str, Any] Track dictionary containing ``frames`` and ``point_indices`` lists. Returns ------- (frame_idx, point_idx) : tuple[int, int] or None Last detection pair, or None if no detections are present. """ for t_f, i_f in zip( reversed(trk["frames"]), reversed(trk["point_indices"]), strict=False ): if i_f is not None: return t_f, i_f return None
[docs] def run( self, stack: AFMImageStack, previous_results: Optional[dict[str, Any]] = None, detection_module: str = "feature_detection", coord_key: str = "features_per_frame", coord_columns: Sequence[str] = ("centroid_x", "centroid_y"), max_distance: float = 5.0, max_missing: int = 0, distance_scale: str = "constant", # {"constant","sqrt","linear"} **params, ) -> dict[str, Any]: """ Track particles across frames using nearest-neighbor association. Parameters ---------- stack : AFMImageStack The input AFM image stack. previous_results : dict[str, Any], optional Must contain results from a detection module, including: - coord_key (e.g., "features_per_frame"): list of dicts with per-frame features - "labeled_masks": per-frame mask of label regions detection_module : str, optional Which module to read features from (default: "feature_detection"). coord_key : str, optional Key in previous_results[detection_module] containing per-frame feature dicts (default: "features_per_frame"). coord_columns : Sequence[str], optional Keys to extract coordinates from each feature; falls back to "centroid" if needed. Default is ("centroid_x", "centroid_y"). max_distance : float, optional Maximum allowed movement per frame in coordinate units (default: 5.0). max_missing : int, optional Maximum allowed consecutive missing detections before terminating a track. Default is 0 (no missing allowed). distance_scale : str, optional How to scale max_distance with time since last detection in frames: "constant" (default), "linear" (scale linearly with time), or "sqrt" (scale with square root of time). Returns ------- dict Dictionary with keys: - tracks : list of dict Per-track dictionaries containing: - id : int Track ID - frames : list[int] Frame indices spanned by track - point_indices : list[Optional[int]] Indices into features_per_frame - coords : list[Optional[tuple[float, float]]] Coordinates (coord0, coord1) aligned with frames; None indicates missing detection. - track_masks : dict[int, np.ndarray] Last mask per track - n_tracks : int Total number of tracks """ if previous_results is None: raise RuntimeError(f"{self.name!r} requires previous results to run.") if max_missing < 0: raise ValueError("max_missing must be >= 0") if max_distance < 0: raise ValueError("max_distance must be >= 0") feats, masks = self._get_detection_outputs( previous_results, detection_module=detection_module, coord_key=coord_key, ) n_frames = len(feats) tracks = [] next_track_id = 0 # List of dictionaries, each: {"id": int, "last_coord": (x,y), # "last_frame": int, "missing": int} active_tracks: list[dict[str, Any]] = [] for t in range(n_frames): this_feats = feats[t] assigned: set[int] = set() new_active: list[dict[str, Any]] = [] # Match existing tracks to nearest features for state in active_tracks: trk_id = state["id"] last_coord = state["last_coord"] last_frame = state["last_frame"] missing = state["missing"] dt = t - last_frame # time since last detection (in frames) dist_thresh = self._scaled_threshold(max_distance, dt, distance_scale) best = None best_idx = None best_dist = dist_thresh for i, f in enumerate(this_feats): if i in assigned: continue coords = self._extract_coords(f, coord_columns) dist = np.hypot( coords[0] - last_coord[0], coords[1] - last_coord[1] ) if dist < best_dist: best_dist, best, best_idx = dist, coords, i if best is not None: track = tracks[trk_id] # track id equals index in tracks list track["frames"].append(t) track["coords"].append(best) track["point_indices"].append(best_idx) assigned.add(best_idx) new_active.append( { "id": trk_id, "last_coord": best, "last_frame": t, "missing": 0, } ) else: missing += 1 if missing <= max_missing: new_active.append( { "id": trk_id, "last_coord": last_coord, "last_frame": last_frame, "missing": missing, } ) # else: retire (drop from active) # Start new tracks for unmatched detections for i, f in enumerate(this_feats): if i in assigned: continue coords = self._extract_coords(f, coord_columns) trk = { "id": next_track_id, "frames": [t], "coords": [coords], "point_indices": [i], } tracks.append(trk) new_active.append( { "id": next_track_id, "last_coord": coords, "last_frame": t, "missing": 0, } ) next_track_id += 1 active_tracks = new_active for trk in tracks: self._densify_track(trk) # Generate per-track masks from last known frame/feature track_masks = {} for trk in tracks: last_det = self._last_detection(trk) if last_det is None: continue t_last, i_last = last_det label = feats[t_last][i_last].get("label") if label is not None: track_masks[trk["id"]] = masks[t_last] == label return { "tracks": tracks, "track_masks": track_masks, "n_tracks": len(tracks), }