Sorting Components module

Spike sorting is comprised of several steps, or components. In the spikeinterface.sortingcomponents module we are building a library of methods and steps that can be assembled to build full spike sorting pipelines.

The goal is to allow for the modularization of spike sorting algorithms. Currently, spike sorters are shipped as full packages with all the steps needed to perform end-to-end spike sorting.

However, this might not be the best option. It is in fact very likely that a sorter has one excellent step, say the clustering, but another step, which is sub-optimal. Decoupling different steps as separate components would allow one to mix-and-match sorting steps from different sorters.

Another advantage of modularization is that we can accurately benchmark every step of a spike sorting pipeline. For example, what is the performance of peak detection method 1 or 2, provided that the rest of the pipeline is the same?

Currently, we have methods for:

  • peak detection

  • peak localization

  • peak selection

  • motion estimation

  • motion interpolation

  • clustering

  • template matching

An important concept is the node pipeline machinery, which uses the run_node_pipeline() function, and will be covered in the Node pipelines section.

You can also have a look spikeinterface where there are more detailed notebooks on sorting components.

Peak detection

Peak detection is usually the first step of spike sorting and it consists of finding peaks in the traces that could be actual spikes.

Peaks can be detected with the detect_peaks() function as follows:

from spikeinterface.sortingcomponents.peak_detection import detect_peaks

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

peaks = detect_peaks(
    recording=recording,
    method='by_channel',
    peak_sign='neg',
    detect_threshold=5,
    exclude_sweep_ms=0.2,
    noise_levels=None,
    random_chunk_kwargs={},
    job_kwargs=job_kwargs,
)

The output peaks is a NumPy array with a length of the number of peaks found and the following dtype:

peak_dtype = [('sample_index', 'int64'), ('channel_index', 'int64'), ('amplitude', 'float64'), ('segment_index', 'int64')]

There are two different methods available with the method argument:

  • ‘locally_exclusive’ (requires numba): peaks on neighboring channels within a certain radius are excluded (not counted multiple times)

  • ‘matched_filtering’ (requires numba): a method based on convolution by a kernel that “looks like a spike” at several spatial scales. This is a bit slower but can detect spikes with lower amplitude.

Other variants are also implemented (but less tested or not so useful):

  • ‘by_channel’ : peaks are detected separately for each channel, this should be used in high density probe layout.

  • ‘by_channel_torch’ (requires torch): pytorch implementation (GPU-compatible) that uses max pooling for time deduplication

  • ‘locally_exclusive_torch’ (requires torch): pytorch implementation (GPU-compatible) that uses max pooling for space-time deduplication

NOTE: the torch implementations give slightly different results due to a different implementation.

Peak detection, as many of the other sorting components, can be run in parallel.

Peak localization

Peak localization estimates the spike location on the probe. An estimate of location can be important to correct for drift or cluster spikes into different units.

Peak localization can be run using localize_peaks() as follows:

from spikeinterface.sortingcomponents.peak_localization import localize_peaks

job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True)

peak_locations = localize_peaks(
    recording=recording,
    peaks=peaks,
    method='center_of_mass',
    method_kwargs=dict(
      radius_um=70.,
      ms_before=0.3,
      ms_after=0.6,
    ),
    job_kwargs=job_kwargs,
)

Currently, the following methods are implemented:

  • ‘center_of_mass’ : the fastest and most intuitive. This method is not accurate on the border of the probe, so for neuropixel only the ‘y’ axis will be well estimated. For in vitro, with a square MEA, all spikes on borders will also be biased.

  • ‘monopolar_triangulation’ with optimizer=’least_square’ This method is from Julien Boussard and Erdem Varol from the Paninski lab. This has been presented at NeurIPS see also here ‘monopolar_triangulation’ has some variant with differents optimizers (default is ‘minimize_with_log_penality’)

  • ‘grid_convolution’inspired by the Kilosort approach. This consists of a convolution of traces with waveform

    prototypes with varying local spatial footprint on the probe.

Please have a look at [Scopin2024], for details on these methods.

These methods are the same as implemented in spikeinterface.postprocessing.unit_localization

The output peak_locations is a 1d NumPy array with a dtype that depends on the chosen method.

For instance, the ‘monopolar_triangulation’ method will have:

localization_dtype = [('x', 'float64'),  ('y', 'float64'), ('z', 'float64'), ('alpha', 'float64')]

Note

By convention in SpikeInterface, when a probe is described in 3d
  • ‘x’ is the width of the probe

  • ‘y’ is the depth

  • ‘z’ is orthogonal to the probe plane

