Source code for macro_eeg_model.evaluation.simulation_data_extractor
# external imports
import numpy as np
# local imports
from macro_eeg_model.utils.paths import paths
from macro_eeg_model.simulation.simulation_info import SimulationInfo
from macro_eeg_model.simulation.data_processor import DataProcessor
[docs]
class SimulationDataExtractor:
"""
The SimulationDataExtractor class is responsible for extracting and processing simulation data.
It organizes the data by nodes and simulations, allowing for easy access to both raw and processed data.
Attributes
----------
nodes : numpy.ndarray
An array of node names used in the simulations.
simulation_names : list
A list of simulation names.
sample_rates : dict
A dictionary mapping simulation names to their corresponding sample rates.
simulations_data_per_node : dict
A dictionary organizing the processed simulation data by node.
simulations_power_per_node : dict
A dictionary organizing the processed power spectra by node.
simulations_epoched_power_per_node : dict
A dictionary organizing the processed epoched power spectra by node.
"""
[docs]
def __init__(self):
"""
Initializes the SimulationDataExtractor by loading and processing the simulation data
using methods from this class.
"""
self.nodes = None
self.surface_nodes = None
self.simulations_info, self.sample_rates = self._get_simulations_info()
self.simulation_names = list(self.simulations_info.keys())
self.simulation_names.sort()
processed_simulations_data = self._get_processed_simulations_data(self.simulations_info)
self.simulations_data_per_node = self._get_simulations_data_per_node(processed_simulations_data)
processed_simulations_power = self._get_processed_simulations_power(self.simulations_info)
self.simulations_power_per_node = self._get_simulations_power_per_node(processed_simulations_power)
processed_simulations_epoched_power = self._get_processed_simulations_epoched_power(self.simulations_info)
self.simulations_epoched_power_per_node = self._get_simulations_epoched_power_per_node(processed_simulations_epoched_power)
[docs]
def _get_simulations_data_per_node(self, processed_simulations_data):
"""
Organizes the processed simulation data by node and then simulation name.
Parameters
----------
processed_simulations_data : dict
The dictionary containing processed simulation data organized by simulation name and then node.
Returns
-------
dict
A dictionary organizing the simulation data by node and then simulation name.
"""
simulations_data_per_node = {
node: {
simulation_name: processed_simulations_data[simulation_name][node] for simulation_name in
processed_simulations_data.keys()
}
for node in self.surface_nodes
}
return simulations_data_per_node
[docs]
def _get_simulations_epoched_power_per_node(self, processed_simulations_epoched_power):
"""
Organizes the processed epoched power spectra by node and then simulation name.
Parameters
----------
processed_simulations_epoched_power : dict
The dictionary containing processed epoched power spectra organized by simulation name and then node.
Returns
-------
dict
A dictionary organizing the epoched power spectra by node and then simulation name.
"""
simulations_epoched_power_per_node = {
node: {
simulation_name: processed_simulations_epoched_power[simulation_name][node] for simulation_name in
processed_simulations_epoched_power.keys()
}
for node in self.nodes
}
return simulations_epoched_power_per_node
[docs]
def _get_processed_simulations_epoched_power(self, simulations_info, epoch_len=1000):
"""
Processes and organizes the epoched power spectra data by simulation name and then node.
Parameters
----------
simulations_info : dict
A dictionary containing simulation information objects.
epoch_len : int, optional
The length of each epoch in milliseconds (default is 1000).
Returns
-------
dict
A dictionary organizing the processed epoched power spectra data by simulation name and then node.
"""
processed_simulations_epoched_power = {}
for simulation_name, simulation_info in simulations_info.items():
simulations_frequencies = simulation_info.frequencies
simulation_data = simulation_info.simulation_data
epoched_powers = []
for i, node in enumerate(self.nodes):
node_simulation_data = simulation_data[:, i]
# segment data into epochs
nr_epochs = len(node_simulation_data) // epoch_len
node_epoched_data = np.reshape(node_simulation_data[:nr_epochs * epoch_len], (nr_epochs, epoch_len))
# calculate power spectra for each epoch
node_epoched_power = np.zeros((nr_epochs, node_epoched_data.shape[1]))
for j in range(nr_epochs):
epoch_data = node_epoched_data[j]
fourier = np.fft.fft(epoch_data) / len(epoch_data)
node_epoched_power[j] = np.abs(fourier) ** 2
epoched_powers.append(node_epoched_power)
processed_simulations_epoched_power[simulation_name] = {
node: np.array(epoched_powers[i])
for i, node in enumerate(self.nodes)
}
return processed_simulations_epoched_power
[docs]
def _get_simulations_power_per_node(self, processed_simulations_power):
"""
Organizes the processed power spectra by node and then simulation name.
Parameters
----------
processed_simulations_power : dict
The dictionary containing processed power spectra organized by simulation name and then node.
Returns
-------
dict
A dictionary organizing the power spectra by node and then simulation name.
"""
simulations_power_per_node = {
node: {
simulation_name: processed_simulations_power[simulation_name][node] for simulation_name in
processed_simulations_power.keys()
}
for node in self.nodes
}
return simulations_power_per_node
[docs]
def _get_processed_simulations_power(self, simulations_info):
"""
Processes and organizes the power spectra data by simulation name and then node.
Parameters
----------
simulations_info : dict
A dictionary containing simulation information objects.
Returns
-------
dict
A dictionary organizing the processed power spectra data by simulation name and then node.
"""
processed_simulations_power = {}
for simulation_name, simulation_info in simulations_info.items():
simulations_frequencies = simulation_info.frequencies
simulations_power = simulation_info.power
# swap dimensions of simulations_power
simulations_power = np.swapaxes(simulations_power, 0, 1)
processed_simulations_power[simulation_name] = {
node: (np.array(simulations_frequencies), np.array(simulations_power[i]))
for i, node in enumerate(self.nodes)
}
return processed_simulations_power
[docs]
def _get_processed_simulations_data(self, simulations_info):
"""
Processes and organizes the raw simulation data by simulation name and then node.
Parameters
----------
simulations_info : dict
A dictionary containing simulation information objects.
Returns
-------
dict
A dictionary organizing the processed simulation data by simulation name and then node.
"""
processed_simulations_data = {}
for simulation_name, simulation_info in simulations_info.items():
simulation_data = simulation_info.simulation_data
sample_rate = simulation_info.sample_rate
nr_nodes = len(simulation_info.nodes)
data_processor = DataProcessor()
data = data_processor.segment_data(simulation_data, sample_rate=sample_rate, nr_nodes=nr_nodes)
# data has shape (nr samples, nr nodes, nr epochs)
# reshape data to (nr nodes, nr epochs, nr samples)
data = np.transpose(data, (1, 2, 0))
# reshape to include x seconds (nr nodes, nr epochs * y, x = nr samples / y)
nr_secs = 2
data = np.reshape(data, (data.shape[0], data.shape[1] // nr_secs, nr_secs * sample_rate))
processed_simulations_data[simulation_name] = {
node: np.array(data[i]) for i, node in enumerate(self.surface_nodes)
}
return processed_simulations_data
[docs]
def _get_simulations_info(self):
"""
Loads simulation information from saved files in the directories within the
output path (see :py:class:`src.utils.paths.Paths`) using
:py:meth:`src.simulation.simulation_info.SimulationInfo.load_simulation_info`.
and checks for consistency in node names.
Returns
-------
tuple
A tuple containing:
- simulations_info (dict): A dictionary of SimulationInfo objects keyed by simulation name.
- sample_rates (dict): A dictionary of sample rates keyed by simulation name.
Raises
------
AssertionError
If the nodes in any simulation do not match the expected node names.
"""
simulations_info = {}
sample_rates = {}
for folder in paths.output_path.iterdir():
if folder.is_dir():
output_simulation_dir = paths.output_path / folder.name
simulation_info = SimulationInfo(output_dir=output_simulation_dir)
simulation_info.load_simulation_info()
if self.nodes is None:
self.nodes = simulation_info.nodes
self.surface_nodes = self._get_surface_nodes(self.nodes)
assert all(simulation_info.nodes == self.nodes), \
f"Nodes do not match for simulation {folder.name}. Expected {self.nodes}, got {simulation_info.nodes}"
simulations_info[folder.name] = simulation_info
sample_rates[folder.name] = simulation_info.sample_rate
return simulations_info, sample_rates
[docs]
@staticmethod
def _get_surface_nodes(nodes):
"""
Returns the surface nodes from the given list of nodes.
Currently, the surface nodes are all nodes except the thalamus.
Parameters
----------
nodes : list
A list of node names.
Returns
-------
list
A list of surface node names.
"""
surface_nodes = [node for node in nodes if node != "thalamus"]
return surface_nodes