"""
Video processing functions for AFM time-series (stacks of frames).
This module provides functions that operate on 3D numpy arrays
(time-series of 2D AFM frames). These include:
- Frame alignment to compensate for drift
- Cropping and padding utilities
- Temporal (time-domain) filters
- Future extensions such as spatio-temporal denoising
All functions follow a NumPy-style API: input stacks are 3D arrays
with shape (n_frames, height, width). Outputs are processed
stacks and a metadata dictionary.
"""
from collections import deque
from typing import Callable, Literal, Optional
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.signal import correlate2d, fftconvolve
from skimage.registration import phase_cross_correlation
from playnano.utils.param_utils import param_conditions
from playnano.utils.versioning import versioned_filter
# -----------------------------------------------------------------------------#
# Alignment
# -----------------------------------------------------------------------------#
[docs]
@versioned_filter("0.1.0")
def align_frames(
stack: np.ndarray,
reference_frame: int = 0,
method: str = "fft_cross_correlation",
mode: str = "pad",
debug: bool = False,
max_shift: Optional[int] = None,
pre_filter_sigma: Optional[float] = None,
max_jump: Optional[int] = None,
):
"""
Align a stack of AFM frames to a reference frame using integer-pixel shifts.
Alignment is performed using either FFT-based or full cross-correlation.
Jump smoothing prevents abrupt unrealistic displacements between consecutive
frames by limiting the change in shift relative to the previous frame.
Parameters
----------
stack : np.ndarray[float]
3D array of shape (n_frames, height, width) containing the input AFM image
stack.
reference_frame : int, optional
Index of the frame to use as the alignment reference (default 0). Must be
within [0, n_frames-1].
method : {"fft_cross_correlation", "full_cross_correlation"}, optional
Alignment method (default "fft_cross_correlation"). FFT-based cross-correlation
is generally faster and uses less memory for large frames.
mode : {"pad", "crop", "crop_square"}, optional
How to handle borders after shifting:
- "pad": keep all frames with NaN padding (default)
- "crop": crop to intersection of all frames
- "crop_square": crop to largest centered square
debug : bool, optional
If True, returns additional diagnostic outputs.
max_shift : int, optional
Maximum allowed shift in pixels. Detected shifts are clipped to this range.
pre_filter_sigma : float, optional
Standard deviation of Gaussian filter applied to frames before
cross-correlation.
max_jump : int, optional
Maximum allowed change in shift between consecutive frames. If exceeded, the
shift is replaced by a linear extrapolation from the previous two frames.
Returns
-------
aligned_stack : np.ndarray[float]
Aligned 3D stack of frames. Shape may be larger than input to accommodate all
shifts.
metadata : dict
Dictionary containing alignment information:
- "reference_frame": int, index of the reference frame
- "method": str, the alignment method used
- "mode": str, border approach used
- "shifts": np.ndarray of shape (n_frames, 2), detected (dy, dx) shifts
- "original_shape": tuple of (height, width)
- "aligned_shape": tuple of (height, width) of the output canvas
- "border_mask": np.ndarray[bool], True where valid frame pixels exist
- "pre_filter_sigma": float or None
- "max_shift": int or None
- "max_jump": int or None
debug_outputs : dict, optional
Returned only if ``debug=True``. Contains:
- "shifts": copy of the shifts array.
Raises
------
ValueError
If ``stack.ndim`` is not 3.
ValueError
If ``method`` is not one of {"fft_cross_correlation", "full_cross_correlation"}.
ValueError
If ``reference_frame`` is not in the range [0, n_frames-1].
Notes
-----
- Using ``fft_cross_correlation`` reduces memory usage compared to full
cross-correlation because it leverages the FFT algorithm and avoids creating
large full correlation matrices.
- Padding with NaNs allows all frames to be placed without clipping, but may
increase memory usage for large shifts.
- The function does not interpolate subpixel shifts; all shifts are integer-valued.
Examples
--------
>>> import numpy as np
>>> from playnano.processing.video_processing import align_frames
>>> stack = np.random.rand(10, 200, 200) # 10 frames of 200x200 pixels
>>> aligned_stack, metadata = align_frames(stack, reference_frame=0)
>>> aligned_stack.shape
(10, 210, 210) # padded to accommodate shifts
>>> metadata['shifts']
array([[ 0, 0],
[ 1, -2],
...])
"""
stack = stack.astype(np.float32, copy=False)
if stack.ndim != 3:
raise ValueError(
f"stack must be a 3D array (n_frames, H, W), got shape {stack.shape}"
)
n_frames, H, W = stack.shape
# Validate reference_frame
if not (0 <= reference_frame < n_frames):
raise ValueError(
f"reference_frame must be in [0, {n_frames-1}], got {reference_frame}"
)
# Preprocess reference frame
ref = stack[reference_frame]
if pre_filter_sigma is not None:
ref = gaussian_filter(ref, sigma=pre_filter_sigma)
ref0 = ref.copy()
ref0 -= np.mean(ref0)
shifts = np.zeros((n_frames, 2), dtype=int)
for i in range(n_frames):
if i == reference_frame:
continue
frame = stack[i]
if pre_filter_sigma is not None:
frame = gaussian_filter(frame, sigma=pre_filter_sigma)
# Compute shift
if method == "fft_cross_correlation":
frame0 = frame.copy()
frame0 -= np.mean(frame0)
cc = fftconvolve(frame0[::-1, ::-1], ref0, mode="full")
y_center, x_center = H - 1, W - 1
elif method == "full_cross_correlation":
cc = correlate2d(
ref, frame, mode="full", boundary="fill", fillvalue=0
).astype(np.float32)
y_center, x_center = H - 1, W - 1
else:
raise ValueError(f"Unknown method: {method}")
# Apply max_shift restriction
if max_shift is not None:
y_min = max(0, y_center - max_shift)
y_max = min(cc.shape[0], y_center + max_shift + 1)
x_min = max(0, x_center - max_shift)
x_max = min(cc.shape[1], x_center + max_shift + 1)
cc_window = cc[y_min:y_max, x_min:x_max]
y_rel, x_rel = np.unravel_index(np.argmax(cc_window), cc_window.shape)
dy = (y_min + y_rel) - y_center
dx = (x_min + x_rel) - x_center
else:
y_max_idx, x_max_idx = np.unravel_index(np.argmax(cc), cc.shape)
dy = y_max_idx - y_center
dx = x_max_idx - x_center
# Jump smoothing (linear extrapolation from last 2 shifts)
if max_jump is not None:
if i == 1:
dy_prev, dx_prev = shifts[i - 1]
dy = dy_prev + np.clip(dy - dy_prev, -max_jump, max_jump)
dx = dx_prev + np.clip(dx - dx_prev, -max_jump, max_jump)
elif i >= 2:
prev_shift = shifts[i - 1]
prev_prev_shift = shifts[i - 2]
expected = prev_shift + (prev_shift - prev_prev_shift)
if abs(dy - prev_shift[0]) > max_jump:
dy = int(expected[0])
if abs(dx - prev_shift[1]) > max_jump:
dx = int(expected[1])
shifts[i] = (int(dy), int(dx))
# Pad canvas to fit all shifted frames
max_dy_pos, max_dy_neg = shifts[:, 0].max(), shifts[:, 0].min()
max_dx_pos, max_dx_neg = shifts[:, 1].max(), shifts[:, 1].min()
new_H = H + max_dy_pos - max_dy_neg
new_W = W + max_dx_pos - max_dx_neg
# Allocate aligned stack
aligned_stack = np.empty((n_frames, new_H, new_W), dtype=stack.dtype)
aligned_stack.fill(np.nan)
border_mask = np.zeros((new_H, new_W), dtype=bool)
for i in range(n_frames):
dy, dx = shifts[i]
y_start = dy - max_dy_neg
x_start = dx - max_dx_neg
aligned_stack[i, y_start : y_start + H, x_start : x_start + W] = stack[i]
border_mask[y_start : y_start + H, x_start : x_start + W] = True
metadata = {
"reference_frame": reference_frame,
"method": method,
"mode": mode,
"shifts": shifts,
"original_shape": (H, W),
"aligned_shape": (new_H, new_W),
"border_mask": border_mask,
"pre_filter_sigma": pre_filter_sigma,
"max_shift": max_shift,
"max_jump": max_jump,
}
# Apply cropping according to mode
if mode == "crop":
aligned_stack, crop_meta = intersection_crop(aligned_stack)
metadata["crop"] = crop_meta
if "bounds" in crop_meta:
y_min, y_max, x_min, x_max = crop_meta["bounds"]
border_mask = border_mask[y_min : y_max + 1, x_min : x_max + 1]
metadata["border_mask"] = border_mask
elif mode == "crop_square":
aligned_stack, crop_meta = crop_square(aligned_stack)
metadata["crop"] = crop_meta
H_sq, W_sq = aligned_stack.shape[1:]
r_start, c_start = crop_meta["offset"]
border_mask = border_mask[r_start : r_start + H_sq, c_start : c_start + W_sq]
metadata["border_mask"] = border_mask
elif mode != "pad":
raise ValueError(f"Unknown mode: {mode}")
if debug:
return aligned_stack, metadata, {"shifts": shifts.copy()}
return aligned_stack, metadata
[docs]
@versioned_filter("0.1.0")
def rolling_frame_align(
stack: np.ndarray,
window: int = 5,
mode: str = "pad",
debug: bool = False,
max_shift: Optional[int] = None,
pre_filter_sigma: Optional[float] = None,
max_jump: Optional[int] = None,
):
"""
Align a stack of AFM frames using a rolling reference and integer pixel shifts.
This function computes frame-to-frame shifts relative to a rolling reference
(average of the last `window` aligned frames) using phase cross-correlation.
Each frame is then placed on a canvas large enough to accommodate all shifts.
Optional jump smoothing prevents sudden unrealistic displacements between
consecutive frames, and optional Gaussian pre-filtering can improve correlation
robustness for noisy data.
Parameters
----------
stack : np.ndarray[float]
3D array of shape (n_frames, height, width) containing the image frames.
window : int, optional
Number of previous aligned frames to average when building the rolling
reference. Default is 5.
mode : {"pad", "crop", "crop_square"}, optional
How to handle borders after shifting:
- "pad": keep all frames with NaN padding (default)
- "crop": crop to intersection of all frames
- "crop_square": crop to largest centered square
debug : bool, optional
If True, returns additional diagnostic outputs such as the rolling reference
frames. Default is False.
max_shift : int, optional
Maximum allowed shift in pixels along either axis. Detected shifts are clipped.
Default is None (no clipping).
pre_filter_sigma : float, optional
Standard deviation of Gaussian filter applied to both reference and moving
frames prior to cross-correlation. Helps reduce noise. Default is None.
max_jump : int, optional
Maximum allowed jump in pixels between consecutive frame shifts. If exceeded,
the shift is replaced by a linear extrapolation from the previous two shifts.
Default is None (no jump smoothing).
Returns
-------
aligned_stack : np.ndarray[float]
3D array of shape (n_frames, canvas_height, canvas_width) containing the
aligned frames. NaN values indicate areas outside the original frames after
alignment.
metadata : dict
Dictionary containing alignment information:
- "window": int, rolling reference window used
- "method": str, alignment method used
- "mode": str, border approach used
- "shifts": ndarray of shape (n_frames, 2), detected integer shifts (dy, dx)
- "original_shape": tuple of (height, width)
- "aligned_shape": tuple of (canvas_height, canvas_width)
- "border_mask": ndarray of shape (canvas_height, canvas_width), True where
valid pixels exist
- "pre_filter_sigma": float or None
- "max_shift": int or None
- "max_jump": int or None
debug_outputs : dict, optional
Returned only if `debug=True`. Contains:
- "shifts": copy of the detected shifts array
- "aligned_refs": deque of indices used for rolling reference
Raises
------
ValueError
If ``stack.ndim`` is not 3.
ValueError
If ``window`` < 1.
Notes
-----
- The rolling reference is computed using the last `window` aligned frames,
ignoring NaN pixels.
- Shifts are integer-valued; no subpixel interpolation is performed.
- Padding ensures all frames fit without clipping, but increases memory usage.
- Internally, a deque ``aligned_refs`` tracks which patches of which frames
contribute to the rolling reference. Each entry stores:
(frame_index, y0c, y1c, x0c, x1c, fy0, fy1, fx0, fx1),
i.e. both the region of the canvas updated and the corresponding slice in
the original frame. This allows exact removal of old contributions from
``rolling_sum`` and ``rolling_count`` when the window is exceeded, ensuring
consistency without recomputation.
Examples
--------
>>> import numpy as np
>>> from playnano.processing.video_processing import rolling_frame_align
>>> stack = np.random.rand(10, 200, 200) # 10 frames of 200x200 pixels
>>> aligned_stack, metadata = rolling_frame_align(stack, window=3)
>>> aligned_stack.shape
(10, 210, 210)
>>> metadata['shifts']
array([[0, 0],
[1, -1],
...])
"""
stack = stack.astype(np.float32, copy=False)
if stack.ndim != 3:
raise ValueError(
f"stack must be a 3D array (n_frames, H, W), got shape {stack.shape}"
)
if window < 1:
raise ValueError("window must be >= 1")
n_frames, H, W = stack.shape
shifts = np.zeros((n_frames, 2), dtype=int)
# Safety margin for rolling reference
margin = max_shift if max_shift is not None else max(H, W) // 2
canvas_H, canvas_W = H + 2 * margin, W + 2 * margin
# Pre-filter stack once if needed
if pre_filter_sigma is not None:
filtered_stack = np.empty_like(stack)
for i in range(n_frames):
filtered_stack[i] = gaussian_filter(stack[i], sigma=pre_filter_sigma)
else:
filtered_stack = stack
rolling_sum = np.zeros((canvas_H, canvas_W), dtype=np.float32)
rolling_count = np.zeros((canvas_H, canvas_W), dtype=np.uint32)
# Store only canvas indices for rolling reference
aligned_refs = deque() # each entry: (center_y, y_end, center_x, x_end)
# Seed first frame at center
center_y, center_x = margin, margin
y_end, x_end = center_y + H, center_x + W
rolling_sum[center_y:y_end, center_x:x_end] += filtered_stack[0]
rolling_count[center_y:y_end, center_x:x_end] += 1
# Store a full record (frame_idx + padded box + frame box)
aligned_refs.append(
(
0, # frame_idx
center_y,
y_end, # padded box y
center_x,
x_end, # padded box x
0,
H, # frame box y
0,
W, # frame box x
)
)
for i in range(1, n_frames):
# Compute rolling mean
with np.errstate(divide="ignore", invalid="ignore"):
denom_safe = np.where(rolling_count == 0, 1.0, rolling_count)
ref = rolling_sum / denom_safe
ref_mask_bool = rolling_count != 0
frame_corr = filtered_stack[i]
shift, _, _ = phase_cross_correlation(
ref,
frame_corr,
upsample_factor=1,
reference_mask=ref_mask_bool,
moving_mask=np.ones_like(frame_corr, dtype=bool),
)
dy, dx = shift.astype(int)
# Jump smoothing
if max_jump is not None and i >= 2:
prev_shift = shifts[i - 1]
prev_prev_shift = shifts[i - 2]
expected = prev_shift + (prev_shift - prev_prev_shift)
if abs(dy - prev_shift[0]) > max_jump:
dy = int(expected[0])
if abs(dx - prev_shift[1]) > max_jump:
dx = int(expected[1])
# Clip by max_shift
if max_shift is not None:
dy = int(np.clip(dy, -max_shift, max_shift))
dx = int(np.clip(dx, -max_shift, max_shift))
shifts[i] = (dy, dx)
# Compute canvas region
y_start, x_start = center_y + dy, center_x + dx
y_end, x_end = y_start + H, x_start + W
# Clip to canvas
y0c, x0c = max(0, y_start), max(0, x_start)
y1c, x1c = min(canvas_H, y_end), min(canvas_W, x_end)
# Corresponding patch in frame
fy0, fx0 = y0c - y_start, x0c - x_start
fy1, fx1 = fy0 + (y1c - y0c), fx0 + (x1c - x0c)
if fy1 <= fy0 or fx1 <= fx0:
continue # skip empty patch
patch = frame_corr[fy0:fy1, fx0:fx1]
# Update rolling accumulators
rolling_sum[y0c:y1c, x0c:x1c] += patch
rolling_count[y0c:y1c, x0c:x1c] += 1
# store frame index & frame coords
aligned_refs.append((i, y0c, y1c, x0c, x1c, fy0, fy1, fx0, fx1))
# Remove oldest frame from rolling sum/count if window exceeded
if len(aligned_refs) > window:
(
frame_idx,
old_y0,
old_y1,
old_x0,
old_x1,
old_fy0,
old_fy1,
old_fx0,
old_fx1,
) = aligned_refs.popleft()
old_patch = filtered_stack[frame_idx, old_fy0:old_fy1, old_fx0:old_fx1]
rolling_sum[old_y0:old_y1, old_x0:old_x1] -= old_patch
rolling_count[old_y0:old_y1, old_x0:old_x1] -= 1
# Compute final canvas
min_dy, max_dy = shifts[:, 0].min(), shifts[:, 0].max()
min_dx, max_dx = shifts[:, 1].min(), shifts[:, 1].max()
final_H, final_W = H + (max_dy - min_dy), W + (max_dx - min_dx)
y_offset, x_offset = -min_dy, -min_dx
aligned_stack = np.full((n_frames, final_H, final_W), np.nan, dtype=stack.dtype)
border_mask = np.zeros((final_H, final_W), dtype=bool)
for i in range(n_frames):
dy, dx = shifts[i]
y_start, x_start = y_offset + dy, x_offset + dx
aligned_stack[i, y_start : y_start + H, x_start : x_start + W] = stack[i]
border_mask[y_start : y_start + H, x_start : x_start + W] = True
metadata = {
"window": window,
"method": "phase_cross_correlation",
"mode": mode,
"shifts": shifts,
"original_shape": (H, W),
"aligned_shape": aligned_stack.shape[1:],
"border_mask": border_mask,
"pre_filter_sigma": pre_filter_sigma,
"max_shift": max_shift,
"max_jump": max_jump,
}
# Apply cropping
if mode == "crop":
aligned_stack, crop_meta = intersection_crop(aligned_stack)
metadata["crop"] = crop_meta
if "bounds" in crop_meta:
y_min, y_max, x_min, x_max = crop_meta["bounds"]
metadata["border_mask"] = metadata["border_mask"][
y_min : y_max + 1, x_min : x_max + 1
]
elif mode == "crop_square":
aligned_stack, crop_meta = crop_square(aligned_stack)
metadata["crop"] = crop_meta
H_sq, W_sq = aligned_stack.shape[1:]
r_start, c_start = crop_meta["offset"]
metadata["border_mask"] = metadata["border_mask"][
r_start : r_start + H_sq, c_start : c_start + W_sq
]
elif mode != "pad":
raise ValueError(f"Unknown mode: {mode}")
if debug:
return (
aligned_stack,
metadata,
{"shifts": shifts.copy(), "aligned_refs": list(aligned_refs)},
)
return aligned_stack, metadata
# -----------------------------------------------------------------------------#
# Cropping and padding helpers
# -----------------------------------------------------------------------------#
def _normalize_pad(pad):
"""
Normalize pad argument to (top, bottom, left, right).
Accepts:
- int: uniform pad on all sides
- tuple/list of length 2: (vertical, horizontal)
- tuple/list of length 4: (top, bottom, left, right)
"""
if isinstance(pad, (int, np.integer)):
return (pad, pad, pad, pad)
if isinstance(pad, (tuple, list)):
if len(pad) == 2:
v, h = pad
return (v, v, h, h)
if len(pad) == 4:
return tuple(pad)
raise ValueError("pad must be int, (v, h), or (top, bottom, left, right)")
def _crop_with_pad(stack, y0, y1, x0, x1, pad_value=np.nan):
"""
Crop stack[:, y0:y1, x0:x1] even if y0<0 or x1>W etc., padding with pad_value.
y0, y1, x0, x1 are *exclusive-end* bounds in original coordinates.
"""
n, H, W = stack.shape
# Compute required padding
pad_top = max(0, -y0)
pad_left = max(0, -x0)
pad_bottom = max(0, y1 - H)
pad_right = max(0, x1 - W)
if pad_top or pad_bottom or pad_left or pad_right:
stack_p = np.pad(
stack,
pad_width=((0, 0), (pad_top, pad_bottom), (pad_left, pad_right)),
mode="constant",
constant_values=pad_value,
)
# Shift requested box into padded coordinate system
y0_p = y0 + pad_top
y1_p = y1 + pad_top
x0_p = x0 + pad_left
x1_p = x1 + pad_left
out = stack_p[:, y0_p:y1_p, x0_p:x1_p]
else:
out = stack[:, y0:y1, x0:x1]
return out, (pad_top, pad_bottom, pad_left, pad_right)
[docs]
@versioned_filter("0.2.0")
def intersection_crop(stack: np.ndarray, pad=0) -> tuple[np.ndarray, dict]:
"""
Crop aligned stack to the largest common intersection region (finite across frames).
Option to add padding to expand the crop beyond the intersection, filling with NaN
when beyond the data.
Parameters
----------
stack : ndarray of shape (n_frames, height, width)
Input aligned stack with NaN padding.
pad : int or tuple, optional (default=0)
Extra pixels to add around the intersection bounds.
- int: uniform pad
- (v, h): vertical and horizontal pad
- (top, bottom, left, right): per-side pad
Returns
-------
cropped : ndarray
Cropped (and possibly padded) stack.
meta : dict
Metadata including original shape, intersection bounds, requested bounds,
actual padding applied, and new shape.
"""
valid_mask = np.all(np.isfinite(stack), axis=0)
rows = np.any(valid_mask, axis=1)
cols = np.any(valid_mask, axis=0)
H, W = stack.shape[1:]
if not np.any(rows) or not np.any(cols):
meta = {
"operation": "crop_intersection",
"original_shape": (H, W),
"intersection_shape": (H, W),
"new_shape": (H, W),
"intersection_bounds": None,
"requested_bounds": None,
"applied_pad": (0, 0, 0, 0),
"note": "No finite pixels found, returned original stack",
}
return stack.copy(), meta
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
y0, y1 = y_min, y_max + 1
x0, x1 = x_min, x_max + 1
pt, pb, pl, pr = _normalize_pad(pad)
y0_req = y0 - pt
y1_req = y1 + pb
x0_req = x0 - pl
x1_req = x1 + pr
cropped, applied_pad = _crop_with_pad(stack, y0_req, y1_req, x0_req, x1_req)
meta = {
"operation": "crop_intersection",
"original_shape": (H, W),
"intersection_shape": (y1 - y0, x1 - x0),
"new_shape": cropped.shape[1:],
"intersection_bounds": (y0, y1, x0, x1),
"requested_bounds": (y0_req, y1_req, x0_req, x1_req),
"applied_pad": applied_pad,
"pad_param": (pt, pb, pl, pr),
}
return cropped, meta
[docs]
@versioned_filter("0.2.0")
def crop_square(stack: np.ndarray, pad=0) -> tuple[np.ndarray, dict]:
"""
Crop aligned stack to the largest centered square region.
This is based on the finite-pixel intersection across frames, with optional
outward padding (np.nan).
Parameters
----------
stack : ndarray of shape (n_frames, height, width)
Input aligned stack with possible NaN padding.
pad : int or tuple, optional (default=0)
Extra pixels to add around the square bounds.
Accepts:
- int: uniform pad
- (v, h): vertical and horizontal pad
- (top, bottom, left, right): per-side pad
Returns
-------
cropped : ndarray
Cropped (and possibly padded) square stack.
meta : dict
Metadata including original shape, intersection shape, square size, bounds,
padding details, and offset compatible with the original function
(offset within the intersection crop).
"""
H, W = stack.shape[1:]
# Compute finite-pixel intersection mask (across frames)
valid_mask = np.all(np.isfinite(stack), axis=0)
rows = np.any(valid_mask, axis=1)
cols = np.any(valid_mask, axis=0)
if not np.any(rows) or not np.any(cols):
# No finite pixels: return original stack unchanged (original behavior)
# offset is 0,0 (undefined in this case, but keep compatible keys)
meta = {
"operation": "crop_square",
"original_shape": (H, W),
"intersection_shape": (H, W),
"new_shape": (H, W),
"square_size": min(H, W),
"offset": (0, 0),
"intersection_bounds": None,
"square_bounds": None,
"requested_bounds": None,
"applied_pad": (0, 0, 0, 0),
"pad_param": _normalize_pad(pad),
"note": "No finite pixels found; returned original stack",
}
return stack.copy(), meta
# Tight intersection bounds in exclusive-end coordinates
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
y0_i, y1_i = y_min, y_max + 1
x0_i, x1_i = x_min, x_max + 1
H_i, W_i = (y1_i - y0_i), (x1_i - x0_i)
# Largest centered square INSIDE the intersection
size = min(H_i, W_i)
r_start = (H_i - size) // 2 # offset within the intersection (for metadata compat)
c_start = (W_i - size) // 2
# Convert that to absolute image coordinates
y0_sq = y0_i + r_start
x0_sq = x0_i + c_start
y1_sq = y0_sq + size
x1_sq = x0_sq + size
# Apply outward padding around the square
pt, pb, pl, pr = _normalize_pad(pad)
y0_req = y0_sq - pt
y1_req = y1_sq + pb
x0_req = x0_sq - pl
x1_req = x1_sq + pr
# Single crop directly from the original (no prior intersection cropping)
cropped, applied_pad = _crop_with_pad(stack, y0_req, y1_req, x0_req, x1_req)
meta = {
"operation": "crop_square",
"original_shape": (H, W),
"intersection_shape": (H_i, W_i),
"new_shape": cropped.shape[1:],
"square_size": size,
"offset": (r_start, c_start),
"intersection_bounds": (y0_i, y1_i, x0_i, x1_i),
"square_bounds": (y0_sq, y1_sq, x0_sq, x1_sq),
"requested_bounds": (y0_req, y1_req, x0_req, x1_req),
"applied_pad": applied_pad, # actual pad used beyond image edges
"pad_param": (pt, pb, pl, pr), # requested pad
}
return cropped, meta
[docs]
@param_conditions(value=lambda p: p.get("mode") == "constant")
@versioned_filter("0.1.0")
def replace_nan(
stack: np.ndarray,
mode: Literal["zero", "mean", "median", "global_mean", "constant"] = "zero",
value: float | None = None,
) -> tuple[np.ndarray, dict]:
"""
Replace NaN values in a 2D frame or 3D AFM image stack using various strategies.
Primarily used in video pipelines after alignment, but also applicable to single
frames.
Parameters
----------
stack : np.ndarray
Input 3D array of shape (n_frames, height, width) or 2D frame (height, width)
that may contain NaN values.
mode : {"zero", "mean", "median", "global_mean", "constant"}, optional
Replacement strategy. Default is "zero".
- "zero" : Replace NaNs with 0.
- "mean" : Replace NaNs with the mean of each frame.
- "median" : Replace NaNs with the median of each frame.
- "global_mean" : Replace NaNs with the mean of the entire stack.
- "constant" : Replace NaNs with a user-specified constant `value`.
value : float, optional
Constant value to use when `mode="constant"`. Must be provided in that case.
Returns
-------
filled : np.ndarray
Stack of the same shape as `stack` with NaNs replaced according to `mode`.
meta : dict
Metadata about the NaN replacement operation (e.g., count, mode, constant used).
Raises
------
ValueError
If `mode` is unknown or if `mode="constant"` and `value` is not provided.
Notes
-----
- Frame-wise operations like "mean" and "median" compute statistics per frame
independently.
- Preserves the dtype of the input stack.
"""
filled = stack.copy()
nan_count = np.isnan(filled).sum()
if mode == "zero":
filled[np.isnan(filled)] = 0
elif mode == "mean":
for i in range(filled.shape[0]):
frame = filled[i]
mask = np.isnan(frame)
if np.any(mask):
frame_mean = np.nanmean(frame)
frame[mask] = frame_mean
elif mode == "median":
for i in range(filled.shape[0]):
frame = filled[i]
mask = np.isnan(frame)
if np.any(mask):
frame_median = np.nanmedian(frame)
frame[mask] = frame_median
elif mode == "global_mean":
global_mean = np.nanmean(filled)
filled[np.isnan(filled)] = global_mean
elif mode == "constant":
if value is None:
raise ValueError("Must provide 'value' for constant mode.")
filled[np.isnan(filled)] = value
else:
raise ValueError(f"Unknown mode: {mode}")
meta = {
"nans_filled": int(nan_count),
"mode": mode,
"value_used": value if mode == "constant" else None,
}
return filled, meta
# -----------------------------------------------------------------------------#
# Temporal filters
# -----------------------------------------------------------------------------#
[docs]
def temporal_mean_filter(stack: np.ndarray, window: int = 3) -> np.ndarray:
"""
Apply mean filter across the time dimension.
Parameters
----------
stack : ndarray of shape (n_frames, height, width)
Input stack.
window : int, optional
Window size (number of frames). Default is 3.
Returns
-------
filtered : ndarray of shape (n_frames, height, width)
Stack after temporal mean filtering.
"""
raise NotImplementedError
# -----------------------------------------------------------------------------#
# Future placeholders (diagnostics, spatio-temporal denoising, etc.)
# -----------------------------------------------------------------------------#
[docs]
def register_video_processing() -> dict[str, Callable]:
"""
Return a dictionary of registered video processing filters.
Keys are names of the operations, values are the functions themselves.
These functions should take a 3D stack (n_frames, H, W) and return either
an ndarray (filtered stack) or a tuple (stack, metadata).
"""
return {
"align_frames": align_frames,
"rolling_frame_align": rolling_frame_align,
"intersection_crop": intersection_crop,
"crop_square": crop_square,
"replace_nan": replace_nan,
}