Source code for macro_eeg_model.evaluation.evaluator

# standard imports
import sys

# external imports
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from tqdm import tqdm
import numpy as np

# local imports
from .simulation_data_extractor import SimulationDataExtractor
from .coherence_computer import CoherenceComputer
from .peak_tester import PeakTester
from .fooof_tester import FooofTester
from macro_eeg_model.utils.plotting_setup import PLOT_SIZE, COLORS, notation, PLOT_FORMAT
from macro_eeg_model.utils.paths import paths


[docs] class Evaluator: """ A class responsible for evaluating simulated EEG data . It computes metrics such as coherence and power spectra across different brain regions (nodes). Attributes ---------- frequencies : list The frequency range for evaluating the data ([0, 30] Hz). simulation_data_extractor : SimulationDataExtractor An instance of the :py:class:`src.evaluation.simulation_data_extractor.SimulationDataExtractor` class used to extract and process simulated EEG data. """
[docs] def __init__(self): """ Initializes the Evaluator class, setting up the frequency range and loading real and simulated data. """ self.frequencies = (0, 30) self.simulation_data_extractor = SimulationDataExtractor()
[docs] def evaluate(self, plot_overview=True): """ Evaluates and compares the coherence and power metrics using :py:meth:`_evaluate_metric`. Parameters ---------- plot_overview : bool, optional If True, generates overview plots for the evaluated metrics; if False, generates individual plots for (pairs of) brain regions. (default is True). """ nr_nodes = len(self.simulation_data_extractor.surface_nodes) nr_pairwise_plots = nr_nodes * (nr_nodes - 1) // 2 nr_rows = (nr_pairwise_plots + 1) // 2 nr_cols = 2 if nr_pairwise_plots > 1 else 1 self._evaluate_metric(self._evaluate_coherence_node_pair, "Evaluating coherence", plot_overview, nr_rows, nr_cols,"Coherences_summary") nr_nodes = len(self.simulation_data_extractor.nodes) # nr_rows = (nr_nodes + 1) // 2 # nr_cols = 2 if nr_nodes > 1 else 1 nr_rows = nr_nodes nr_cols = 1 self._evaluate_metric(self._evaluate_power_node, "Evaluating power", plot_overview, nr_rows, nr_cols, "Powers_summary") self._evaluate_metric(self._evaluate_peaks, "Evaluating peaks", plot_overview, nr_rows, nr_cols,"Peaks_summary") self._evaluate_metric(self._evaluate_fooof, "Evaluating fooof", plot_overview, nr_rows, nr_cols,"Fooof_summary") print(f"The evaluation plots have been saved in the 'plots' directory.")
[docs] def _evaluate_metric(self, evaluation_func, desc, plot_overview, rows, cols, save_file_name): """ A helper function to evaluate a specific metric (e.g., coherence or power) across nodes or node pairs. Parameters ---------- evaluation_func : function The function to evaluate the metric (:py:meth:`_evaluate_coherence_node_pair` or :py:meth:`_evaluate_power_node`). desc : str The description for the tqdm progress bar. plot_overview : bool If True, generates overview plots for the evaluated metrics; if False, generates individual plots for (pairs of) brain regions. rows : int The number of rows in the overview plot. cols : int The number of columns in the overview plot. save_file_name : str The file name for saving the overview plot. """ fig, ax = None, None if plot_overview: fig, ax = plt.subplots( nrows=rows, ncols=cols, sharex=True, sharey=True, figsize=(1.5 * PLOT_SIZE * cols, PLOT_SIZE * rows) ) with tqdm(desc=desc, unit=" iter", ascii=True, leave=False, file=sys.stdout) as pbar: node_combos = self._get_nodes(pairwise=evaluation_func == self._evaluate_coherence_node_pair) for plot_id, nodes in enumerate(node_combos): pbar.update(1) sys.stdout.flush() evaluation_func( *nodes, fig=fig, ax=self._get_ax(ax, rows, cols, plot_id), show_legend=(plot_id == len(node_combos) - 1) ) if plot_overview: # plt.show() path = paths.plots_path / f"{save_file_name}.{PLOT_FORMAT}" fig.savefig(path)
[docs] def _get_nodes(self, pairwise=False): """ Generates nodes or node pairs for evaluation. Parameters ---------- pairwise : bool, optional If True, generates pairs of nodes (for coherence evaluation), otherwise generates individual nodes (for power evaluation) (default is False). Returns ------ tuple A tuple containing one or two nodes, depending on the value of `pairwise`. """ nodes_to_return = [] if pairwise: nodes = list(self.simulation_data_extractor.surface_nodes) for i, node1 in enumerate(nodes): for node2 in nodes[i + 1:]: nodes_to_return.append((node1, node2)) # yield node1, node2 else: nodes = list(self.simulation_data_extractor.nodes) for node in nodes: nodes_to_return.append((node,)) # yield (node,) return nodes_to_return
[docs] def _evaluate_peaks(self, node, fig=None, ax=None, show_legend=True): """ Evaluates the presence of alpha peaks (using :py:meth:`_get_peaks`) and plots (using :py:meth:`_plot_metric`) detrended power spectrum for a given node. Parameters ---------- node : str The name of the brain region to evaluate. fig : matplotlib.figure.Figure, optional The figure object for plotting (default is None). ax : matplotlib.axes.Axes, optional The axis object for plotting (default is None). show_legend : bool, optional If True, shows the legend on the plot (default is True). """ frequencies, powers, p_values, test_names = self._get_peaks(node) yes_peak = "Y" no_peak = "N" label_addons = { key: f": {test_names[key]}, p={p_values[key]:.3f}, {yes_peak if p_values[key] < 0.05 else no_peak}" for key in p_values } mean_p_value = np.mean(list(p_values.values())) self._plot_metric( f"{notation(node)}: p={mean_p_value:.3f} ({yes_peak if mean_p_value < 0.05 else no_peak})", frequencies, powers, fig=fig, ax=ax, show_legend=False, y_label="Peaks", xlim=[self.frequencies[0], self.frequencies[1]], ylim=[0, 2e7], file_label=f"peaks_{node}", label_addons=label_addons )
[docs] def _evaluate_fooof(self, node, fig=None, ax=None, show_legend=True): """ Evaluates the presence of peaks (using :py:meth:`_get_fooof_peaks`) and plots (using :py:meth:`_plot_metric`) the peaks. Parameters ---------- node : str The name of the brain region to evaluate. fig : matplotlib.figure.Figure, optional The figure object for plotting (default is None). ax : matplotlib.axes.Axes, optional The axis object for plotting (default is None). show_legend : bool, optional If True, shows the legend on the plot (default is True). """ frequencies, binary_peaks = self._get_fooof_peaks(node) sim_counter = 1 upd_binary_peaks = {} alpha = (8, 13) nr_sims_in_alpha = 0 nr_sims = len(binary_peaks) alpha_count = 0 total_count = 0 for key, values in binary_peaks.items(): curr_alpha_count = np.sum(values[alpha[0]:alpha[1]]) alpha_count += curr_alpha_count total_count += np.sum(values) if curr_alpha_count > 0: nr_sims_in_alpha += 1 scaled_values = sim_counter * values cleared_values = [np.nan if val == 0 else val for val in scaled_values] upd_binary_peaks[key] = cleared_values sim_counter += 1 percentage_in_alpha = 100 * nr_sims_in_alpha / nr_sims percentage_of_alpha = 100 * alpha_count / total_count self._plot_metric( f"{notation(node)}: %in_a={percentage_in_alpha:.1f}, %of_a={percentage_of_alpha:.1f}", frequencies, upd_binary_peaks, plot_type="scatter", fig=fig, ax=ax, show_legend=False, y_label="Peak presence", file_label=f"fooof_{node}" )
[docs] def _evaluate_power_node(self, node, fig=None, ax=None, show_legend=True): """ Evaluates (using :py:meth:`_get_simulated_power`) and plots (using :py:meth:`_plot_metric`) the power spectrum for a given node. Parameters ---------- node : str The name of the brain region to evaluate. fig : matplotlib.figure.Figure, optional The figure object for plotting (default is None). ax : matplotlib.axes.Axes, optional The axis object for plotting (default is None). show_legend : bool, optional If True, shows the legend on the plot (default is True). """ sim_frequencies, sim_powers = self._get_simulated_power(node) self._plot_metric( f"{notation(node)}", sim_frequencies, sim_powers, # plot_type="mean_std", fig=fig, ax=ax, show_legend=show_legend if ax is not None else True, y_label="Power", xlim=[self.frequencies[0], self.frequencies[1]], ylim=[0, 2e7], file_label=f"power_{node}" )
[docs] def _evaluate_coherence_node_pair(self, node1, node2, fig=None, ax=None, show_legend=True): """ Evaluates (using :py:meth:`_get_simulated_coherences`) and plots (using :py:meth:`_plot_metric`) the coherence between a pair of nodes. Parameters ---------- node1 : str The name of the first brain region (node). node2 : str The name of the second brain region (node). fig : matplotlib.figure.Figure, optional The figure object for plotting (default is None). ax : matplotlib.axes.Axes, optional The axis object for plotting (default is None). show_legend : bool, optional If True, shows the legend on the plot (default is True). """ # simulated data sim_frequencies_coherence, sim_coherences = self._get_simulated_coherences(node1, node2) self._plot_metric( f"{notation(node1)}{notation(node2)}", sim_frequencies_coherence, sim_coherences, # plot_type="mean_std", fig=fig, ax=ax, show_legend=show_legend if ax is not None else True, y_label="Coherence", xlim=[1, self.frequencies[1]], ylim=[0, 0.6], file_label=f"coherence_{node1}_{node2}" )
[docs] def _get_fooof_peaks(self, node): """ Computes the peaks in the power spectrum for a given node using :py:class:`FooofTester`. Parameters ---------- node : str The name of the brain region for which to compute the peaks. Returns ------- tuple A tuple containing: - frequencies (numpy.ndarray): The array of frequencies. - all_binary_peaks (dict): A dictionary of binary peaks for each simulation, keyed by simulation name. """ frequencies = None all_binary_peaks = {} for key in self.simulation_data_extractor.simulation_names: powers = self.simulation_data_extractor.simulations_power_per_node[node][key][1] frequencies = self.simulation_data_extractor.simulations_info[key].frequencies fooof_tester = FooofTester(frequencies=frequencies) binary_peaks = fooof_tester.get_peaks_positions(powers) all_binary_peaks[key] = binary_peaks return frequencies, all_binary_peaks
[docs] def _get_peaks(self, node): """ Computes the peaks in the power spectrum for a given node using :py:class:`PeakTester`. Parameters ---------- node : str The name of the brain region for which to compute the peaks. Returns ------- tuple A tuple containing: - frequencies (numpy.ndarray): The array of frequencies. - powers (dict): A dictionary of simulated power spectra, keyed by simulation name. - p_values (dict): A dictionary of p-values for the peak test, keyed by simulation name. - test_names (dict): A dictionary of test names for the peak test, keyed by simulation name. """ frequencies = None powers = {} p_values = {} test_names = {} for key in self.simulation_data_extractor.simulation_names: epoched_powers = self.simulation_data_extractor.simulations_epoched_power_per_node[node][key] frequencies = self.simulation_data_extractor.simulations_info[key].frequencies peak_tester = PeakTester( frequencies=frequencies, peaks_range=(8, 13), others_range=(13, 30) ) frequencies, detrended_powers, p_value, test_name = peak_tester.compute_test_result(key, epoched_powers) powers[key] = detrended_powers p_values[key] = p_value test_names[key] = test_name return frequencies, powers, p_values, test_names
[docs] def _get_simulated_power(self, node): """ Retrieves the simulated power spectrum for a given node. Parameters ---------- node : str The name of the brain region for which to retrieve the simulated power spectrum. Returns ------- tuple A tuple containing: - frequencies (numpy.ndarray): The array of frequencies. - powers (dict): A dictionary of simulated power spectra, keyed by simulation name. """ simulations = self.simulation_data_extractor.simulations_power_per_node[node] frequencies, powers = zip(*[simulations[key] for key in self.simulation_data_extractor.simulation_names]) return frequencies[0], dict(zip(self.simulation_data_extractor.simulation_names, powers))
[docs] def _get_simulated_coherences(self, node1, node2): """ Computes the simulated coherence between a pair of nodes for each simulation using :py:meth:`src.simulation.coherence_computer.CoherenceComputer.compute_coherence_matched` . Parameters ---------- node1 : str The name of the first brain region. node2 : str The name of the second brain region. Returns ------- tuple A tuple containing: - frequencies (numpy.ndarray): The array of frequencies for coherence. - coherences (dict): A dictionary of simulated coherence values, keyed by simulation name. """ simulations1 = self.simulation_data_extractor.simulations_data_per_node[node1] simulations2 = self.simulation_data_extractor.simulations_data_per_node[node2] freq_coh_sim = None coherences = {} keys = self.simulation_data_extractor.simulation_names for key in keys: time_series1 = simulations1[key] time_series2 = simulations2[key] sample_rate = self.simulation_data_extractor.sample_rates[key] coherence_computer_sim = CoherenceComputer(fs=sample_rate) freq_coh_sim, coh = coherence_computer_sim.compute_coherence_matched(time_series1, time_series2) coherences[key] = coh return freq_coh_sim, coherences
[docs] def _plot_metric( self, title, sim_frequencies, sim_data, plot_type="line", fig=None, ax=None, show_legend=True, y_label=None, xlim=None, ylim=None, file_label=None, label_addons=None ): """ Plots a metric (e.g., coherence or power) of data using :py:meth:`_plot_simulated_data`. Parameters ---------- title : str The title of the plot. sim_frequencies : numpy.ndarray The array of frequencies for the simulated data. sim_data : dict The simulated data (e.g., power or coherence) to plot, keyed by simulation name. plot_type : str, optional The type of plot to create (default is "line"). Currently, "line", "scatter", "mean_std" are supported. fig : matplotlib.figure.Figure, optional The figure object for plotting (default is None). ax : matplotlib.axes.Axes, optional The axis object for plotting (default is None). show_legend : bool, optional If True, shows the legend on the plot (default is True). y_label : str, optional The label for the y-axis (default is None). xlim : list, optional The x-axis limits for the plot (default is None). ylim : list, optional The y-axis limits for the plot (default is None). file_label : str, optional The file name label for saving the plot (default is None). label_addons : dict, optional The dictionary of label addons to append to the name of the data (default is None). """ independent = fig is None or ax is None if independent: fig, ax = plt.subplots(figsize=(PLOT_SIZE * 2, PLOT_SIZE)) if label_addons is None: label_addons = {key: "" for key in sim_data} self._plot_simulated_data(ax, sim_frequencies, sim_data, label_addons, plot_type) ax.set_title(title) ax.set_xlim(xlim) ax.set_ylim(ylim) # ax.set_ylabel(y_label) ax.grid(which='both') ax.set_axisbelow(True) if show_legend: handles, labels = ax.get_legend_handles_labels() ax.legend(handles=handles, loc='upper right') if independent: # plt.show() path = paths.plots_path / f"{file_label}_{title}.{PLOT_FORMAT}" fig.savefig(path)
[docs] @staticmethod def _plot_simulated_data(ax, frequencies, data, label_addons, plot_type="line"): """ Plots the simulated EEG data on a given axis. Parameters ---------- ax : matplotlib.axes.Axes The axis object for plotting. frequencies : numpy.ndarray The array of frequencies for the simulated data. data : dict The simulated data (e.g., power or coherence) to plot, keyed by simulation name. label_addons : dict The dictionary of label addons to append to the name of the data. plot_type : str, optional The type of plot to create (default is "line"). Currently, "line", "scatter", "mean_std" are supported. """ if plot_type == "mean_std": mean_data = np.mean(list(data.values()), axis=0) std_data = np.std(list(data.values()), axis=0) ax.fill_between(frequencies, mean_data - std_data, mean_data + std_data, label="std", color=COLORS[1], alpha=0.5) ax.plot(frequencies, mean_data, label="mean", color=COLORS[0], alpha=1.0) return for i, (name, d) in enumerate(data.items()): if plot_type == "scatter": ax.scatter(frequencies, d, label=f"{name}{label_addons[name]}", color=COLORS[i % len(COLORS)], alpha=1.0) else: ax.plot(frequencies, d, label=f"{name}{label_addons[name]}", color=COLORS[i % len(COLORS)], alpha=1.0)
[docs] @staticmethod def _get_ax(ax, rows, cols, i): """ Helper function to get the appropriate subplot axis. Parameters ---------- ax : numpy.ndarray The array of axis objects for subplots. rows : int The number of rows in the subplot grid. cols : int The number of columns in the subplot grid. i : int The index of the current plot. Returns ------- matplotlib.axes.Axes The appropriate axis object for the current subplot. """ if cols == 1: if rows == 1: return ax return ax[i] if ax is None: return None return ax[i // cols, i % cols]