Source code for macro_eeg_model.config.model_config

# standard imports
from typing import Optional

# external imports
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# local imports
from macro_eeg_model.utils.plotting_setup import notation, PLOT_SIZE, PLOT_FORMAT, COLOR_MAP
from macro_eeg_model.simulation.delay_calculator import DelayCalculator
from .connectivity_model import ConnectivityModel


[docs] class ModelConfig: """ A class to configure parameters and model the connectivity between brain nodes, including the distances, connectivity weights, and the generation of delays. Attributes ---------- nodes : list[str] The list of processed nodes used in the model. nr_nodes : int The total number of nodes in the model. relay_station : str, optional The relay station node name, if any. sample_rate : int The sampling rate of the model, in Hz. nr_lags : int The number of time lags calculated based on the sample rate and the total time (ms) in lags. t_secs : int The total time of the simulation in seconds. t_burnit : int The burn-in time for the simulation, in seconds. noise_color : str The color of the noise to be used in the simulation. std_noise : int The standard deviation of the noise to be used in the simulation. distances : numpy.ndarray A matrix containing the distances between the nodes. connectivity_weights : numpy.ndarray A matrix containing the connectivity weights between the nodes. delay_calculator : DelayCalculator An instance of the :py:class:`src.simulation.delay_calculator.DelayCalculator` class used to calculate delay distributions. _dist_shape : float The shape parameter for the delay distribution (xi in GEV distribution). _dist_scale : float The scale parameter for the delay distribution (sigma in GEV distribution). _dist_location : float The location parameter for the delay distribution (mu in GEV distribution). _truncation_percentile : float The percentile at which to truncate the delay distribution. """
[docs] def __init__( self, nodes: list[str], relay_station: Optional[str], sample_rate: int, t_lags: int, t_secs: int, t_burnit: int, noise_color: str, std_noise: int, dist_shape: float, dist_scale: float, dist_location: float, dist_trunc_percent: float, custom_connectivity=False ): """ Initializes the ModelConfig with specified parameters for nodes, connectivity, simulation, and delay distribution. Parameters ---------- nodes : list[str] The list of nodes to be used in the connectivity model. relay_station : str, optional The relay station name, if any. sample_rate : int The sampling rate of the model, in Hz. t_lags : int The total time in lags for the simulation. t_secs : int The total time of the simulation in seconds. t_burnit : int The burn-in time for the simulation, in seconds. noise_color : str The color of the noise to be used in the simulation. std_noise : int The standard deviation of the noise to be used in the simulation. dist_shape : float The shape parameter for the delay distribution (xi in GEV distribution). dist_scale : float The scale parameter for the delay distribution (sigma in GEV distribution). dist_location : float The location parameter for the delay distribution (mu in GEV distribution). dist_trunc_percent : float The percentile at which to truncate the delay distribution. custom_connectivity : bool, optional If True, use custom connectivity weights from a pre-specified file. """ # given params: # nodes self.relay_station = relay_station self.sample_rate = sample_rate x_sample_coeff = 1000.0 / self.sample_rate self.nr_lags = int(t_lags / x_sample_coeff) self.t_secs = t_secs self.t_burnit = t_burnit self.noise_color = noise_color self.std_noise = std_noise # generating self._dist_shape, self._dist_scale, self._dist_location = dist_shape, dist_scale, dist_location self._truncation_percentile = dist_trunc_percent self.delay_calculator = DelayCalculator( shape_param=dist_shape, scale_param=dist_scale, location_param=dist_location, truncation_percentile=dist_trunc_percent ) connectivity_model = ConnectivityModel(given_nodes=nodes, relay_station=self.relay_station) connectivity_model.set_connectivity(custom_connectivity=custom_connectivity) self.distances = connectivity_model.distances self.connectivity_weights = connectivity_model.connectivity_weights self.nodes = connectivity_model.nodes self.nr_nodes = connectivity_model.nr_nodes
[docs] def __str__(self): """ Returns a string representation of the ModelConfig object, including details about the nodes, connectivity, simulation parameters, and GEV distribution parameters. Returns ------- str A formatted string representation of the ModelConfig object. """ notation_nodes = [notation(node) for node in self.nodes] distances_df = pd.DataFrame(self.distances, index=notation_nodes, columns=notation_nodes) connectivity_weights_df = pd.DataFrame(self.connectivity_weights, index=notation_nodes, columns=notation_nodes) pd.set_option('display.max_columns', None) model_config_str = f""" ModelConfig --- NODES ----------------------------------------------- nodes = {self.nodes} relay_station = {self.relay_station} --- CONNECTIVITY ---------------------------------------- distances = {distances_df} connectivity_weights = {connectivity_weights_df} --- SIMULATION PARAMETERS ------------------------------- sample_rate = {self.sample_rate} nr_lags = {self.nr_lags} t_secs = {self.t_secs} t_burnit = {self.t_burnit} noise_color = {self.noise_color} std_noise = {self.std_noise} --- GEV DISTRIBUTION ------------------------------------ dist_shape (xi) = {self._dist_shape} dist_scale (sigma) = {self._dist_scale} dist_location (mu) = {self._dist_location} dist_trunc_percent = {self._truncation_percentile} """ # return textwrap.dedent(model_config_str) lines = model_config_str.splitlines() trimmed_lines = [line.lstrip() for line in lines] return "\n".join(trimmed_lines)
[docs] def plot(self, plots_dir): """ Plots (using :py:meth:`_plot_properties`) the connectivity model's distances (summed through the relay, if applicable) and normalized weights matrices using heatmaps. Parameters ---------- plots_dir : pathlib.Path The directory where the plots are saved. Raises ------ AssertionError If the plots directory does not exist. """ assert plots_dir.exists(), f"Directory not found: {plots_dir}" if not all([distance is None for distance in self.distances.flatten()]): if self.relay_station is None: distances = self.distances title = "Distances" else: distances = np.zeros((self.nr_nodes, self.nr_nodes)) np.fill_diagonal(distances, np.nan) for i in range(self.nr_nodes): for j in range(i + 1, self.nr_nodes): distances[i, j] = distances[j, i] = self.distances[i, j][0] + self.distances[i, j][1] title = "Distances_relayed" distances = np.triu(distances) distances[distances == 0] = np.nan self._plot_properties(distances[:-1, 1:], title, plots_dir, factor=0.1) connectivity_weights = np.triu(self.connectivity_weights) connectivity_weights = connectivity_weights / np.nanmax(connectivity_weights) # normalize the matrix connectivity_weights[connectivity_weights == 0] = np.nan self._plot_properties(connectivity_weights[:-1, 1:], "Connectivity_weights", plots_dir)
[docs] def _plot_properties(self, matrix, title, plots_dir, factor=1.0): """ Helper method to plot a heatmap of a given matrix with specified properties. Parameters ---------- matrix : numpy.ndarray The matrix to be plotted as a heatmap. title : str The title for the plot, used to label the saved file. plots_dir : pathlib.Path The directory where the plots are saved. factor : float, optional A scaling factor applied to the matrix values (default is 1.0). """ fig, ax = plt.subplots(figsize=(1.2 * PLOT_SIZE, 0.7 * PLOT_SIZE)) notations = [notation(node) for node in self.nodes] ax = sns.heatmap( matrix * factor, annot=True, fmt="0.3f", annot_kws={"size": 3 * PLOT_SIZE}, cbar=True, cbar_kws={ "pad": 0.1 }, cmap=COLOR_MAP, xticklabels=notations[1:], yticklabels=notations[:-1], linewidths=PLOT_SIZE / 2, linecolor="#FFFFFF" ) ax.xaxis.tick_top() ax.xaxis.set_label_position('top') ax.yaxis.tick_right() ax.yaxis.set_label_position('right') ax.set_yticklabels(ax.get_yticklabels(), rotation=0) # plt.show() path = plots_dir / f"{title}.{PLOT_FORMAT}" fig.savefig(path)