Node pipelines

Both detect_peaks() and localize_peaks() need to walk throughout the entire recording traces, duplicating the reading of the traces from disk and applying the preprocessing. This can be very slow!

Luckily, there is an internal machinery to avoid the multiple passes over the traces: the run_node_pipeline() function.

The node pipeline is an API that runs user-selected nodes in parallel on all traces’ chunks and performs computations like peak detection, peak localization, svd featuring, …

Here is a small example that does peak detection and localization at once. In the following, please note that there is an intermediate node, the ExtractDenseWaveforms node, that does not output final results (notice the return_output=False), but is needed to extract waveforms for the localization node.

import spikeinterface.full as si

# generate
recording, _, _ = si.generate_drifting_recording(
    probe_name="Neuropixels1-128",
    num_units=200,
    duration=300.,
    seed=2205,
    extra_outputs=False,
)

# let's makes a 3 nodes

# Node 0 : detect peak
noise_levels = si.get_noise_levels(recording, return_in_uV=False)
from spikeinterface.sortingcomponents.peak_detection.method_list import LocallyExclusivePeakDetector
node0 = LocallyExclusivePeakDetector(
    recording,
    return_output=True, # We want output from this node!!
    # then specific params
    noise_levels=noise_levels,
    peak_sign="neg",
    detect_threshold=5.,
    exclude_sweep_ms=0.5
)

# Node 1 : extract local waveforms
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms
node1 = ExtractDenseWaveforms(
    recording,
    parents=[node0],
    return_output=False, # We do NOT want to output all dense waveforms!!!!
    # then specific params
    ms_before=1.,
    ms_after=1.5,
)

# Node 2 : localize peaks using local waveforms
from spikeinterface.sortingcomponents.peak_localization.method_list import LocalizeMonopolarTriangulation
node2 = LocalizeMonopolarTriangulation(
    recording,
    parents=[node0, node1],
    return_output=True, # We want output from this node!!
    # then specific params
    radius_um=75.0,
    optimizer="minimize_with_log_penality",
)

nodes = [node0, node1, node2]

# our dear jobs kwargs dict
job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True)

# only 2 nodes give outputs
from spikeinterface.core.node_pipeline import run_node_pipeline
peaks, peak_locations = run_node_pipeline(recording, nodes, job_kwargs, job_name="my pipeline", gather_mode="memory")

# We strongly hope that geeks from various lab will appreciate the design.
# We spent hours debating on how to do it.

Peak selection

When too many peaks are detected a strategy can be used to select (or sub-sample) only some of them before clustering. This is the strategy used by spyking-circus and tridesclous, for instance. Then, clustering is run on this subset of peaks, templates are extracted, and a template-matching step is run to find all spikes.

The way the peak vector is reduced (or sub-sampled) is a crucial step because units with small firing rates can be hidden by this process.

from spikeinterface.sortingcomponents.peak_detection import detect_peaks

many_peaks = detect_peaks(...) # as in above example

from spikeinterface.sortingcomponents.peak_selection import select_peaks

some_peaks = select_peaks(peaks=many_peaks, method='uniform', n_peaks=10000)

Implemented methods are the following:

  • ‘uniform’

  • ‘uniform_locations’

  • ‘smart_sampling_amplitudes’

  • ‘smart_sampling_locations’

  • ‘smart_sampling_locations_and_time’

Motion estimation

Drift estimation is implemented directly in spikeintertface. So even sorters that do not handle drift can benefit from drift estimation/correction. Especially for acute Neuropixels-like probes, this is a crucial step.

The motion estimation step comes after peak detection and peak localization. Read more about it in the _motion_correction modules doc, and a more practical guide in the Handle motion/drift in your recording How To.

Here is an example with non-rigid motion estimation:

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
peaks = detect_peaks(recording=recording, ...) # as in above example

from spikeinterface.sortingcomponents.peak_localization import localize_peaks
peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above


from spikeinterface.sortingcomponents.motion import estimate_motion
motion = estimate_motion(
    recording=recording,
    peaks=peaks,
    peak_locations=peak_locations,
    method="dredge_ap",
    rigid=False,
    win_shape="gaussian",
    win_step_um=200.0,
    win_scale_um=300.0,
    win_margin_um=None,
    bin_um=1.0,
    bin_s=1.0,
    direction='y',
    progress_bar=True,
    verbose=True
)

In this example, because it is a non-rigid estimation, motion handles a 2d array (num_time_bins, num_spatial_bins). We could now check the motion object and see if we need to apply a correction.

