Skip to article content

Self-orthogonalizing attractor neural networks emerging from the free energy principle

Simulation 3

Sequence learning

Here, we trained the network on a sequence of 3 handwritten digits (1,2,3), with a fixed order of presentation (1→2→3→1→2→3→...), for 2000 epochs, each epoch consisting of a single step (Figure 5A). This rapid presentation of the input sequence forced the network to model the current attractor from the network’s response to the previous pattern, i.e. to establish sequence attractors. The inverse temperature was set to 1 and the learning rate to 0.001 (in a supplementary analysis, we saw a considerable robustness of our results to the choice of these parameters). As shown on Figure 5B, this training approach led to an asymmetric coupling matrix (it was very close to symmetric in all previous simulations). Based on eq.-s (13) and (14), we decomposed the coupling matrix into a symmetric and antisymmetric part (Figure 5C and D). Retrieving the fixed-point attractors for the symmetric component of the coupling matrix, we obtained three attractors, corresponding to the three digits (Figure 5C and E). The antisymmetric component of the coupling matrix, on the other hand, was encoding the sequence dynamics. Indeed, letting the network freely run (with zero bias) resulted in a spontanously emerging sequance of variations of the digits 1→2→3→1→2→3→1→..., reflecting the original training order (Figure 5F). This illustrates that the proposed framework is capable of producing and handling assymetric couplings, and thereby learn sequences.

Imports

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from simulation.network import AttractorNetwork, Langevin, relax, inverse_Langevin
from simulation.utils import fetch_digits_data, preprocess_digits_data, continous_inference_and_learning, get_deterministic_attractors
from joblib import Parallel, delayed
from copy import deepcopy
from tqdm import tqdm

Load data

digits = fetch_digits_data()
train_data, _ = preprocess_digits_data(digits)
<Figure size 1000x300 with 20 Axes><Figure size 1000x1000 with 100 Axes><Figure size 300x300 with 1 Axes>

Plot training data

data = train_data[1:4].copy()
data *= 20

fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(10, 10))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < data.shape[0]:
        image = data[i].reshape(8, 8) + np.random.normal(0, 0.1, (8,8)) 
        for j in range(image.shape[0]):
            for k in range(image.shape[1]):
                image[j, k] = Langevin(image[j, k])

        ax.imshow(image, cmap="gray_r", interpolation="nearest", vmin=-np.max(image), vmax=np.max(image))
        ax.set_axis_off()
    else:
        ax.set_visible(False)

plt.show()
sns.histplot(data.flatten())
<Figure size 1000x1000 with 100 Axes>
<Axes: ylabel='Count'>
<Figure size 640x480 with 1 Axes>

Training

Set up the network

size = data.shape[1]
num_variables = data.shape[0]
J = np.zeros((size, size))
nw = AttractorNetwork(J, biases = np.zeros(J.shape[0]))

Sequential training

Present data in a fixed order and extremely short epochs (1 step)

num_iterations = 2000

for i in tqdm(range(num_iterations)):
    di = i % data.shape[0] #np.random.randint(0, data.shape[0])
    activations, weight_change, accuracy, complexity, vfe = continous_inference_and_learning(
        nw=nw, 
        data=data[di], 
        inverse_temperature=1, 
        learning_rate=0.001, 
        num_steps=1)
    
sns.heatmap(nw.get_J(), cmap="coolwarm", center=0)
plt.show()

#for i, pattern_idx in enumerate(pattern):
#    plt.text(i*num_runs, max(error[i*num_runs:i*num_runs+num_runs]), str(pattern_idx), ha='center', va='bottom')
100%|██████████| 2000/2000 [00:11<00:00, 172.94it/s]
<Figure size 640x480 with 2 Axes>

Visualize the weight matrix and its symmetric and antisymmetric components

# Symmetrize the weight matrix J
J = nw.get_J()
J_symmetrized = 0.5 * (J + J.T)  # J_ij† = (1/2) * (J_ij + J_ji)

# Display the original and symmetrized weight matrices
fig, axes = plt.subplots(1, 3, figsize=(10, 3))
sns.heatmap(J, cmap="coolwarm", center=0, ax=axes[0])
axes[0].set_title("Original Weight Matrix J")

sns.heatmap(J_symmetrized, cmap="coolwarm", center=0, ax=axes[1])
axes[1].set_title("Symmetric Component J†")


# Calculate the asymmetric component of the weight matrix
J_asymmetric = 0.5 * (J - J.T)  # J_ij‡ = (1/2) * (J_ij - J_ji)

# Display the asymmetric component
sns.heatmap(J_asymmetric, cmap="coolwarm", center=0, ax=axes[2])
axes[2].set_title("Antisymmetric Component")
plt.tight_layout()
plt.show()



