Motion/drift correction
See a practical guide to motion correction in our How To guide: Handle motion/drift in your recording.
Overview
Mechanical drift, often observed in recordings when a probe is acutely inserted, is currently a major issue for spike sorting. This is especially striking with the new generation of high-density devices used for in-vivo electrophysiology such as the NeuroPixels probes. The first sorter that introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021] [SteinmetzDataset] [Pachitariu2023])
The first algorithm used the same ideas as those used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea of the algorithm is create an “image” via the activity profile of the cells during a given time window. Assuming this activity profile should be kept constant over time, the motion can be estimated, by blocks, along the probe’s insertion axis (i.e. depth) so that we can interpolate the traces to compensate for this estimated motion.
There are now several algorithms which try to correct for drift as a preprocessing step: the Paninski group from Columbia University introduced DREDGE (see [Varol2021] and [Windolf2023]), and the Jazayeri lab introduced MEDiCINe ([Watters]).
Because motion registration is a hard topic, with numerous hypotheses and/or implementations details that might have a large impact on the spike sorting performances (see [Garcia2023]), in SpikeInterface, we developed a full motion estimation and interpolation framework to make all these methods accessible in one place. This modular approach offers a major benefit: the drift correction can be applied to a recording as a preprocessing step, and then used for any sorter! In short, the motion correction is decoupled from the sorter itself.
This gives the user flexibility to check/test and correct the drift before the sorting process.
Here is an overview of motion correction as part of preprocessing a recording:
The motion correction process can be split into 3 steps:
activity profile : detect peaks and localize them along time and depth
motion inference: estimate the drift motion (by spatial blocks for non-rigid motion)
motion interpolation: interpolate traces using the estimated motion
For each step, we have implemented several methods. The combination of the yellow boxes should give more or less what Kilosort2.5/3 is doing. Similarly, the combination of the green boxes gives the method developed by the Paninski group. Of course the end user can combine any of these methods to get the best motion correction possible. This also makes an incredible framework for testing new ideas.
For a better overview, checkout our recent paper to validate, benchmark, and compare these motion correction methods (see [Garcia2023]).
- SpikeInterface offers two levels for motion correction:
A high level with a unique function and predefined parameter presets
A low level where the user needs to call one by one all functions for better control
High-level API
One challenging task for motion correction is to determine the parameters.
The high level correct_motion() proposes the concept of a “preset” that already
has predefined parameters, in order to achieve a calibrated behavior.
We currently have these presets:
“dredge”: The official implementation of DREDGE, used in [Windolf2023].
“dredge_fast”: A faster implementation of DREDGE, which should give similar results.
- “nonrigid_accurate”: A precursor to DREDGE. This consists of monopolar triangulation + decentralized + inverse distance weighted
It is the slowest combination, but maybe the most accurate. The main bottleneck of this preset is the monopolar triangulation for the estimation of the peaks positions. To speed it up, one could think about subsampling the space of all the detected peaks. Introduced by the Paninski group ([Varol2021], [Windolf2023])
“medicine”: A wrapped version of MEDiCINe, [Watters].
“nonrigid_fast_and_accurate”: A mixture of Kilosort and DREDGE ideas. Roughly: grid_convolution + decentralized motion estimation.
- “rigid_fast”: A fast, but not very accurate method. It uses center of mass + decentralized + inverse distance weighted
To be used as check and/or control on a recording to check the presence of drift. Note that, in this case the drift is considered as “rigid” over the electrode.
- “kilosort_like”: This consists of grid convolution + iterative_template + kriging, to mimic what is done in Kilosort (see [Pachitariu2023]).
Note that this is not exactly 100% what Kilosort is doing, because the peak detection is done with a template matching in Kilosort, while in SpikeInterface we use a threshold-based method. However, this “preset” gives similar results to Kilosort2.5.
# read and preprocess
rec = read_spikeglx(folder_path='/my/Neuropixel/recording')
rec = bandpass_filter(recording=rec)
rec = common_reference(recording=rec)
# then correction is one line of code
rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate")
The process is quite long due the two first steps (activity profile + motion inference)
But the return rec_corrected is a lazy recording object that will interpolate traces on the
fly (step 3 motion interpolation).
If you want to user other presets, this is as easy as:
# mimic kilosort motion
rec_corrected = correct_motion(recording=rec, preset="kilosort_like")
# super but less accurate and rigid
rec_corrected = correct_motion(recording=rec, preset="rigid_fast")
Optionally any parameter from the preset can be overwritten:
rec_corrected = correct_motion(
recording=rec, preset="nonrigid_accurate",
detect_kwargs=dict(
detect_threshold=10.
),
estimate_motion_kwargs=dict(
histogram_depth_smooth_um=8.,
time_horizon_s=120.,
),
correct_motion_kwargs=dict(
spatial_interpolation_method="kriging",
)
)
Importantly, all the results and intermediate computations can be returned to a motion object, for further loading, verification and visualization.
motion_folder = '/somewhere/to/save/the/motion'
rec_corrected, motion = correct_motion(recording=rec, preset="nonrigid_accurate", output_motion=True)
from spikeinterface.widgets import plot_motion
plot_motion(motion)
Alternatively, you can save the motion (and related motion info) in a folder. The folder will contain the motion vector itself, as well as detected peaks, peak locations, and more.
motion_folder = '/somewhere/to/save/the/motion'
rec_corrected = correct_motion(recording=rec, preset="nonrigid_accurate", folder=motion_folder)
# and then
motion_info = load_motion_info(motion_folder)
Low-level API
All steps (activity profile, motion inference, motion interpolation) can be launched with distinct functions.
This can be useful to find the best method and finely tune/optimize parameters at each step.
All functions are implemented in the sortingcomponents module.
They all have a simple API with SpikeInterface objects or numpy arrays as inputs.
Since motion correction is a hot topic, these functions have many possible methods and also many possible parameters.
Finding the best combination of method/parameters is not that easy, but it should be doable, assuming the presets are not
working properly for your particular case.
The high-level correct_motion() is internally equivalent to this:
# each import is needed
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.motion import estimate_motion, interpolate_motion
job_kwargs = dict(chunk_duration="1s", n_jobs=20, progress_bar=True)
# Step 1 : activity profile
peaks = detect_peaks(recording=rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs)
# (optional) sub-select some peaks to speed up the localization
peaks = select_peaks(peaks=peaks, ...)
peak_locations = localize_peaks(recording=rec, peaks=peaks, method="monopolar_triangulation",
method_kwargs(radius_um=75.0,max_distance_um=150.0),
job_kwargs=job_kwargs)
# Step 2: motion inference
motion = estimate_motion(
recording=rec,
peaks=peaks,
peak_locations=peak_locations,
method="decentralized",
direction="y",
bin_um=5.0,
)
# Step 3: motion interpolation
# this step is lazy
rec_corrected = interpolate_motion(
recording=rec,
motion=motion,
border_mode="remove_channels",
spatial_interpolation_method="kriging",
sigma_um=30.
)
Preprocessing details
The function correct_motion() requires a preprocessed recording.
It is important to keep in mind that the preprocessing can have a strong impact on the motion estimation.
- In the context of motion correction we advise:
to not use whitening before motion estimation (as it interferes with spatial amplitude information)
to remove high frequencies in traces, to reduce noise in peak location (e.g. using a bandpass filter)
if you use Neuropixels, then use
phase_shift()in preprocessing
Note that given the flexibility and lazy preprocessing layer of SpikeInterface, it is very easy to implement two different preprocessing chains: one for motion correction and one for spike sorting. See the following example:
raw_rec = read_spikeglx(folder_path='/spikeglx_folder')
# preprocessing 1 : bandpass (this is smoother) + cmr
rec1 = si.bandpass_filter(recording=raw_rec, freq_min=300., freq_max=5000.)
rec1 = si.common_reference(recording=rec1, reference='global', operator='median')
# here the corrected recording is done on the preprocessing 1
# rec_corrected1 will not be used for sorting!
motion_folder = '/my/folder'
rec_corrected1 = correct_motion(recording=rec1, preset="nonrigid_accurate", folder=motion_folder)
# preprocessing 2 : highpass + cmr
rec2 = si.highpass_filter(recording=raw_rec, freq_min=300.)
rec2 = si.common_reference(recording=rec2, reference='global', operator='median')
# we use another preprocessing for the final interpolation
motion_info = load_motion_info(motion_folder)
rec_corrected2 = interpolate_motion(
recording=rec2,
motion=motion_info['motion'],
**motion_info['parameters']['interpolate_motion_kwargs'])
sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2)