# external imports
import numpy as np
# local imports
from .nodes_processor import NodesProcessor
from macro_eeg_model.utils.paths import paths
from macro_eeg_model.data_prep.data_preparator import DataPreparator
[docs]
class ConnectivityModel:
"""
A class to model the connectivity between brain nodes. It computes
distances and connectivity weights between nodes, with optional
relay stations.
Attributes
----------
nodes : list
The processed list of nodes used in the model.
nr_nodes : int
The total number of nodes in the model.
distances : numpy.ndarray
The matrix of distances between nodes.
connectivity_weights : numpy.ndarray
The matrix of connectivity weights between nodes.
_given_nodes : list
The list of nodes provided for the connectivity model.
_relay_station : str
The relay station node name, if any.
_relay_nodes : list
The list of relay nodes derived from the relay station, if applicable.
_relay_indices : list
The indices of the relay nodes in the connectivity model.
_nodes_indices : dict
A dictionary mapping each node to its corresponding indices.
_avg_counts : numpy.ndarray
The average counts of connections between nodes.
_avg_fc : numpy.ndarray
The average functional connectivity between nodes.
_avg_lengths : numpy.ndarray
The average distances (lengths) between nodes.
_relay_distances : dict
The dictionary of average distances between nodes and the relay station.
"""
[docs]
def __init__(self, given_nodes, relay_station):
"""
Initializes the ConnectivityModel with given nodes and an optional relay station.
Parameters
----------
given_nodes : list
The list of nodes to be used in the connectivity model.
relay_station : str
The relay station name (or None).
"""
self._given_nodes = given_nodes
self._relay_station = relay_station
self._relay_nodes = None
self._relay_indices = None
self.nodes = None
self._nodes_indices = None
self.nr_nodes = None
self._avg_counts = None
self._avg_fc = None
self._avg_lengths = None
self._relay_distances = None
self.distances = None
self.connectivity_weights = None
self._init_data()
self._init_nodes()
self._init_connectivity()
self._init_relay_distances()
[docs]
def set_connectivity(self, custom_connectivity):
"""
Assigns streamline lengths to the distances matrix (relayed, if applicable) and
weights to the connectivity matrix based on either
custom-provided values or default calculations based on functional connectivity (FC).
The values for a pair of nodes are extracted from :py:meth:`_get_pair_stats`.
Parameters
----------
custom_connectivity : bool
If True, attempts to load and use custom connectivity weights from
`connectivity_weights.csv` file in the configs path (see :py:class:`src.utils.paths.Paths`).
Raises
------
AssertionError
If the shape of the custom connectivity matrix is incorrect or the matrix has been incorrectly constructed.
"""
custom_connectivity_path = paths.configs_path / "connectivity_weights.csv"
custom_connectivity_weights = None
if custom_connectivity:
try:
custom_connectivity_weights = np.loadtxt(custom_connectivity_path, delimiter=",")
assert custom_connectivity_weights.shape == (self.nr_nodes, self.nr_nodes), "Custom connectivity matrix has wrong shape."
if not np.allclose(custom_connectivity_weights, custom_connectivity_weights.T):
# make the lower triangle equal to the upper triangle
custom_connectivity_weights = np.triu(custom_connectivity_weights) + np.triu(
custom_connectivity_weights, 1).T
except:
print(f"Could not load custom connectivity weights from {custom_connectivity_path}. Using FC connectivity weights.")
custom_connectivity = False
# loop through every pair of nodes to fill the symmetrical matrix
for i in range(self.nr_nodes):
for j in range(i + 1, self.nr_nodes):
_, fcs, distances = self._get_pair_stats(self.nodes[i], self.nodes[j])
if self._relay_station is not None:
distances_k, distances_l = zip(*distances)
distance = (np.mean(distances_k), np.mean(distances_l))
else:
distance = np.mean(distances)
self.distances[i, j] = self.distances[j, i] = distance
if custom_connectivity:
connectivity_weight = custom_connectivity_weights[i, j] #* (0.9 ** 9)
else:
pw = 1.5
connectivity_weight = (np.mean(fcs) ** pw) * (10 ** (pw - int(pw / 2))) * (0.9 ** 20)
self.connectivity_weights[i, j] = self.connectivity_weights[j, i] = connectivity_weight
assert np.isnan(self.connectivity_weights).sum() == self.nr_nodes, \
f"Expected connectivity weights to have {self.nr_nodes} NaNs, but got {np.isnan(self.connectivity_weights).sum()}"
[docs]
def _get_pair_stats(self, node1, node2):
"""
Retrieves statistics for a pair of nodes, including counts, functional connectivity, and distances.
Parameters
----------
node1 : str
The name of the first node.
node2 : str
The name of the second node.
Returns
-------
tuple
A tuple containing lists of counts, functional connectivity values,
and distances between the two nodes.
"""
indices_list1 = self._nodes_indices[node1]
indices_list2 = self._nodes_indices[node2]
counts, fcs, distances = [], [], []
for i in indices_list1:
for j in indices_list2:
if self._relay_station is None:
avg_dist = self._avg_lengths[i, j]
else:
avg_dist = (self._relay_distances[i], self._relay_distances[j])
counts.append(self._avg_counts[i, j])
fcs.append(self._avg_fc[i, j])
distances.append(avg_dist)
return counts, fcs, distances
[docs]
def _init_relay_distances(self):
"""
Calculates and stores the average distance between each node
and the relay station, if a relay station is specified.
"""
if self._relay_station is not None:
# get all indices from self._nodes_indices
all_indices = set(index for indices in self._nodes_indices.values() for index in indices)
self._relay_distances = {}
for i in all_indices:
distances_node_relay = []
for j in self._relay_indices:
distances_node_relay.append(self._avg_lengths[i, j])
self._relay_distances[i] = np.mean(distances_node_relay)
[docs]
def _init_nodes(self):
"""
Initializes and processes nodes using :py:meth:`src.config.nodes_processor.NodesProcessor.get_nodes_indices`.
"""
nodes_processor = NodesProcessor(
given_nodes=self._given_nodes,
relay_station=self._relay_station
)
self._relay_nodes, self._relay_indices, self.nodes, self._nodes_indices = (
nodes_processor.get_nodes_indices()
)
self.nr_nodes = len(self.nodes)
[docs]
def _init_connectivity(self):
"""
Initializes the connectivity matrix and distances between nodes.
It creates matrices for distances and connectivity weights
between nodes, initializing with zeros or tuples as appropriate
(depending on whether there is a relay station) and NaNs on the diagonal.
"""
if self._relay_station is None:
# initially 0's with nan on the diagonal
self.distances = np.zeros((self.nr_nodes, self.nr_nodes))
else:
# initially each element with tuple (0, 0), nan on the diagonal
self.distances = np.empty((self.nr_nodes, self.nr_nodes), dtype=object)
for i in range(self.nr_nodes):
for j in range(self.nr_nodes):
self.distances[i, j] = (0, 0)
np.fill_diagonal(self.distances, np.nan)
# initially 0's with np.nan on the diagonal
self.connectivity_weights = np.zeros((self.nr_nodes, self.nr_nodes))
np.fill_diagonal(self.connectivity_weights, np.nan)
[docs]
def _load_data(self):
"""
Loads precomputed connectivity data such as counts, functional connectivity,
and streamline lengths between nodes from the connectivity data path
(see :py:class:`src.utils.paths.Paths`).
"""
self._avg_counts = np.load(paths.connectivity_data_path / "avg_counts.npy")
self._avg_fc = np.load(paths.connectivity_data_path / "avg_fc.npy")
self._avg_lengths = np.load(paths.connectivity_data_path / "avg_lengths.npy")
[docs]
def _init_data(self):
"""
Checks if the necessary structural and functional connectivity data files exist.
If the files are found, it loads them; otherwise, it triggers the data preparation
process using :py:class:`src.data_prep.data_preparator.DataPreparator`
and then loads the prepared data.
"""
try:
self._load_data()
except:
data_preparator = DataPreparator()
directory_sc = "structural_connectivity_data"
directory_fc = "functional_connectivity_data"
print("Preparing Julich structural connectivity data...")
data_preparator.prep_and_save(
directory_name=directory_sc,
included_word="Lengths",
delimiter=",",
name="lengths"
)
data_preparator.prep_and_save(
directory_name=directory_sc,
included_word="Counts",
delimiter=",",
name="counts"
)
print("Preparing Julich functional connectivity data...")
data_preparator.prep_and_save(
directory_name=directory_fc,
included_word="concatenated",
delimiter=" ",
name="fc"
)
self._load_data()
print("Loaded connectivity data...")