(Yang, 2020): Dynamical system analysis for RNN

Colab Open in Kaggle

Implementation of the paper:

  • Yang G R, Wang X J. Artificial neural networks for neuroscientists: A primer[J]. Neuron, 2020, 107(6): 1048-1070.

The original implementation is based on PyTorch: https://github.com/gyyang/nn-brain/blob/master/RNN%2BDynamicalSystemAnalysis.ipynb

[161]:
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd

bp.math.set_platform('cpu')
[162]:
bp.__version__
[162]:
'2.4.3'
[163]:
bd.__version__
[163]:
'0.0.0.6'
[164]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In this tutorial, we will use supervised learning to train a recurrent neural network on a simple perceptual decision making task, and analyze the trained network using dynamical system analysis.

Defining a cognitive task

[165]:
dataset = bd.cognitive.RatePerceptualDecisionMaking()
task = bd.cognitive.TaskLoader(dataset, batch_size=16)

Define a vanilla continuous-time recurrent network

Here we will define a continuous-time neural network but discretize it in time using the Euler method.

:nbsphinx-math:`begin{align}

tau frac{dmathbf{r}}{dt} = -mathbf{r}(t) + f(W_r mathbf{r}(t) + W_x mathbf{x}(t) + mathbf{b}_r).

end{align}`

This continuous-time system can then be discretized using the Euler method with a time step of \(\Delta t\),

:nbsphinx-math:`begin{align}

mathbf{r}(t+Delta t) = mathbf{r}(t) + Delta mathbf{r} = mathbf{r}(t) + frac{Delta t}{tau}[-mathbf{r}(t) + f(W_r mathbf{r}(t) + W_x mathbf{x}(t) + mathbf{b}_r)].

end{align}`

[166]:
class RNN(bp.DynamicalSystem):
  def __init__(self,
               num_input,
               num_hidden,
               num_output,
               num_batch,
               dt=None, seed=None,
               w_ir=bp.init.KaimingNormal(scale=1.),
               w_rr=bp.init.KaimingNormal(scale=1.),
               w_ro=bp.init.KaimingNormal(scale=1.)):
    super(RNN, self).__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.rng = bm.random.RandomState(seed=seed)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (num_hidden, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    self.reset_state(self.mode)

  def reset_state(self, batch_size):
    self.h = bp.init.variable_(bm.zeros, self.num_hidden, batch_size)
    self.o = bp.init.variable_(bm.zeros, self.num_output, batch_size)

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h)
    self.o.value = self.readout(self.h)
    return self.h.value, self.o.value

  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs)

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    loss = bp.losses.cross_entropy_loss(os, ys.flatten())
    return loss, os

Train the recurrent network on the decision-making task

[167]:
# Instantiate the network and print information
hidden_size = 64
with bm.environment(mode=bm.TrainingMode(batch_size=16)):
    net = RNN(num_input=dataset.num_inputs,
              num_hidden=hidden_size,
              num_output=dataset.num_outputs,
              num_batch=task.batch_size,
              dt=dataset.dt)
[168]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())

# gradient function
grad_f = bm.grad(net.loss,
                 grad_vars=net.train_vars().unique(),
                 return_value=True,
                 has_aux=True)

# training function
@bm.jit
def train(xs, ys):
  grads, l, os = grad_f(xs, ys)
  opt.update(grads)
  return l, os
[169]:
running_acc = []
running_loss = []
for i_batch in range(20):
    for X, Y in task:
        # training
        loss, outputs = train(bm.asarray(X), bm.asarray(Y))
        # Compute performance
        output_np = np.asarray(bm.argmax(outputs, axis=-1)).flatten()
        labels_np = np.asarray(Y).flatten()
        ind = labels_np > 0 # 0: fixation, 1: choice 1, 2: choice 2
        running_loss.append(loss)
        running_acc.append(np.mean(labels_np[ind] == output_np[ind]))
    print(f'Batch {i_batch + 1}, Loss {np.mean(running_loss):0.4f}, Acc {np.mean(running_acc):0.3f}')
    running_loss = []
    running_acc = []