Availables methods are:

  • ‘dredge_ap’ : the most mature method at the moement, done by [Windolf_b]

  • ‘decentralized’ : more or less the ancestor of ‘dredge_ap’

  • ‘iterative_template’ : this mimics the kilosort approach.

  • ‘medicine’ : a more recent approach done in [Watters].

A comparison of these methods can be read in [Garcia2024].

Motion interpolation

The estimated motion can be used to interpolate traces to attempt to correct for drift. One possible way is to make an interpolation sample-by-sample to compensate for the motion. The InterpolateMotionRecording is a preprocessing step doing this. This preprocessing is lazy, so that interpolation is done on-the-fly. However, the class needs the “motion vector” as input, which requires a relatively long computation (peak detection, localization and motion estimation).

Here is a short example that depends on the output of “Motion interpolation”:

from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording

recording_corrected = InterpolateMotionRecording(
    recording=recording_with_drift,
    motion=motion,
    spatial_interpolation_method='kriging',
    border_mode='remove_channels'
)
Notes:
  • spatial_interpolation_method “kriging” or “iwd” do not play a big role.

  • border_mode is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the border_mode='remove_channels' because this removes channels on the border that will be impacted by drift. Of course the larger the motion is the greater the number of channels that would be removed.

Clustering

The clustering step remains the central step of spike sorting. Historically this step was separated into two distinct parts: feature reduction and clustering. In SpikeInterface, we decided to regroup these two steps into the same module. This allows one to compute feature reduction ‘on-the-fly’ and avoid long computations and storage of large features.

The clustering step takes the recording and detected (and optionally selected) peaks as input and returns a label for every peak.

Some methods have been implemented with various ideas in mind. We really hope that this list will be extended soon by talented people willing to improve it. This is a crucial and not totally resolved step.

  • ‘iterative-hdbscan’method used in spkyking-circus2. This performs local hdbscan clusetrings on

    svd waveforms features.

  • ‘iterative-isosplit’method used in tridesclous2. This performs local isosplit clusetrings on

    svd waveforms features.

  • ‘hdbscan-positions’ : This performs a hdbscan clusetring based on the localizations of the spikes. This mimics the herdingspikes approach : make the clustering on spike position only but more flexible because more localization methods are availables.

  • ‘random-projections’ : attempt to make the feature from waveforms with random projections instead of the good-old-school-pca.

  • ‘graph-clustering’ : attempt to resolve the clusetring globally and not locally. This constructs a global but sparse distance matrix between all spikes. Can be slow. Then it performs ‘classical’ algos on graph (Louvain, Leiden or even HDBSCAN). Promising method but not as efficient as the ‘iterative-isosplit’ or ‘iterative-hdbscan’.

from spikeinterface.sortingcomponents.peak_detection import detect_peaks
peaks = detect_peaks(recording, ...) # as in above example

from spikeinterface.sortingcomponents.clustering import find_clusters_from_peaks
labels, peak_labels = find_clusters_from_peaks(recording=recording, peaks=peaks, method="iterative-isosplit")
  • labels : contains all possible labels (aka unit_ids)

  • peak_labels : vector with the same size as peaks containing the label for each peak

Extract SVD from peaks

Importantly many clustering functions internally use the extract_peaks_svd(). This runs a node pipeline on a selected peaks set that extracts waveforms, sparsifies them, and compresses them on the time axis using svd.

Template matching

Template matching is the final step used in many sorters (Kilosort, SpyKING-Circus, YASS, Tridesclous, HDsort…)

In this step, from a given catalogue (or dictionary) of templates (or atoms), the algorithms try to explain the traces as a linear sum of a template plus a residual noise.

At the moment, there are five methods implemented:

  • ‘nearest’: a simple implementation which is more or less a np.argmin distance for the spike waveforms against all templates.

  • ‘nearest-svd’: a smarter implementation than ‘nearest’ using svd compression and spatial sparsity.

  • ‘tdc-peeler’: a simple idea similar to ‘nearest’. Perform nearest on locally detected peaks, fit the amplitudes and remove them from the traces. Then re-run on residual. A bit naive but this is very fast.

  • ‘circus-omp’: a more serious implementation orthogonal template matching. This internally make a convolution of traces with all templates with some svd decomposition tricks to be faster. This is quite accurate but need lots of memory.

  • ‘wobble’: this is a re-implementation of the yass template matching code. Also very similar to ‘circus-omp’. This is the most accurate methods for discovering spike collisions.