# Calculate and print the asymmetry measure
asymmetry = np.linalg.norm(J - J.T) / np.linalg.norm(J)
print(f"Asymmetry measure: {asymmetry:.6f}")
<Figure size 1000x300 with 6 Axes>
Asymmetry measure: 0.981247

Free-running (spontaneous activity)

acts, weight_change, accuracy, complexity, vfe = continous_inference_and_learning(
    nw, np.zeros(data.shape[1]), inverse_temperature=1, 
    learning_rate=0, num_steps=100)
fig, axes = plt.subplots(5, 10, figsize=(20, 10))
axes = axes.flatten()
for i, ax in enumerate(axes):
    sns.heatmap(np.array(acts[i+10]).reshape(8,8), cmap="gray_r", ax=ax, cbar=False)
    ax.set_xticks([])
    ax.set_yticks([])
plt.tight_layout()
plt.show()
<Figure size 2000x1000 with 50 Axes>

Attractors from the symmetric component

nw_sym = AttractorNetwork(J_symmetrized, biases = np.zeros(nw.get_J().shape[0]))
get_deterministic_attractors(nw_sym, data, noise_levels=[0], inverse_temperature=1 )
  ** Noise: 0
100%|██████████| 3/3 [00:02<00:00,  1.01it/s]
<Figure size 2000x400 with 6 Axes>
[array([-0.90288564, -0.90288734, -0.89395016, 0.92772164, 0.90749495, -0.90028988, -0.90216315, -0.90270948, -0.90288256, -0.90231346, -0.87466679, 0.8948815 , 0.90994825, -0.6187639 , -0.90212455, -0.90255703, -0.90307682, -0.90252808, -0.89942102, 0.90867119, 0.90704861, -0.89955458, -0.90247444, -0.90232051, -0.90299654, -0.89213355, 0.9190238 , 0.93219895, 0.90873427, -0.90150859, -0.90267223, -0.90182199, -0.90294999, -0.90271066, -0.89998355, 0.91462843, 0.90905729, -0.75843228, -0.90298625, -0.90265723, -0.90301133, -0.90503416, -0.90455038, 0.91211094, 0.93237412, 0.80662162, -0.90336047, -0.90280665, -0.90293977, -0.90259084, -0.87176695, 0.92059984, 0.91000916, -0.44331117, -0.92467333, -0.90183692, -0.90296146, -0.90220506, -0.89362384, 0.92309671, 0.90894171, 0.8974486 , -0.90722171, -0.90273355]), array([-0.90260279, -0.90234075, -0.90303629, -0.89376333, 0.90722312, 0.92418188, -0.90226095, -0.90180304, -0.90282406, -0.91785139, -0.92862477, 0.91179096, 0.90820096, 0.92968106, -0.90219291, -0.90190966, -0.90269278, -0.90237912, -0.88075182, 0.90650091, 0.85832826, 0.92502362, -0.90250502, -0.90194942, -0.90239322, -0.89711952, -0.84089627, 0.42139728, 0.90814186, 0.92154957, -0.90276228, -0.90196099, -0.90224946, -0.90213559, -0.8832218 , 0.9385866 , 0.90840627, -0.91558364, -0.90235638, -0.90207662, -0.9021762 , -0.79929213, 0.9248548 , 0.9380629 , -0.19952918, -0.91233694, -0.91872469, -0.90177328, -0.90264956, -0.90172886, 0.89974382, 0.93749983, 0.92545748, 0.90349607, -0.94236079, -0.90147602, -0.90218063, -0.90127172, -0.90442236, -0.90875701, 0.90329421, 0.90389035, -0.79742375, -0.90240095]), array([-0.89776996, -0.89737846, -0.84969839, 0.94880558, 0.90373144, -0.8945477 , -0.89677821, -0.89757324, -0.89784976, -0.56051117, 0.9263714 , 0.78135273, 0.90545583, -0.84857729, -0.89684061, -0.89737488, -0.89748397, -0.89630102, -0.89675088, 0.90384461, 0.90871205, -0.89008014, -0.8978052 , -0.89751025, -0.89756914, -0.8936341 , -0.76538435, 0.94645642, 0.90258711, -0.89686633, -0.89799551, -0.89652715, -0.89756203, -0.89693139, -0.89630859, -0.90080782, 0.90351773, 0.93540067, -0.89796881, -0.89747417, -0.8977793 , -0.91494646, -0.89130629, -0.90368371, -0.62970443, 0.93656167, -0.56908093, -0.89785701, -0.89754378, -0.89723332, 0.8407246 , -0.42193345, 0.45210824, 0.88947096, 0.82641821, -0.89713562, -0.89761694, -0.89703909, -0.84471754, 0.94637602, 0.90372611, 0.8942036 , -0.91770835, -0.89738004])]
Attractor networks and active inference
Simulation 2
Attractor networks and active inference
Simulation 4