Batch 1, Loss 0.2494, Acc 0.164
Batch 2, Loss 0.0526, Acc 0.663
Batch 3, Loss 0.0375, Acc 0.766
Batch 4, Loss 0.0314, Acc 0.775
Batch 5, Loss 0.0294, Acc 0.780
Batch 6, Loss 0.0291, Acc 0.796
Batch 7, Loss 0.0249, Acc 0.830
Batch 8, Loss 0.0251, Acc 0.812
Batch 9, Loss 0.0223, Acc 0.827
Batch 10, Loss 0.0209, Acc 0.848
Batch 11, Loss 0.0218, Acc 0.817
Batch 12, Loss 0.0220, Acc 0.822
Batch 13, Loss 0.0191, Acc 0.853
Batch 14, Loss 0.0176, Acc 0.861
Batch 15, Loss 0.0216, Acc 0.832
Batch 16, Loss 0.0177, Acc 0.882
Batch 17, Loss 0.0180, Acc 0.864
Batch 18, Loss 0.0166, Acc 0.869
Batch 19, Loss 0.0162, Acc 0.861
Batch 20, Loss 0.0160, Acc 0.871

Visualize neural activity for in sample trials

We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space.

[170]:
num_trial = 100
task = bd.cognitive.TaskLoader(dataset, batch_size=num_trial)
inputs, trial_infos = task.get_batch()

# reset the network state to match the required batch size
net.reset_state(num_trial)

# get the RNN activity
rnn_activity, _ = net.predict(inputs)
rnn_activity = np.asarray(rnn_activity)
trial_infos = np.asarray(trial_infos)

# Concatenate activity for PCA
activity = rnn_activity.reshape(-1, hidden_size)
print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)
Shape of the neural activity: (Time points, Neurons):  (2200, 64)
[171]:
pca = PCA(n_components=2)
pca.fit(activity)
[171]:
PCA(n_components=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Transform individual trials and Visualize in PC space based on ground-truth color. We see that the neural activity is organized by stimulus ground-truth in PC1

[172]:
trial_infos.shape, inputs.shape

[172]:
((22, 100), (22, 100, 3))
[173]:
plt.rcdefaults()
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5))
for i in range(num_trial):
    activity_pc = pca.transform(rnn_activity[:, i])
    color = 'red' if trial_infos[-1, i] == 1 else 'blue'
    _ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
    if i < 5:
        _ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)

ax1.set_xlabel('PC 1')
ax1.set_ylabel('PC 2')
plt.show()
../_images/recurrent_networks_Yang_2020_RNN_Analysis_21_0.png

Search for approximate fixed points

Here we search for approximate fixed points and visualize them in the same PC space. In a generic dynamical system, :nbsphinx-math:`begin{align}

frac{dmathbf{x}}{dt} = F(mathbf{x}),

end{align}` We can search for fixed points by doing the optimization :nbsphinx-math:`begin{align}

mathrm{argmin}_{mathbf{x}} |F(mathbf{x})|^2.

end{align}`

