Note
Go to the end to download the full example code.
Alpha CSC on empirical time-series with strong artifacts#
This example illustrates how to learn univariate atoms on a univariate time-serie affected by strong artifacts. The data is a single LFP channel recorded on a rodent’s striatum [1]. Interestingly in this time-serie, the high frequency oscillations around 80 Hz are modulated in amplitude by the low-frequency oscillation around 3 Hz, a phenomenon known as cross-frequency coupling (CFC).
The convolutional sparse coding (CSC) model is able to learn the prototypical waveforms of the signal, on which we can clearly see the CFC. However, when the CSC model is fitted on a data section with strong artifacts, the learned atoms do not show the expected CFC waveforms. To solve this problem, another model can be used, called alpha-CSC [2], which is less affected by strong artifacts in the data.
# Authors: Tom Dupre La Tour <tom.duprelatour@telecom-paristech.fr>
# Mainak Jas <mainak.jas@telecom-paristech.fr>
# Umut Simsekli <umut.simsekli@telecom-paristech.fr>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: BSD (3-clause)
Let us first load the data sample.
import mne
import numpy as np
import matplotlib.pyplot as plt
# sample frequency
sfreq = 350.
# We load the signal. It is an LFP channel recorded on a rodent's striatum.
data = np.load('../rodent_striatum.npy')
print(data.shape)
(1, 630000)
Now let us take a closer look, plotting the 100 first seconds of the signal.
As we can see, the data contains severe artifacts. We will thus compare three approaches to tackle these artifacts:
First, we will fit a CSC model on a section not affected by artifacts.
Then, we will fit a CSC model on a section affected by artifacts.
Finally, we will fit an alpha-CSC model on a section affected by artifacts.
# Define a clean data section.
start, stop = [100, 600] # in seconds
start, stop = int(start * sfreq), int(stop * sfreq)
data_clean = data[:, start:stop].copy()
# Define a dirty data section (same length).
start, stop = [0, 500] # in seconds
start, stop = int(start * sfreq), int(stop * sfreq)
data_dirty = data[:, start:stop].copy()
# We also remove the slow drift, which accounts for a lot of variance.
data_clean = mne.filter.filter_data(data_clean, sfreq, 1, None)
data_dirty = mne.filter.filter_data(data_dirty, sfreq, 1, None)
# To make the most of parallel computing, we split the data into trials.
data_clean = data_clean.reshape(50, -1)
data_dirty = data_dirty.reshape(50, -1)
# We scale the data, since parameter alpha is scale dependant.
scale = data_clean.std()
data_clean /= scale
data_dirty /= scale
Setting up high-pass filter at 1 Hz
FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1155 samples (3.300 s)
Setting up high-pass filter at 1 Hz
FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Filter length: 1155 samples (3.300 s)
This sample contains CFC between 3 Hz and 80 Hz. This phenomenon can be described with a comodulogram, computed for instance with the pactools Python library.
from pactools import Comodulogram
comod = Comodulogram(fs=sfreq, low_fq_range=np.arange(0.2, 10.2, 0.2),
low_fq_width=2., method='duprelatour')
comod.fit(data_clean)
comod.plot()
plt.show()
[ ] 0% | 0.00 sec | comodulogram: DAR(10, 1)
[ ] 2% | 0.22 sec | comodulogram: DAR(10, 1)
[. ] 4% | 0.32 sec | comodulogram: DAR(10, 1)
[.. ] 6% | 0.42 sec | comodulogram: DAR(10, 1)
[... ] 8% | 0.53 sec | comodulogram: DAR(10, 1)
[.... ] 10% | 0.63 sec | comodulogram: DAR(10, 1)
[.... ] 12% | 0.73 sec | comodulogram: DAR(10, 1)
[..... ] 14% | 0.84 sec | comodulogram: DAR(10, 1)
[...... ] 16% | 0.94 sec | comodulogram: DAR(10, 1)
[....... ] 18% | 1.05 sec | comodulogram: DAR(10, 1)
[........ ] 20% | 1.15 sec | comodulogram: DAR(10, 1)
[........ ] 22% | 1.26 sec | comodulogram: DAR(10, 1)
[......... ] 24% | 1.36 sec | comodulogram: DAR(10, 1)
[.......... ] 26% | 1.46 sec | comodulogram: DAR(10, 1)
[........... ] 28% | 1.58 sec | comodulogram: DAR(10, 1)
[............ ] 30% | 1.68 sec | comodulogram: DAR(10, 1)
[............ ] 32% | 1.79 sec | comodulogram: DAR(10, 1)
[............. ] 34% | 1.90 sec | comodulogram: DAR(10, 1)
[.............. ] 36% | 2.01 sec | comodulogram: DAR(10, 1)
[............... ] 38% | 2.12 sec | comodulogram: DAR(10, 1)
[................ ] 40% | 2.23 sec | comodulogram: DAR(10, 1)
[................ ] 42% | 2.34 sec | comodulogram: DAR(10, 1)
[................. ] 44% | 2.44 sec | comodulogram: DAR(10, 1)
[.................. ] 46% | 2.55 sec | comodulogram: DAR(10, 1)
[................... ] 48% | 2.65 sec | comodulogram: DAR(10, 1)
[.................... ] 50% | 2.75 sec | comodulogram: DAR(10, 1)
[.................... ] 52% | 2.86 sec | comodulogram: DAR(10, 1)
[..................... ] 54% | 2.97 sec | comodulogram: DAR(10, 1)
[...................... ] 56% | 3.07 sec | comodulogram: DAR(10, 1)
[....................... ] 58% | 3.18 sec | comodulogram: DAR(10, 1)
[........................ ] 60% | 3.33 sec | comodulogram: DAR(10, 1)
[........................ ] 62% | 3.43 sec | comodulogram: DAR(10, 1)
[......................... ] 64% | 3.53 sec | comodulogram: DAR(10, 1)
[.......................... ] 66% | 3.63 sec | comodulogram: DAR(10, 1)
[........................... ] 68% | 3.74 sec | comodulogram: DAR(10, 1)
[............................ ] 70% | 3.84 sec | comodulogram: DAR(10, 1)
[............................ ] 72% | 3.94 sec | comodulogram: DAR(10, 1)
[............................. ] 74% | 4.05 sec | comodulogram: DAR(10, 1)
[.............................. ] 76% | 4.15 sec | comodulogram: DAR(10, 1)
[............................... ] 78% | 4.25 sec | comodulogram: DAR(10, 1)
[................................ ] 80% | 4.37 sec | comodulogram: DAR(10, 1)
[................................ ] 82% | 4.47 sec | comodulogram: DAR(10, 1)
[................................. ] 84% | 4.57 sec | comodulogram: DAR(10, 1)
[.................................. ] 86% | 4.68 sec | comodulogram: DAR(10, 1)
[................................... ] 88% | 4.78 sec | comodulogram: DAR(10, 1)
[.................................... ] 90% | 4.88 sec | comodulogram: DAR(10, 1)
[.................................... ] 92% | 4.98 sec | comodulogram: DAR(10, 1)
[..................................... ] 94% | 5.09 sec | comodulogram: DAR(10, 1)
[...................................... ] 96% | 5.19 sec | comodulogram: DAR(10, 1)
[....................................... ] 98% | 5.29 sec | comodulogram: DAR(10, 1)
[........................................] 100% | 5.44 sec | comodulogram: DAR(10, 1)
[........................................] 100% | 5.44 sec | comodulogram: DAR(10, 1)
Here we define the plotting function which display the learned atoms.
def plot_atoms(d_hat):
n_atoms, n_times_atom = d_hat.shape
n_columns = min(6, n_atoms)
n_rows = int(np.ceil(n_atoms // n_columns))
figsize = (4 * n_columns, 3 * n_rows)
fig, axes = plt.subplots(n_rows, n_columns, figsize=figsize, sharey=True)
axes = axes.ravel()
# Plot the temporal pattern of the atom
for kk in range(n_atoms):
ax = axes[kk]
time = np.arange(n_times_atom) / sfreq
ax.plot(time, d_hat[kk], color='C%d' % kk)
ax.set_xlim(0, n_times_atom / sfreq)
ax.set(xlabel='Time (sec)', title="Temporal pattern %d" % kk)
ax.grid(True)
fig.tight_layout()
plt.show()
Then we define the common parameters of the different models.
common_params = dict(
n_atoms=3,
n_times_atom=int(sfreq * 1.0), # 1000. ms
reg=3.,
solver_z='l-bfgs',
solver_z_kwargs=dict(factr=1e9),
solver_d_kwargs=dict(factr=1e2),
random_state=42,
n_jobs=5,
verbose=1)
# number of iterations
n_iter = 10
# Parameter of the alpha-stable distribution for the alpha-CSC model.
# 0 < alpha < 2
# A value of 2 would correspond to the Gaussian noise model, as in vanilla CSC.
alpha = 1.2
First, we fit a CSC model on the clean data. Interestingly, we obtain prototypical waveforms of the signal on which we can clearly see the CFC.
from alphacsc import learn_d_z, learn_d_z_weighted
X = data_clean
_, _, d_hat, z_hat, _ = learn_d_z(X, n_iter=n_iter, **common_params)
plot_atoms(d_hat)
V_0/10 /github/workspace/alphacsc/update_d.py:224: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
lhs_c_copy[0] += lambd
/github/workspace/alphacsc/update_d.py:172: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
lambd_hats[k] = lambd_hat
.........
Then, if we fit a CSC model on the dirty data, the model is strongly affected by the artifacts, and we cannot see CFC anymore in the temporal waveforms.
V_0/10 .........
Finally, If we fit an alpha-CSC model on the dirty data, the model is less affected by the artifacts, and we are able to see CFC in the temporal waveforms.
X = data_dirty
d_hat, z_hat, tau = learn_d_z_weighted(
X, n_iter_optim=n_iter, n_iter_global=3, n_iter_mcmc=300,
n_burnin_mcmc=100, alpha=alpha, **common_params)
plot_atoms(d_hat)
Global Iter: 0/3 V_0/10 .........
Global Iter: 1/3 V_0/10 .........
Global Iter: 2/3 V_0/10 .........
Total running time of the script: (5 minutes 40.308 seconds)