Source code for playnano.analysis.modules.k_means_clustering

"""
K-Means clustering on features over the entire stack in 3D (x, y, time).

This module extracts a point-cloud from per-frame feature dictionaries
(e.g. coordinates + timestamps), optionally normalizes each axis to [0,1],
applies K-Means with a user-supplied k, then returns cluster assignments,
cluster centers (in original coordinate units), and a summary.

Parameters
----------
coord_key : str
    Key in previous_results whose value is `features_per_frame`
    (list of lists of dicts).

coord_columns : Sequence[str]
    Which keys in each feature-dict to use (e.g. ("x","y")).

use_time : bool
    If True and coord_columns length is 2, append frame time as the third dimension.

k : int
    Number of clusters.

normalise : bool
    If True, min-max normalize each axis before clustering.

time_weight : float | None
    If given, multiply the time axis by this weight.

**kmeans_kwargs
    Forwarded to sklearn.cluster.KMeans.
"""

from typing import Any, Optional, Sequence

import numpy as np
from sklearn.cluster import KMeans

from playnano.analysis.base import AnalysisModule


[docs] class KMeansClusteringModule(AnalysisModule): """ Cluster features across all frames using K-Means in 2D or 3D (x, y, [time]). Extracts point coordinates from per-frame features, applies optional normalization and time weighting, then performs K-Means clustering. Returns cluster assignments, centers in original scale, and a summary report. Parameters ---------- coord_key : str Key in previous_results pointing to 'features_per_frame' structure. coord_columns : Sequence[str] Keys to extract coordinates from each feature (e.g. ("x", "y")). use_time : bool If True, appends frame timestamp as a third clustering dimension. k : int Number of clusters to fit. normalise : bool If True, normalize each axis to [0, 1] before clustering. time_weight : float or None Optional multiplier for time axis after normalization. **kmeans_kwargs Additional keyword arguments passed to sklearn.cluster.KMeans. Version ------- 0.1.0 """ version = "0.1.0" @property def name(self) -> str: """ Name of the analysis module. Returns ------- str The string identifier for this module. """ return "k_means_clustering" requires = ["feature_detection", "log_blob_detection"]
[docs] def run( self, stack, previous_results: Optional[dict[str, Any]] = None, *, detection_module: str = "feature_detection", coord_key: str = "features_per_frame", coord_columns: Sequence[str] = ("centroid_x", "centroid_y"), use_time: bool = True, k: int, normalise: bool = True, time_weight: Optional[float] = None, **kmeans_kwargs: Any, ) -> dict[str, Any]: """ Perform K-Means clustering on features extracted from a stack. Constructs a coordinate array from features (x, y[, t]), optionally applies normalization and time weighting, and fits k-means to assign clusters. Parameters ---------- stack : AFMImageStack The input image stack providing frame times and data context. previous_results : dict[str, Any], optional Dictionary containing outputs from previous analysis steps. Must contain the selected detection_module and coord_key. detection_module : str Key identifying which previous module's output to use. Default is "feature_detection". coord_key : str Key under the detection module that holds per-frame feature dicts. Default is "features_per_frame". coord_columns : Sequence[str] Keys to extract from each feature for clustering coordinates. If missing, fallback to the "centroid" tuple is attempted. Default is ("centroid_x", "centroid_y") use_time : bool If True and `coord_columns` is 2D, append frame timestamp as third dimension. Default is True. k : int Number of clusters to compute. normalise : bool Whether to min-max normalize each axis of the feature points before clustering. Default is True. time_weight : float or None, optional Weighting factor for time axis (applied after normalization). Only used if time is included as a third dimension. **kmeans_kwargs : dict Additional arguments forwarded to `sklearn.cluster.KMeans`. Returns ------- dict A dictionary with the following keys: - "clusters" : list of dicts, each with: - id : int - frames : list of int - point_indices : list of int - coords : list of tuple The normalized coordinates used in clustering for each point in the cluster (e.g., (x, y[, t])). - "cluster_centers" : ndarray of shape (k, D) Cluster centers in original coordinate units. - "summary" : dict - "n_clusters" : int - "members_per_cluster" : dict mapping cluster id to point count Raises ------ RuntimeError If the required `detection_module` output is not found in previous_results. KeyError If the required coordinate keys are missing in any feature dictionary. """ if previous_results is None: raise RuntimeError(f"{self.name!r} requires previous results to run.") # Auto-detect the most recent available detection module if detection_module not in previous_results: available = [ mod for mod in reversed(self.requires) if mod in previous_results ] if not available: raise RuntimeError( f"{self.name!r} requires one of {self.requires}, but none were found in previous results." # noqa ) detection_module = available[0] per_frame = previous_results[detection_module][coord_key] points, metadata = [], [] for f_idx, feats in enumerate(per_frame): t = stack.time_for_frame(f_idx) for p_idx, feat in enumerate(feats): try: coords = [float(feat[c]) for c in coord_columns] except KeyError: cent = feat.get("centroid") if not cent or len(cent) < len(coord_columns): raise KeyError( f"Missing keys {coord_columns} in feature" ) from None coords = [float(cent[0]), float(cent[1])] if use_time and len(coords) == 2: coords.append(float(t)) points.append(coords) metadata.append((f_idx, p_idx)) if not points: dim = 3 if (use_time and len(coord_columns) == 2) else len(coord_columns) return { "clusters": [], "cluster_centers": np.empty((0, dim)), "summary": {"n_clusters": 0, "members_per_cluster": {}}, } data = np.array(points) # normalize each column if normalise: mins, maxs = data.min(0), data.max(0) spans = maxs - mins spans[spans == 0] = 1.0 data = (data - mins) / spans if time_weight is not None and data.shape[1] == 3: data[:, 2] *= time_weight # run KMeans km = KMeans(n_clusters=k, random_state=42, **kmeans_kwargs) labels = km.fit_predict(data) centers = km.cluster_centers_.copy() # undo weighting/normalization on centers if normalise: if time_weight is not None and centers.shape[1] == 3: centers[:, 2] /= time_weight centers = centers * spans + mins # format output clusters_out, members = [], {} for cid in range(k): idxs = np.where(labels == cid)[0].tolist() frames, p_inds, coords_list = [], [], [] for idx in idxs: f_idx, p_idx = metadata[idx] frames.append(f_idx) p_inds.append(p_idx) coords_list.append(tuple(data[idx].tolist())) clusters_out.append( { "id": cid, "frames": frames, "point_indices": p_inds, "coords": coords_list, } ) members[cid] = len(idxs) summary = {"n_clusters": k, "members_per_cluster": members} return { "clusters": clusters_out, "cluster_centers": centers, "summary": summary, }