[174]:
f_cell = lambda h: net.cell(bm.asarray([1, 0.5, 0.5]), h)
[175]:
fp_candidates = bm.asarray(activity)
[176]:
fp_candidates.shape
[176]:
(2200, 64)
[177]:
finder = bp.analysis.SlowPointFinder(f_cell=f_cell, f_type='discrete')
finder.find_fps_with_gd_method(
    candidates=fp_candidates,
    tolerance=1e-4,
    num_batch=200,
    optimizer=bp.optim.Adam(lr=1e-3),
)
finder.filter_loss(tolerance=1e-4)
finder.keep_unique(tolerance=0.03)
finder.exclude_outliers(0.1)
fixed_points = finder.fixed_points
Optimizing with Adam(lr=Constant(lr=0.001, last_epoch=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.16 sec, Training loss 0.0802390501
    Batches 201-400 in 0.14 sec, Training loss 0.0552543662
    Batches 401-600 in 0.14 sec, Training loss 0.0439432897
    Batches 601-800 in 0.15 sec, Training loss 0.0374152139
    Batches 801-1000 in 0.14 sec, Training loss 0.0330325253
    Batches 1001-1200 in 0.14 sec, Training loss 0.0297763441
    Batches 1201-1400 in 0.14 sec, Training loss 0.0271856301
    Batches 1401-1600 in 0.13 sec, Training loss 0.0250399429
    Batches 1601-1800 in 0.14 sec, Training loss 0.0232114382
    Batches 1801-2000 in 0.14 sec, Training loss 0.0216213744
    Batches 2001-2200 in 0.15 sec, Training loss 0.0202192534
    Batches 2201-2400 in 0.13 sec, Training loss 0.0189735591
    Batches 2401-2600 in 0.13 sec, Training loss 0.0178660452
    Batches 2601-2800 in 0.14 sec, Training loss 0.0168767571
    Batches 2801-3000 in 0.14 sec, Training loss 0.0159866065
    Batches 3001-3200 in 0.13 sec, Training loss 0.0151826618
    Batches 3201-3400 in 0.14 sec, Training loss 0.0144545259
    Batches 3401-3600 in 0.14 sec, Training loss 0.0137864584
    Batches 3601-3800 in 0.14 sec, Training loss 0.0131686088
    Batches 3801-4000 in 0.15 sec, Training loss 0.0125987381
    Batches 4001-4200 in 0.13 sec, Training loss 0.0120685510
    Batches 4201-4400 in 0.14 sec, Training loss 0.0115718376
    Batches 4401-4600 in 0.12 sec, Training loss 0.0111064734
    Batches 4601-4800 in 0.29 sec, Training loss 0.0106634693
    Batches 4801-5000 in 0.13 sec, Training loss 0.0102427835
    Batches 5001-5200 in 0.13 sec, Training loss 0.0098414142
    Batches 5201-5400 in 0.15 sec, Training loss 0.0094555821
    Batches 5401-5600 in 0.14 sec, Training loss 0.0090830158
    Batches 5601-5800 in 0.13 sec, Training loss 0.0087208683
    Batches 5801-6000 in 0.13 sec, Training loss 0.0083650518
    Batches 6001-6200 in 0.14 sec, Training loss 0.0080158152
    Batches 6201-6400 in 0.15 sec, Training loss 0.0076781083
    Batches 6401-6600 in 0.14 sec, Training loss 0.0073548621
    Batches 6601-6800 in 0.14 sec, Training loss 0.0070449994
    Batches 6801-7000 in 0.13 sec, Training loss 0.0067473040
    Batches 7001-7200 in 0.15 sec, Training loss 0.0064625386
    Batches 7201-7400 in 0.14 sec, Training loss 0.0061870413
    Batches 7401-7600 in 0.14 sec, Training loss 0.0059184977
    Batches 7601-7800 in 0.14 sec, Training loss 0.0056603295
    Batches 7801-8000 in 0.13 sec, Training loss 0.0054131425
    Batches 8001-8200 in 0.12 sec, Training loss 0.0051786788
    Batches 8201-8400 in 0.14 sec, Training loss 0.0049580932
    Batches 8401-8600 in 0.14 sec, Training loss 0.0047499253
    Batches 8601-8800 in 0.15 sec, Training loss 0.0045528994
    Batches 8801-9000 in 0.14 sec, Training loss 0.0043659979
    Batches 9001-9200 in 0.14 sec, Training loss 0.0041924547
    Batches 9201-9400 in 0.14 sec, Training loss 0.0040282174
    Batches 9401-9600 in 0.13 sec, Training loss 0.0038750202
    Batches 9601-9800 in 0.16 sec, Training loss 0.0037310245
    Batches 9801-10000 in 0.14 sec, Training loss 0.0035944246
Excluding fixed points with squared speed above tolerance 0.0001:
    Kept 157/2200 fixed points with tolerance under 0.0001.
Excluding non-unique fixed points:
    Kept 69/157 unique fixed points with uniqueness tolerance 0.03.
Excluding outliers:
    Kept 66/69 fixed points with within outlier tolerance 0.1.

Visualize the found approximate fixed points.

We see that they found an approximate line attrator, corresponding to our PC1, along which evidence is integrated during the stimulus period.

[178]:
fixed_points.shape
[178]:
(66, 64)
[179]:
# Plot in the same space as activity
plt.figure(figsize=(10, 5))
for i in range(10):
    activity_pc = pca.transform(rnn_activity[:, i])
    color = 'red' if trial_infos[-1, i] == 0 else 'blue'
    plt.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color, alpha=0.1)

# Fixed points are shown in cross
fixedpoints_pc = pca.transform(fixed_points)
plt.plot(fixedpoints_pc[:, 0], fixedpoints_pc[:, 1], 'x', label='fixed points')

plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.legend()
plt.show()
../_images/recurrent_networks_Yang_2020_RNN_Analysis_30_0.png

Computing the Jacobian and finding the line attractor

[180]:
from jax import jacobian
[181]:
dFdh = jacobian(f_cell)(fixed_points[2])

eigval, eigvec = np.linalg.eig(bm.as_numpy(dFdh))
[182]:
# Plot distribution of eigenvalues in a 2-d real-imaginary plot
plt.figure()
plt.scatter(np.real(eigval), np.imag(eigval))
plt.plot([1, 1], [-1, 1], '--')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.show()
../_images/recurrent_networks_Yang_2020_RNN_Analysis_34_0.png