# standard imports
import sys
# external imports
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from tqdm import tqdm
import numpy as np
# local imports
from .simulation_data_extractor import SimulationDataExtractor
from .coherence_computer import CoherenceComputer
from .peak_tester import PeakTester
from .fooof_tester import FooofTester
from macro_eeg_model.utils.plotting_setup import PLOT_SIZE, COLORS, notation, PLOT_FORMAT
from macro_eeg_model.utils.paths import paths
[docs]
class Evaluator:
"""
A class responsible for evaluating simulated EEG data .
It computes metrics such as coherence and power spectra across different brain regions (nodes).
Attributes
----------
frequencies : list
The frequency range for evaluating the data ([0, 30] Hz).
simulation_data_extractor : SimulationDataExtractor
An instance of the :py:class:`src.evaluation.simulation_data_extractor.SimulationDataExtractor` class
used to extract and process simulated EEG data.
"""
[docs]
def __init__(self):
"""
Initializes the Evaluator class, setting up the frequency range and loading real and simulated data.
"""
self.frequencies = (0, 30)
self.simulation_data_extractor = SimulationDataExtractor()
[docs]
def evaluate(self, plot_overview=True):
"""
Evaluates and compares the coherence and power metrics
using :py:meth:`_evaluate_metric`.
Parameters
----------
plot_overview : bool, optional
If True, generates overview plots for the evaluated metrics;
if False, generates individual plots for (pairs of) brain regions.
(default is True).
"""
nr_nodes = len(self.simulation_data_extractor.surface_nodes)
nr_pairwise_plots = nr_nodes * (nr_nodes - 1) // 2
nr_rows = (nr_pairwise_plots + 1) // 2
nr_cols = 2 if nr_pairwise_plots > 1 else 1
self._evaluate_metric(self._evaluate_coherence_node_pair, "Evaluating coherence", plot_overview, nr_rows, nr_cols,"Coherences_summary")
nr_nodes = len(self.simulation_data_extractor.nodes)
# nr_rows = (nr_nodes + 1) // 2
# nr_cols = 2 if nr_nodes > 1 else 1
nr_rows = nr_nodes
nr_cols = 1
self._evaluate_metric(self._evaluate_power_node, "Evaluating power", plot_overview, nr_rows, nr_cols, "Powers_summary")
self._evaluate_metric(self._evaluate_peaks, "Evaluating peaks", plot_overview, nr_rows, nr_cols,"Peaks_summary")
self._evaluate_metric(self._evaluate_fooof, "Evaluating fooof", plot_overview, nr_rows, nr_cols,"Fooof_summary")
print(f"The evaluation plots have been saved in the 'plots' directory.")
[docs]
def _evaluate_metric(self, evaluation_func, desc, plot_overview, rows, cols, save_file_name):
"""
A helper function to evaluate a specific metric (e.g., coherence or power) across nodes or node pairs.
Parameters
----------
evaluation_func : function
The function to evaluate the metric
(:py:meth:`_evaluate_coherence_node_pair` or :py:meth:`_evaluate_power_node`).
desc : str
The description for the tqdm progress bar.
plot_overview : bool
If True, generates overview plots for the evaluated metrics;
if False, generates individual plots for (pairs of) brain regions.
rows : int
The number of rows in the overview plot.
cols : int
The number of columns in the overview plot.
save_file_name : str
The file name for saving the overview plot.
"""
fig, ax = None, None
if plot_overview:
fig, ax = plt.subplots(
nrows=rows, ncols=cols,
sharex=True, sharey=True,
figsize=(1.5 * PLOT_SIZE * cols, PLOT_SIZE * rows)
)
with tqdm(desc=desc, unit=" iter", ascii=True, leave=False, file=sys.stdout) as pbar:
node_combos = self._get_nodes(pairwise=evaluation_func == self._evaluate_coherence_node_pair)
for plot_id, nodes in enumerate(node_combos):
pbar.update(1)
sys.stdout.flush()
evaluation_func(
*nodes, fig=fig, ax=self._get_ax(ax, rows, cols, plot_id), show_legend=(plot_id == len(node_combos) - 1)
)
if plot_overview:
# plt.show()
path = paths.plots_path / f"{save_file_name}.{PLOT_FORMAT}"
fig.savefig(path)
[docs]
def _get_nodes(self, pairwise=False):
"""
Generates nodes or node pairs for evaluation.
Parameters
----------
pairwise : bool, optional
If True, generates pairs of nodes (for coherence evaluation),
otherwise generates individual nodes (for power evaluation)
(default is False).
Returns
------
tuple
A tuple containing one or two nodes, depending on the value of `pairwise`.
"""
nodes_to_return = []
if pairwise:
nodes = list(self.simulation_data_extractor.surface_nodes)
for i, node1 in enumerate(nodes):
for node2 in nodes[i + 1:]:
nodes_to_return.append((node1, node2))
# yield node1, node2
else:
nodes = list(self.simulation_data_extractor.nodes)
for node in nodes:
nodes_to_return.append((node,))
# yield (node,)
return nodes_to_return
[docs]
def _evaluate_peaks(self, node, fig=None, ax=None, show_legend=True):
"""
Evaluates the presence of alpha peaks (using :py:meth:`_get_peaks`)
and plots (using :py:meth:`_plot_metric`) detrended power spectrum for a given node.
Parameters
----------
node : str
The name of the brain region to evaluate.
fig : matplotlib.figure.Figure, optional
The figure object for plotting (default is None).
ax : matplotlib.axes.Axes, optional
The axis object for plotting (default is None).
show_legend : bool, optional
If True, shows the legend on the plot (default is True).
"""
frequencies, powers, p_values, test_names = self._get_peaks(node)
yes_peak = "Y"
no_peak = "N"
label_addons = {
key: f": {test_names[key]}, p={p_values[key]:.3f}, {yes_peak if p_values[key] < 0.05 else no_peak}"
for key in p_values
}
mean_p_value = np.mean(list(p_values.values()))
self._plot_metric(
f"{notation(node)}: p={mean_p_value:.3f} ({yes_peak if mean_p_value < 0.05 else no_peak})",
frequencies, powers,
fig=fig, ax=ax, show_legend=False,
y_label="Peaks", xlim=[self.frequencies[0], self.frequencies[1]], ylim=[0, 2e7], file_label=f"peaks_{node}",
label_addons=label_addons
)
[docs]
def _evaluate_fooof(self, node, fig=None, ax=None, show_legend=True):
"""
Evaluates the presence of peaks (using :py:meth:`_get_fooof_peaks`)
and plots (using :py:meth:`_plot_metric`) the peaks.
Parameters
----------
node : str
The name of the brain region to evaluate.
fig : matplotlib.figure.Figure, optional
The figure object for plotting (default is None).
ax : matplotlib.axes.Axes, optional
The axis object for plotting (default is None).
show_legend : bool, optional
If True, shows the legend on the plot (default is True).
"""
frequencies, binary_peaks = self._get_fooof_peaks(node)
sim_counter = 1
upd_binary_peaks = {}
alpha = (8, 13)
nr_sims_in_alpha = 0
nr_sims = len(binary_peaks)
alpha_count = 0
total_count = 0
for key, values in binary_peaks.items():
curr_alpha_count = np.sum(values[alpha[0]:alpha[1]])
alpha_count += curr_alpha_count
total_count += np.sum(values)
if curr_alpha_count > 0:
nr_sims_in_alpha += 1
scaled_values = sim_counter * values
cleared_values = [np.nan if val == 0 else val for val in scaled_values]
upd_binary_peaks[key] = cleared_values
sim_counter += 1
percentage_in_alpha = 100 * nr_sims_in_alpha / nr_sims
percentage_of_alpha = 100 * alpha_count / total_count
self._plot_metric(
f"{notation(node)}: %in_a={percentage_in_alpha:.1f}, %of_a={percentage_of_alpha:.1f}",
frequencies, upd_binary_peaks,
plot_type="scatter",
fig=fig, ax=ax, show_legend=False,
y_label="Peak presence", file_label=f"fooof_{node}"
)
[docs]
def _evaluate_power_node(self, node, fig=None, ax=None, show_legend=True):
"""
Evaluates (using :py:meth:`_get_simulated_power`)
and plots (using :py:meth:`_plot_metric`) the power spectrum for a given node.
Parameters
----------
node : str
The name of the brain region to evaluate.
fig : matplotlib.figure.Figure, optional
The figure object for plotting (default is None).
ax : matplotlib.axes.Axes, optional
The axis object for plotting (default is None).
show_legend : bool, optional
If True, shows the legend on the plot (default is True).
"""
sim_frequencies, sim_powers = self._get_simulated_power(node)
self._plot_metric(
f"{notation(node)}",
sim_frequencies, sim_powers,
# plot_type="mean_std",
fig=fig, ax=ax, show_legend=show_legend if ax is not None else True,
y_label="Power", xlim=[self.frequencies[0], self.frequencies[1]], ylim=[0, 2e7], file_label=f"power_{node}"
)
[docs]
def _evaluate_coherence_node_pair(self, node1, node2, fig=None, ax=None, show_legend=True):
"""
Evaluates (using :py:meth:`_get_simulated_coherences`)
and plots (using :py:meth:`_plot_metric`)
the coherence between a pair of nodes.
Parameters
----------
node1 : str
The name of the first brain region (node).
node2 : str
The name of the second brain region (node).
fig : matplotlib.figure.Figure, optional
The figure object for plotting (default is None).
ax : matplotlib.axes.Axes, optional
The axis object for plotting (default is None).
show_legend : bool, optional
If True, shows the legend on the plot (default is True).
"""
# simulated data
sim_frequencies_coherence, sim_coherences = self._get_simulated_coherences(node1, node2)
self._plot_metric(
f"{notation(node1)} — {notation(node2)}",
sim_frequencies_coherence, sim_coherences,
# plot_type="mean_std",
fig=fig, ax=ax, show_legend=show_legend if ax is not None else True,
y_label="Coherence", xlim=[1, self.frequencies[1]], ylim=[0, 0.6], file_label=f"coherence_{node1}_{node2}"
)
[docs]
def _get_fooof_peaks(self, node):
"""
Computes the peaks in the power spectrum for a given node using :py:class:`FooofTester`.
Parameters
----------
node : str
The name of the brain region for which to compute the peaks.
Returns
-------
tuple
A tuple containing:
- frequencies (numpy.ndarray): The array of frequencies.
- all_binary_peaks (dict): A dictionary of binary peaks for each simulation, keyed by simulation name.
"""
frequencies = None
all_binary_peaks = {}
for key in self.simulation_data_extractor.simulation_names:
powers = self.simulation_data_extractor.simulations_power_per_node[node][key][1]
frequencies = self.simulation_data_extractor.simulations_info[key].frequencies
fooof_tester = FooofTester(frequencies=frequencies)
binary_peaks = fooof_tester.get_peaks_positions(powers)
all_binary_peaks[key] = binary_peaks
return frequencies, all_binary_peaks
[docs]
def _get_peaks(self, node):
"""
Computes the peaks in the power spectrum for a given node using :py:class:`PeakTester`.
Parameters
----------
node : str
The name of the brain region for which to compute the peaks.
Returns
-------
tuple
A tuple containing:
- frequencies (numpy.ndarray): The array of frequencies.
- powers (dict): A dictionary of simulated power spectra, keyed by simulation name.
- p_values (dict): A dictionary of p-values for the peak test, keyed by simulation name.
- test_names (dict): A dictionary of test names for the peak test, keyed by simulation name.
"""
frequencies = None
powers = {}
p_values = {}
test_names = {}
for key in self.simulation_data_extractor.simulation_names:
epoched_powers = self.simulation_data_extractor.simulations_epoched_power_per_node[node][key]
frequencies = self.simulation_data_extractor.simulations_info[key].frequencies
peak_tester = PeakTester(
frequencies=frequencies,
peaks_range=(8, 13),
others_range=(13, 30)
)
frequencies, detrended_powers, p_value, test_name = peak_tester.compute_test_result(key, epoched_powers)
powers[key] = detrended_powers
p_values[key] = p_value
test_names[key] = test_name
return frequencies, powers, p_values, test_names
[docs]
def _get_simulated_power(self, node):
"""
Retrieves the simulated power spectrum for a given node.
Parameters
----------
node : str
The name of the brain region for which to retrieve the simulated power spectrum.
Returns
-------
tuple
A tuple containing:
- frequencies (numpy.ndarray): The array of frequencies.
- powers (dict): A dictionary of simulated power spectra, keyed by simulation name.
"""
simulations = self.simulation_data_extractor.simulations_power_per_node[node]
frequencies, powers = zip(*[simulations[key] for key in self.simulation_data_extractor.simulation_names])
return frequencies[0], dict(zip(self.simulation_data_extractor.simulation_names, powers))
[docs]
def _get_simulated_coherences(self, node1, node2):
"""
Computes the simulated coherence between a pair of nodes for each simulation using
:py:meth:`src.simulation.coherence_computer.CoherenceComputer.compute_coherence_matched` .
Parameters
----------
node1 : str
The name of the first brain region.
node2 : str
The name of the second brain region.
Returns
-------
tuple
A tuple containing:
- frequencies (numpy.ndarray): The array of frequencies for coherence.
- coherences (dict): A dictionary of simulated coherence values, keyed by simulation name.
"""
simulations1 = self.simulation_data_extractor.simulations_data_per_node[node1]
simulations2 = self.simulation_data_extractor.simulations_data_per_node[node2]
freq_coh_sim = None
coherences = {}
keys = self.simulation_data_extractor.simulation_names
for key in keys:
time_series1 = simulations1[key]
time_series2 = simulations2[key]
sample_rate = self.simulation_data_extractor.sample_rates[key]
coherence_computer_sim = CoherenceComputer(fs=sample_rate)
freq_coh_sim, coh = coherence_computer_sim.compute_coherence_matched(time_series1, time_series2)
coherences[key] = coh
return freq_coh_sim, coherences
[docs]
def _plot_metric(
self,
title,
sim_frequencies, sim_data,
plot_type="line",
fig=None, ax=None, show_legend=True, y_label=None, xlim=None, ylim=None, file_label=None, label_addons=None
):
"""
Plots a metric (e.g., coherence or power) of data
using :py:meth:`_plot_simulated_data`.
Parameters
----------
title : str
The title of the plot.
sim_frequencies : numpy.ndarray
The array of frequencies for the simulated data.
sim_data : dict
The simulated data (e.g., power or coherence) to plot, keyed by simulation name.
plot_type : str, optional
The type of plot to create (default is "line"). Currently, "line", "scatter", "mean_std" are supported.
fig : matplotlib.figure.Figure, optional
The figure object for plotting (default is None).
ax : matplotlib.axes.Axes, optional
The axis object for plotting (default is None).
show_legend : bool, optional
If True, shows the legend on the plot (default is True).
y_label : str, optional
The label for the y-axis (default is None).
xlim : list, optional
The x-axis limits for the plot (default is None).
ylim : list, optional
The y-axis limits for the plot (default is None).
file_label : str, optional
The file name label for saving the plot (default is None).
label_addons : dict, optional
The dictionary of label addons to append to the name of the data (default is None).
"""
independent = fig is None or ax is None
if independent:
fig, ax = plt.subplots(figsize=(PLOT_SIZE * 2, PLOT_SIZE))
if label_addons is None:
label_addons = {key: "" for key in sim_data}
self._plot_simulated_data(ax, sim_frequencies, sim_data, label_addons, plot_type)
ax.set_title(title)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
# ax.set_ylabel(y_label)
ax.grid(which='both')
ax.set_axisbelow(True)
if show_legend:
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, loc='upper right')
if independent:
# plt.show()
path = paths.plots_path / f"{file_label}_{title}.{PLOT_FORMAT}"
fig.savefig(path)
[docs]
@staticmethod
def _plot_simulated_data(ax, frequencies, data, label_addons, plot_type="line"):
"""
Plots the simulated EEG data on a given axis.
Parameters
----------
ax : matplotlib.axes.Axes
The axis object for plotting.
frequencies : numpy.ndarray
The array of frequencies for the simulated data.
data : dict
The simulated data (e.g., power or coherence) to plot, keyed by simulation name.
label_addons : dict
The dictionary of label addons to append to the name of the data.
plot_type : str, optional
The type of plot to create (default is "line"). Currently, "line", "scatter", "mean_std" are supported.
"""
if plot_type == "mean_std":
mean_data = np.mean(list(data.values()), axis=0)
std_data = np.std(list(data.values()), axis=0)
ax.fill_between(frequencies, mean_data - std_data, mean_data + std_data, label="std", color=COLORS[1], alpha=0.5)
ax.plot(frequencies, mean_data, label="mean", color=COLORS[0], alpha=1.0)
return
for i, (name, d) in enumerate(data.items()):
if plot_type == "scatter":
ax.scatter(frequencies, d, label=f"{name}{label_addons[name]}", color=COLORS[i % len(COLORS)], alpha=1.0)
else:
ax.plot(frequencies, d, label=f"{name}{label_addons[name]}", color=COLORS[i % len(COLORS)], alpha=1.0)
[docs]
@staticmethod
def _get_ax(ax, rows, cols, i):
"""
Helper function to get the appropriate subplot axis.
Parameters
----------
ax : numpy.ndarray
The array of axis objects for subplots.
rows : int
The number of rows in the subplot grid.
cols : int
The number of columns in the subplot grid.
i : int
The index of the current plot.
Returns
-------
matplotlib.axes.Axes
The appropriate axis object for the current subplot.
"""
if cols == 1:
if rows == 1:
return ax
return ax[i]
if ax is None:
return None
return ax[i // cols, i % cols]