"""
Module for X-Means clustering as part of the playNano analysis pipeline.
This module implements a version of the X-Means clustering algorithm,
an extension of K-Means that estimates the optimal number of clusters using the
Bayesian Information Criterion (BIC).
Based on:
Pelleg, D., & Moore, A. W. (2000). X-means: Extending K-means with Efficient
Estimation of the Number of Clusters. Carnegie Mellon University.
http://www.cs.cmu.edu/~dpelleg/download/xmeans.pdf
"""
import logging
from typing import Any, Optional, Sequence
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
from playnano.afm_stack import AFMImageStack
from playnano.analysis.base import AnalysisModule
logger = logging.getLogger(__name__)
[docs]
class XMeansClusteringModule(AnalysisModule):
"""
Cluster features using the X-Means algorithm over (x, y[, t]) coordinates.
This module clusters spatial (and optionally temporal) feature coordinates extracted
from an AFM stack using an X-Means algorithm implemented in pure Python.
Parameters
----------
coord_key : str
Key in previous_results[detection_module] to find feature list.
coord_columns : Sequence[str]
Names of feature dictionary keys to use for coordinates
(e.g. centroid_x, centroid_y).
use_time : bool
Whether to append frame timestamps as the third coordinate.
min_k : int
Initial number of clusters (minimum).
max_k : int
Maximum number of clusters to allow.
normalise : bool
Whether to min-max normalize coordinate space before clustering.
time_weight : float, optional
Multiplier for time dimension (after normalization).
Returns
-------
dict
Dictionary with clustering results:
- clusters: list of {id, frames, point_indices, coords}
- cluster_centers: (K, D) ndarray in original units
- summary: {n_clusters: int, members_per_cluster: dict}
Version
-------
0.1.0
"""
version = "0.1.0"
@property
def name(self) -> str:
"""
Name of the analysis module.
Returns
-------
str
The string identifier for this module.
"""
return "x_means_clustering"
requires = ["feature_detection", "log_blob_detection"]
[docs]
def run(
self,
stack: AFMImageStack,
previous_results: Optional[dict[str, Any]] = None,
*,
detection_module: str = "feature_detection",
coord_key: str = "features_per_frame",
coord_columns: Sequence[str] = ("centroid_x", "centroid_y"),
use_time: bool = True,
min_k: int = 1,
max_k: int = 10,
normalise: bool = True,
time_weight: Optional[float] = None,
replicates: int = 3,
max_iter: int = 300,
bic_threshold: float = 0.0,
) -> dict[str, Any]:
"""
Perform X-Means clustering on features extracted from an AFM stack.
This method extracts (x, y[, t]) coordinates from detected features,
optionally normalizes and time-weights them, and applies the X-Means algorithm
to automatically select the number of clusters based on the BIC score.
Parameters
----------
stack : AFMImageStack
The input image stack providing frame timing and metadata context.
previous_results : dict[str, Any], optional
Dictionary containing outputs from previous analysis steps.
Must contain the selected detection_module and coord_key.
detection_module : str
Key identifying which previous modules output to use.
Default is "feature_detection".
coord_key : str
Key under the detection module that holds per-frame feature dicts.
Default is "features_per_frame".
coord_columns : Sequence[str]
Keys to extract from each feature for clustering coordinates.
If missing, will fall back to using the "centroid" tuple if available.
Defaults is ("centroid_x", "centroid_y").
use_time : bool
If True and `coord_columns` only gives 2D coordinates, appends the
frame timestamp as a third dimension. Default is True.
min_k : int
Initial number of clusters to start with. Default is 1.
max_k : int
Maximum number of clusters allowed. Defalut is 10.
normalise : bool
Whether to normalize the feature coordinate axes to the [0, 1] range
before clustering. Default is True.
time_weight : float or None, optional
Multiplicative factor applied to the time axis (after normalization).
Used only if time is included as a third coordinate.
replicates : int
Number of times to run k-means internally to choose the best split.
Default is 3.
max_iter : int
Maximum number of iterations for each k-means call.
Default is 300.
bic_threshold : float
Minimum improvement in BIC required to split a cluster.
Default is 0.0 (any improvement allows a split).
Returns
-------
dict
A dictionary with the following keys:
- "clusters" : list of dicts, each with:
- id : int
- frames : list of int
- point_indices : list of int
- coords : list of tuple (normalized x, y, [t])
- "cluster_centers" : ndarray of shape (k, D)
Final cluster centers in original (denormalized) coordinates.
- "summary" : dict
- "n_clusters" : int
- "members_per_cluster" : dict mapping cluster ID to point count.
Raises
------
RuntimeError
If the required detection module output is missing from previous_results.
KeyError
If the expected coordinate keys are missing from any feature dictionary.
"""
# Validate input dependencies
if previous_results is None:
raise RuntimeError(f"{self.name!r} requires previous results to run.")
# Auto-detect the most recent available detection module
if detection_module not in previous_results:
available = [
mod for mod in reversed(self.requires) if mod in previous_results
]
if not available:
raise RuntimeError(
f"{self.name!r} requires one of {self.requires}, but none were found in previous results." # noqa
)
detection_module = available[0]
fd_out = previous_results[detection_module]
per_frame = fd_out[coord_key]
# Extract and format data points
points, metadata = [], []
for f_idx, feats in enumerate(per_frame):
tval = stack.time_for_frame(f_idx)
for p_idx, feat in enumerate(feats):
try:
coords = [float(feat[c]) for c in coord_columns]
except KeyError:
cent = feat.get("centroid")
if cent and len(cent) >= len(coord_columns):
coords = [float(cent[0]), float(cent[1])]
else:
raise KeyError(
f"Missing keys {coord_columns} in feature."
) from None
if use_time and len(coords) == 2:
coords.append(float(tval))
points.append(coords)
metadata.append((f_idx, p_idx))
if not points:
dim = 3 if use_time else len(coord_columns)
return {
"clusters": [],
"cluster_centers": np.empty((0, dim)),
"summary": {"n_clusters": 0, "members_per_cluster": {}},
}
data = np.array(points)
# Normalize
if normalise:
mins, maxs = data.min(axis=0), data.max(axis=0)
spans = maxs - mins
spans[spans == 0] = 1.0
data = (data - mins) / spans
if time_weight is not None and data.shape[1] == 3:
data[:, 2] *= time_weight
# Run X-means
labels, centers = core_xmeans(
data,
init_k=min_k,
max_k=max_k,
min_cluster_size=2,
distance="sqeuclidean",
replicates=replicates,
max_iter=max_iter,
bic_threshold=bic_threshold,
)
# Undo normalization on centers
if normalise:
if time_weight not in (None, 0.0) and centers.shape[1] == 3:
centers[:, 2] /= time_weight
centers = centers * spans + mins
# Format output
clusters_out, members = [], {}
for cid in np.unique(labels):
if cid < 0:
continue
idxs = np.where(np.atleast_1d(labels == cid))[0]
frames, coords_list, p_inds = [], [], []
for idx in idxs:
f_idx, p_idx = metadata[idx]
frames.append(f_idx)
p_inds.append(p_idx)
coords_list.append(tuple(data[idx]))
clusters_out.append(
{
"id": int(cid),
"frames": frames,
"point_indices": p_inds,
"coords": coords_list,
}
)
members[int(cid)] = len(idxs)
return {
"clusters": clusters_out,
"cluster_centers": centers,
"summary": {"n_clusters": len(members), "members_per_cluster": members},
}
[docs]
def core_xmeans(
data: np.ndarray,
init_k: int,
max_k: int,
min_cluster_size: int,
distance: str,
replicates: int,
max_iter: int,
bic_threshold: float,
) -> tuple[np.ndarray, np.ndarray]:
"""Core X-Means loop.
Parameters are equivalent to those in `run` above.
"""
k = init_k
centers = initialize_centers(data, k)
while k <= max_k:
km = KMeans(
n_clusters=k, n_init=replicates, max_iter=max_iter, random_state=42
).fit(data)
labels = km.labels_
centers = km.cluster_centers_
new_centers = []
split_occurred = False
for j in range(k):
pts = data[labels == j]
if len(pts) < 2:
new_centers.append(centers[j])
continue
km2 = KMeans(
n_clusters=2, n_init=replicates, max_iter=max_iter, random_state=42
).fit(pts)
labels2, centers2 = km2.labels_, km2.cluster_centers_
if (
sum(labels2 == 0) < min_cluster_size
or sum(labels2 == 1) < min_cluster_size
):
new_centers.append(centers[j])
continue
bic_parent = compute_bic(pts, centers[j : j + 1])
bic_children = sum(
compute_bic(pts[labels2 == lab], centers2[lab : lab + 1])
for lab in [0, 1]
)
if bic_children - bic_parent > bic_threshold:
new_centers.extend(centers2)
split_occurred = True
else:
new_centers.append(centers[j])
if not split_occurred or len(new_centers) > max_k:
break
centers = np.vstack(new_centers)
k = len(centers)
final_dists = cdist(data, centers, metric="sqeuclidean")
final_labels = np.argmin(final_dists, axis=1)
return final_labels, centers
[docs]
def compute_bic(points: np.ndarray, center: np.ndarray) -> float:
"""Compute Bayesian Information Criterion for a cluster.
Parameters
----------
points : np.ndarray
Points in the cluster.
center : np.ndarray
Cluster center (shape (1, D)).
Returns
-------
float
BIC value.
"""
n, p = points.shape
if n <= 1:
return -np.inf
sse = np.sum((points - center) ** 2)
var = sse / (n - 1)
if var <= 0:
var = np.finfo(float).eps
ll = -0.5 * n * p * np.log(2 * np.pi * var) - 0.5 * sse / var
num_params = p + 1
penalty = 0.5 * num_params * np.log(n)
return ll - penalty
[docs]
def initialize_centers(points: np.ndarray, k: int) -> np.ndarray:
"""Initialize k centers using a k-means++-like heuristic."""
n = points.shape[0]
centers = [points[np.random.choice(n)]]
for _ in range(1, k):
dists = np.min(cdist(points, np.vstack(centers), "sqeuclidean"), axis=1)
probs = dists / dists.sum()
cumprobs = np.cumsum(probs)
r = np.random.rand()
idx = np.searchsorted(cumprobs, r)
centers.append(points[idx])
return np.vstack(centers)