"""
Threshold-based feature detection for AFM image stacks.
Detect features in each frame of an AFM image stack through thresholding methods.
"""
from typing import Any, Optional
import numpy as np
from scipy.ndimage import binary_fill_holes
from skimage.measure import label, regionprops
from skimage.morphology import remove_small_holes
from playnano.analysis.base import AnalysisModule
from playnano.processing.mask_generators import register_masking
from playnano.utils.param_utils import param_conditions
MASK_MAP = register_masking()
[docs]
class FeatureDetectionModule(AnalysisModule):
"""
Detect contiguous features in each frame of an AFM image stack.
This module takes either a user-supplied mask function or a pre-computed boolean
mask array, labels connected regions in each frame, filters them by size and edge
contact, optionally fills holes, and returns per-frame feature statistics and
labeled masks.
Parameters
----------
mask_fn : callable, optional
A function `frame -> bool_2D_array` used to generate a mask for each frame.
Required if `mask_key` is not provided.
mask_key : str, optional
Name of a boolean mask array from a previous analysis (e.g.
`previous_results["your_mask_key"]`). Required if `mask_fn` is not provided.
min_size : int
Minimum area (in pixels) for a region to be kept. Default is 10.
remove_edge : bool
If True, discard any region that touches the frame boundary. Default is True.
fill_holes : bool
If True, fill holes in each mask before labeling. Default is False.
hole_area : int or None
If set, fills only holes smaller than this area. Default is None (all
holes filled).
**mask_kwargs : Any
Additional keyword arguments forwarded to `mask_fn(frame, **mask_kwargs)`.
Raises
------
ValueError
If neither `mask_fn` nor `mask_key` is provided, or if the mask array
has the wrong shape/dtype.
KeyError
If `mask_key` is not found in `previous_results`.
Returns
-------
dict[str, Any]
Dictionary with the following keys:
- features_per_frame : list of list of dict
Per-frame list of feature stats dicts, each with:
- `"frame_timestamp"` : float
- `"label"` : int
- `"area"` : int
- `"min"`, `"max"`, `"mean"` : float
- `"bbox"` : (min_row, min_col, max_row, max_col)
- `"centroid"` : (row, col)
- labeled_masks : list of np.ndarray
The final labeled mask (integer labels) for each frame.
- summary : dict
Aggregate metrics:
- `"total_frames"` : int
- `"total_features"` : int
- `"avg_features_per_frame"` : float
Version
-------
0.1.0
Examples
--------
>>> pipeline.add("feature_detection", mask_fn=mask_mean_offset, min_size=20,
... fill_holes=True, hole_area=50)
>>> result = pipeline.run(stack)
>>> result["summary"]["total_features"]
123
"""
version = "0.1.0"
@property
def name(self) -> str:
"""
Name of the analysis module.
Returns
-------
str
The string identifier for this module: "feature_detection".
"""
return "feature_detection"
def _get_mask_array(
self,
data: np.ndarray,
previous_results: Optional[dict[str, Any]],
mask_fn: Optional[callable],
mask_key: Optional[str],
**mask_kwargs,
) -> np.ndarray:
"""Resolve mask array from previous results or by computing frame-by-frame."""
n_frames, H, W = data.shape
if mask_key is not None:
if not previous_results or mask_key not in previous_results:
raise KeyError(f"mask_key '{mask_key}' not found in previous_results")
mask_arr = previous_results[mask_key]
if not (
isinstance(mask_arr, np.ndarray)
and mask_arr.dtype == bool
and mask_arr.shape == data.shape
):
raise ValueError(
f"previous_results[{mask_key}] must be a boolean ndarray of shape {data.shape}" # noqa
)
return mask_arr
if mask_fn is None:
raise ValueError("Either mask_fn or mask_key must be provided")
# Resolve mask_fn if it's a registered string
if isinstance(mask_fn, str):
if mask_fn not in MASK_MAP:
raise ValueError(
f"mask_fn '{mask_fn}' is not a known registered mask. "
f"Available: {list(MASK_MAP.keys())}"
)
mask_fn = MASK_MAP[mask_fn]
# Compute mask frame-by-frame
mask_arr = np.zeros_like(data, dtype=bool)
for i in range(n_frames):
try:
mf = mask_fn(data[i], **mask_kwargs)
except TypeError:
mf = mask_fn(data[i])
if not (
isinstance(mf, np.ndarray) and mf.dtype == bool and mf.shape == (H, W)
):
raise ValueError(f"mask_fn returned invalid mask for frame {i}")
mask_arr[i] = mf
return mask_arr
def _process_frame(
self,
frame: np.ndarray,
mask_frame: np.ndarray,
frame_ts: float,
*,
min_size: int,
remove_edge: bool,
fill_holes: bool,
hole_area: Optional[int],
) -> tuple[list[dict[str, Any]], np.ndarray]:
"""Process a single frame: hole fill, labeling, filtering, stats."""
H, W = frame.shape
# Optionally fill holes
if fill_holes:
if hole_area is not None:
mask_frame = remove_small_holes(mask_frame, area_threshold=hole_area)
else:
mask_frame = binary_fill_holes(mask_frame)
mask_frame = mask_frame.astype(bool)
# Label connected regions
initial_labeled = label(mask_frame)
filtered_mask = np.zeros_like(mask_frame, dtype=bool)
for prop in regionprops(initial_labeled):
if prop.area < min_size:
continue
minr, minc, maxr, maxc = prop.bbox
if remove_edge and (minr == 0 or minc == 0 or maxr == H or maxc == W):
continue
filtered_mask[initial_labeled == prop.label] = True
# Relabel after filtering
labeled = label(filtered_mask)
props = regionprops(labeled, intensity_image=frame)
# Collect stats
features: list[dict[str, Any]] = []
for prop in props:
mask_pixels = labeled == prop.label
vals = frame[mask_pixels]
if vals.size == 0:
continue
features.append(
{
"frame_timestamp": frame_ts,
"label": int(prop.label),
"area": float(prop.area),
"min": float(vals.min()),
"max": float(vals.max()),
"mean": float(vals.mean()),
"bbox": tuple(map(int, prop.bbox)), # (minr, minc, maxr, maxc)
"centroid": tuple(map(float, prop.centroid)),
}
)
return features, labeled
def _summarize(self, n_frames: int, total_features: int) -> dict[str, Any]:
"""Summarize results across frames."""
return {
"total_frames": n_frames,
"total_features": total_features,
"avg_features_per_frame": (
total_features / n_frames if n_frames > 0 else 0.0
),
}
[docs]
@param_conditions(
mask_fn=lambda p: not p.get("mask_key"),
mask_key=lambda p: not p.get("mask_fn"),
hole_area=lambda p: p.get("fill_holes", False),
)
def run(
self,
stack,
previous_results: Optional[dict[str, Any]] = None,
*,
# Mask input: either supply a mask function or refer to
# existing mask in previous_results
mask_fn: Optional[callable] = None,
mask_key: Optional[str] = None,
# Filtering criteria:
min_size: int = 10,
remove_edge: bool = True,
# Hole-filling options:
fill_holes: bool = False,
hole_area: Optional[int] = None,
# kwargs for mask_fn(frame, **mask_kwargs)
**mask_kwargs,
) -> dict[str, Any]:
"""
Detect contiguous features on each frame of stack.data.
Parameters
----------
stack : AFMImageStack
The AFM stack whose `.data` (3D array) and `.time_for_frame()` are used.
previous_results : dict[str, Any], optional
Mapping of earlier analysis outputs. If `mask_key` is given,
must contain a boolean mask array under that key.
mask_fn : callable, optional
Function frame->bool array for masking.
Required if `mask_key` is None.
Returns
-------
dict[str, Any]
Dictionary containing:
- features_per_frame : list of lists of dict
- labeled_masks : list of np.ndarray
- summary : dict with total_features, total_frames, avg_features_per_frame
Raises
------
ValueError
If `stack.data` is None or not 3D, or mask array invalid,
or neither `mask_fn` nor `mask_key` provided.
KeyError
If `mask_key` not found in `previous_results`.
Examples
--------
>>> pipeline.add("feature_detection", mask_fn=mask_mean_offset, min_size=20)
>>> result = pipeline.run(stack)
"""
data = stack.data
if data is None:
raise ValueError("AFMImageStack has no data")
if not isinstance(data, np.ndarray) or data.ndim != 3:
raise ValueError("stack.data must be a 3D numpy array (n_frames, H, W)")
n_frames, _, _ = data.shape
mask_arr = self._get_mask_array(
data, previous_results, mask_fn, mask_key, **mask_kwargs
)
features_per_frame: list[list[dict[str, Any]]] = []
labeled_masks: list[np.ndarray] = []
total_features = 0
for i in range(n_frames):
try:
frame_ts = float(stack.time_for_frame(i))
except Exception:
frame_ts = float(i)
feats, labeled = self._process_frame(
data[i],
mask_arr[i].copy(),
frame_ts,
min_size=min_size,
remove_edge=remove_edge,
fill_holes=fill_holes,
hole_area=hole_area,
)
features_per_frame.append(feats)
labeled_masks.append(labeled)
total_features += len(feats)
return {
"features_per_frame": features_per_frame,
"labeled_masks": labeled_masks,
"summary": self._summarize(n_frames, total_features),
}