Source code for macro_eeg_model.evaluation.simulation_data_extractor

# external imports
import numpy as np

# local imports
from macro_eeg_model.utils.paths import paths
from macro_eeg_model.simulation.simulation_info import SimulationInfo
from macro_eeg_model.simulation.data_processor import DataProcessor


[docs] class SimulationDataExtractor: """ The SimulationDataExtractor class is responsible for extracting and processing simulation data. It organizes the data by nodes and simulations, allowing for easy access to both raw and processed data. Attributes ---------- nodes : numpy.ndarray An array of node names used in the simulations. simulation_names : list A list of simulation names. sample_rates : dict A dictionary mapping simulation names to their corresponding sample rates. simulations_data_per_node : dict A dictionary organizing the processed simulation data by node. simulations_power_per_node : dict A dictionary organizing the processed power spectra by node. simulations_epoched_power_per_node : dict A dictionary organizing the processed epoched power spectra by node. """
[docs] def __init__(self): """ Initializes the SimulationDataExtractor by loading and processing the simulation data using methods from this class. """ self.nodes = None self.surface_nodes = None self.simulations_info, self.sample_rates = self._get_simulations_info() self.simulation_names = list(self.simulations_info.keys()) self.simulation_names.sort() processed_simulations_data = self._get_processed_simulations_data(self.simulations_info) self.simulations_data_per_node = self._get_simulations_data_per_node(processed_simulations_data) processed_simulations_power = self._get_processed_simulations_power(self.simulations_info) self.simulations_power_per_node = self._get_simulations_power_per_node(processed_simulations_power) processed_simulations_epoched_power = self._get_processed_simulations_epoched_power(self.simulations_info) self.simulations_epoched_power_per_node = self._get_simulations_epoched_power_per_node(processed_simulations_epoched_power)
[docs] def _get_simulations_data_per_node(self, processed_simulations_data): """ Organizes the processed simulation data by node and then simulation name. Parameters ---------- processed_simulations_data : dict The dictionary containing processed simulation data organized by simulation name and then node. Returns ------- dict A dictionary organizing the simulation data by node and then simulation name. """ simulations_data_per_node = { node: { simulation_name: processed_simulations_data[simulation_name][node] for simulation_name in processed_simulations_data.keys() } for node in self.surface_nodes } return simulations_data_per_node
[docs] def _get_simulations_epoched_power_per_node(self, processed_simulations_epoched_power): """ Organizes the processed epoched power spectra by node and then simulation name. Parameters ---------- processed_simulations_epoched_power : dict The dictionary containing processed epoched power spectra organized by simulation name and then node. Returns ------- dict A dictionary organizing the epoched power spectra by node and then simulation name. """ simulations_epoched_power_per_node = { node: { simulation_name: processed_simulations_epoched_power[simulation_name][node] for simulation_name in processed_simulations_epoched_power.keys() } for node in self.nodes } return simulations_epoched_power_per_node
[docs] def _get_processed_simulations_epoched_power(self, simulations_info, epoch_len=1000): """ Processes and organizes the epoched power spectra data by simulation name and then node. Parameters ---------- simulations_info : dict A dictionary containing simulation information objects. epoch_len : int, optional The length of each epoch in milliseconds (default is 1000). Returns ------- dict A dictionary organizing the processed epoched power spectra data by simulation name and then node. """ processed_simulations_epoched_power = {} for simulation_name, simulation_info in simulations_info.items(): simulations_frequencies = simulation_info.frequencies simulation_data = simulation_info.simulation_data epoched_powers = [] for i, node in enumerate(self.nodes): node_simulation_data = simulation_data[:, i] # segment data into epochs nr_epochs = len(node_simulation_data) // epoch_len node_epoched_data = np.reshape(node_simulation_data[:nr_epochs * epoch_len], (nr_epochs, epoch_len)) # calculate power spectra for each epoch node_epoched_power = np.zeros((nr_epochs, node_epoched_data.shape[1])) for j in range(nr_epochs): epoch_data = node_epoched_data[j] fourier = np.fft.fft(epoch_data) / len(epoch_data) node_epoched_power[j] = np.abs(fourier) ** 2 epoched_powers.append(node_epoched_power) processed_simulations_epoched_power[simulation_name] = { node: np.array(epoched_powers[i]) for i, node in enumerate(self.nodes) } return processed_simulations_epoched_power
[docs] def _get_simulations_power_per_node(self, processed_simulations_power): """ Organizes the processed power spectra by node and then simulation name. Parameters ---------- processed_simulations_power : dict The dictionary containing processed power spectra organized by simulation name and then node. Returns ------- dict A dictionary organizing the power spectra by node and then simulation name. """ simulations_power_per_node = { node: { simulation_name: processed_simulations_power[simulation_name][node] for simulation_name in processed_simulations_power.keys() } for node in self.nodes } return simulations_power_per_node
[docs] def _get_processed_simulations_power(self, simulations_info): """ Processes and organizes the power spectra data by simulation name and then node. Parameters ---------- simulations_info : dict A dictionary containing simulation information objects. Returns ------- dict A dictionary organizing the processed power spectra data by simulation name and then node. """ processed_simulations_power = {} for simulation_name, simulation_info in simulations_info.items(): simulations_frequencies = simulation_info.frequencies simulations_power = simulation_info.power # swap dimensions of simulations_power simulations_power = np.swapaxes(simulations_power, 0, 1) processed_simulations_power[simulation_name] = { node: (np.array(simulations_frequencies), np.array(simulations_power[i])) for i, node in enumerate(self.nodes) } return processed_simulations_power
[docs] def _get_processed_simulations_data(self, simulations_info): """ Processes and organizes the raw simulation data by simulation name and then node. Parameters ---------- simulations_info : dict A dictionary containing simulation information objects. Returns ------- dict A dictionary organizing the processed simulation data by simulation name and then node. """ processed_simulations_data = {} for simulation_name, simulation_info in simulations_info.items(): simulation_data = simulation_info.simulation_data sample_rate = simulation_info.sample_rate nr_nodes = len(simulation_info.nodes) data_processor = DataProcessor() data = data_processor.segment_data(simulation_data, sample_rate=sample_rate, nr_nodes=nr_nodes) # data has shape (nr samples, nr nodes, nr epochs) # reshape data to (nr nodes, nr epochs, nr samples) data = np.transpose(data, (1, 2, 0)) # reshape to include x seconds (nr nodes, nr epochs * y, x = nr samples / y) nr_secs = 2 data = np.reshape(data, (data.shape[0], data.shape[1] // nr_secs, nr_secs * sample_rate)) processed_simulations_data[simulation_name] = { node: np.array(data[i]) for i, node in enumerate(self.surface_nodes) } return processed_simulations_data
[docs] def _get_simulations_info(self): """ Loads simulation information from saved files in the directories within the output path (see :py:class:`src.utils.paths.Paths`) using :py:meth:`src.simulation.simulation_info.SimulationInfo.load_simulation_info`. and checks for consistency in node names. Returns ------- tuple A tuple containing: - simulations_info (dict): A dictionary of SimulationInfo objects keyed by simulation name. - sample_rates (dict): A dictionary of sample rates keyed by simulation name. Raises ------ AssertionError If the nodes in any simulation do not match the expected node names. """ simulations_info = {} sample_rates = {} for folder in paths.output_path.iterdir(): if folder.is_dir(): output_simulation_dir = paths.output_path / folder.name simulation_info = SimulationInfo(output_dir=output_simulation_dir) simulation_info.load_simulation_info() if self.nodes is None: self.nodes = simulation_info.nodes self.surface_nodes = self._get_surface_nodes(self.nodes) assert all(simulation_info.nodes == self.nodes), \ f"Nodes do not match for simulation {folder.name}. Expected {self.nodes}, got {simulation_info.nodes}" simulations_info[folder.name] = simulation_info sample_rates[folder.name] = simulation_info.sample_rate return simulations_info, sample_rates
[docs] @staticmethod def _get_surface_nodes(nodes): """ Returns the surface nodes from the given list of nodes. Currently, the surface nodes are all nodes except the thalamus. Parameters ---------- nodes : list A list of node names. Returns ------- list A list of surface node names. """ surface_nodes = [node for node in nodes if node != "thalamus"] return surface_nodes