Source code for macro_eeg_model.simulation.eeg_analyzer
# external imports
import numpy as np
import matplotlib.pyplot as plt
# local imports
from macro_eeg_model.utils.plotting_setup import notation, PLOT_SIZE, PLOT_FORMAT
[docs]
class EEGAnalyzer:
"""
The EEGAnalyzer class is responsible computing the power spectrum of EEG data.
"""
[docs]
@staticmethod
def calculate_power(data, sample_rate):
"""
Applies the Fast Fourier Transform (FFT) to the EEG data to calculate the power spectrum.
It returns the frequencies and the average power spectrum across epochs/samples per second.
Parameters
----------
data : numpy.ndarray
The EEG data to be analyzed (a 3D array with dimensions (time, nodes, epochs)).
sample_rate : int
The sample rate of the EEG data in Hz.
Returns
-------
tuple
A tuple containing:
- frequencies (numpy.ndarray): The array of frequencies corresponding to the power spectrum.
- power (numpy.ndarray): The calculated power spectrum for each frequency and node.
Raises
------
ValueError
If the user-defined frequencies are outside the valid range determined by the Nyquist frequency.
"""
frequencies = [0, 50]
# Recalculate if maximum larger than Nyquist frequency
nyquist_frequency = (2 / 5) * sample_rate
if frequencies[1] > nyquist_frequency:
print(
f"User defined maximum frequency ({frequencies[1]}) is larger than the Nyquist frequency ({nyquist_frequency})")
print("Using Nyquist frequnecy as maximum")
frequencies[1] = nyquist_frequency
# Recalculate if minimum is smaller than Nyquist sampling rate
nyquist_sampling_rate = sample_rate / data.shape[0]
if frequencies[0] < nyquist_sampling_rate and frequencies[0] != 0:
print(
f"User defined minimum frequency ({frequencies[0]}) is smaller than Nyquist sampling ({nyquist_sampling_rate})")
print("Using Nyquist sampling rate as minimum frequency")
frequencies[0] = nyquist_sampling_rate
if frequencies[0] <= 1:
frequencies[0] = 0
# Power
fourier = np.fft.fft(data, axis=0) / data.shape[0]
used_frequencies = np.arange(0, frequencies[1] + nyquist_sampling_rate, nyquist_sampling_rate)
# Power in standardized units (\muV^2/Hz)
power = np.mean(np.abs(fourier) ** 2, axis=2) * (2 / nyquist_sampling_rate)
# Find the indices for the min and max frequency
min_index = np.argmin(np.abs(frequencies[0] - used_frequencies))
max_index = len(used_frequencies)
return used_frequencies[min_index:max_index], power[min_index:max_index, :]
[docs]
@staticmethod
def plot_power(frequencies, power, nodes, plots_dir):
"""
Visualizes the power spectrum of the EEG data (for each node/channel) as a line plot.
Parameters
----------
frequencies : numpy.ndarray
The array of frequencies corresponding to the power spectrum.
power : numpy.ndarray
The calculated power spectrum for each frequency and node.
nodes : list[str]
The list of node/channel names corresponding to the data.
plots_dir : pathlib.Path
The directory where the plots are saved.
Raises
------
AssertionError
If the plots directory does not exist.
"""
assert plots_dir.exists(), f"Directory not found: {plots_dir}"
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1.5 * PLOT_SIZE, PLOT_SIZE))
ax.plot(frequencies, power)
#ax.set_xlabel('Frequency (Hz)')
#ax.set_ylabel('Standardized Power (units$^2$/Hz)')
ax.grid(which='both')
plt.legend([notation(nodes[i]) for i in range(power.shape[1])], loc="upper right")
plt.xticks(np.arange(0, 51, 5))
plt.xlim([0, 30])
# plt.show()
path = plots_dir / f"Power_spectrum.{PLOT_FORMAT}"
fig.savefig(path)