(Song, et al., 2016): Training excitatory-inhibitory recurrent network

Implementation of the paper:

  • Song, H. F. , G. R. Yang , and X. J. Wang . “Training Excitatory-Inhibitory Recurrent Neural Networks for Cognitive Tasks: A Simple and Flexible Framework.” Plos Computational Biology 12.2(2016):e1004792.

The original code is based on PyTorch (https://github.com/gyyang/nn-brain/blob/master/EI_RNN.ipynb). However, comparing with the PyTorch codes, the training on BrainPy speeds up nearly four folds.

Here we will train recurrent neural network with excitatory and inhibitory neurons on a simple perceptual decision making task.

[1]:
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd

bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
[18]:
bp.__version__, bd.__version__
[18]:
('2.4.3', '0.0.0.6')

Defining a perceptual decision making task

[4]:
dataset = bd.cognitive.RatePerceptualDecisionMaking(dt=20.)
task = bd.cognitive.TaskLoader(dataset, batch_size=16)

Define E-I recurrent network

Here we define a E-I recurrent network, in particular, no self-connections are allowed.

[5]:
class RNN(bp.DynamicalSystem):
  r"""E-I RNN.

  The RNNs are described by the equations

  .. math::

      \begin{gathered}
      \tau \dot{\mathbf{x}}=-\mathbf{x}+W^{\mathrm{rec}} \mathbf{r}+W^{\mathrm{in}}
      \mathbf{u}+\sqrt{2 \tau \sigma_{\mathrm{rec}}^{2}} \xi \\
      \mathbf{r}=[\mathbf{x}]_{+} \\
      \mathbf{z}=W^{\text {out }} \mathbf{r}
      \end{gathered}

  In practice, the continuous-time dynamics are discretized to Euler form
  in time steps of size :math:`\Delta t` as

  .. math::

     \begin{gathered}
      \mathbf{x}_{t}=(1-\alpha) \mathbf{x}_{t-1}+\alpha\left(W^{\mathrm{rec}} \mathbf{r}_{t-1}+
      W^{\mathrm{in}} \mathbf{u}_{t}\right)+\sqrt{2 \alpha \sigma_{\mathrm{rec}}^{2}} \mathbf{N}(0,1) \\
      \mathbf{r}_{t}=\left[\mathbf{x}_{t}\right]_{+} \\
      \mathbf{z}_{t}=W^{\mathrm{out}} \mathbf{r}_{t}
      \end{gathered}

  where :math:`\alpha = \Delta t/\tau` and :math:`N(0, 1)` are normally distributed
  random numbers with zero mean and unit variance, sampled independently at every time step.
  """

  def __init__(self, num_input, num_hidden, num_output, num_batch,
               dt=None, e_ratio=0.8, sigma_rec=0., seed=None,
               w_ir=bp.init.KaimingUniform(scale=1.),
               w_rr=bp.init.KaimingUniform(scale=1.),
               w_ro=bp.init.KaimingUniform(scale=1.)):
    super().__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    self.e_size = int(num_hidden * e_ratio)
    self.i_size = num_hidden - self.e_size
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.sigma_rec = (2 * self.alpha) ** 0.5 * sigma_rec  # Recurrent noise
    self.rng = bm.random.RandomState(seed=seed)

    # hidden mask
    mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1))
    np.fill_diagonal(mask, 0)
    self.mask = bm.asarray(mask, dtype=bm.float_)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
    self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size)
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    bound = 1 / self.e_size ** 0.5
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    self.reset_state(self.mode)

  def reset_state(self, batch_size):
    # variables
    self.h = bp.init.variable_(bm.zeros, self.num_hidden, batch_size)
    self.o = bp.init.variable_(bm.zeros, self.num_output, batch_size)

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ (bm.abs(self.w_rr) * self.mask) + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    state += self.sigma_rec * self.rng.randn(self.num_hidden)
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h)
    self.o.value = self.readout(self.h[:, :self.e_size])
    return self.h.value, self.o.value

  @bm.cls_jit
  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs)

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    return bp.losses.cross_entropy_loss(os, ys.flatten())

Train the network on the decision making task

[6]:
# Instantiate the network and print information
hidden_size = 50
with bm.environment(mode=bm.TrainingMode(batch_size=16)):
  net = RNN(num_input=dataset.num_inputs,
            num_hidden=hidden_size,
            num_output=dataset.num_outputs,
            num_batch=task.batch_size,
            dt=dataset.dt,
            sigma_rec=0.15)
[7]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[8]:
# gradient function
grad_f = bm.grad(net.loss,
                 grad_vars=net.train_vars().unique(),
                 return_value=True)
[9]:
@bm.jit
def train(xs, ys):
  grads, loss = grad_f(xs, ys)
  opt.update(grads)
  return loss

The training speeds up nearly 4 times, comparing with the original PyTorch codes.

[10]:
running_loss = []
for i_batch in range(20):
  for inputs, labels in task:
    loss = train(inputs, labels)
    running_loss.append(loss)
  print(f'Batch {i_batch}, Loss {np.mean(running_loss):0.4f}')
  running_loss = []
Batch 0, Loss 1.4421
Batch 1, Loss 0.3261
Batch 2, Loss 0.2789
Batch 3, Loss 0.2650
Batch 4, Loss 0.2515
Batch 5, Loss 0.2358
Batch 6, Loss 0.2271
Batch 7, Loss 0.2200
Batch 8, Loss 0.2155
Batch 9, Loss 0.2105
Batch 10, Loss 0.2081
Batch 11, Loss 0.2018
Batch 12, Loss 0.2009
Batch 13, Loss 0.1985
Batch 14, Loss 0.1960
Batch 15, Loss 0.1939
Batch 16, Loss 0.1911
Batch 17, Loss 0.1895
Batch 18, Loss 0.1862
Batch 19, Loss 0.1843

Run the network post-training and record neural activity

[11]:
num_trial = 500
task = bd.cognitive.TaskLoader(dataset, batch_size=num_trial)
inputs, trial_infos = task.get_batch()

net.reset_state(num_trial)
rnn_activity, action_pred = net.predict(inputs)
rnn_activity = np.asarray(rnn_activity)
trial_infos = np.asarray(trial_infos)
[12]:
i_start = int(dataset.t_fixation / dataset.dt)
i_end = int((dataset.t_fixation + dataset.t_stimulus) / dataset.dt)
[13]:
choice1_ind = trial_infos[-1] == 1
choice2_ind = trial_infos[-1] == 2
# response for ground-truth 0 and 1
stim_activity = [rnn_activity[i_start: i_end, choice1_ind],
                 rnn_activity[i_start: i_end, choice2_ind],]

Plot neural activity from sample trials

[14]:
trial = 2
plt.figure(figsize=(8, 6))
_ = plt.plot(rnn_activity[:, trial, :net.e_size], color='blue', label='Excitatory')
_ = plt.plot(rnn_activity[:, trial, net.e_size:], color='red', label='Inhibitory')
plt.xlabel('Time step')
plt.ylabel('Activity')
plt.show()
../_images/recurrent_networks_Song_2016_EI_RNN_22_0.png

Compute stimulus selectivity for sorting neurons

Here for each neuron we compute its stimulus period selectivity \(d'\)

[15]:
mean_activity = []
std_activity = []
for ground_truth in [0, 1]:
  activity = np.reshape(stim_activity[ground_truth],
                        (-1, net.num_hidden))
  mean_activity.append(np.mean(activity, axis=0))
  std_activity.append(np.std(activity, axis=0))

# Compute d'
selectivity = (mean_activity[0] - mean_activity[1])
selectivity /= np.sqrt((std_activity[0] ** 2 + std_activity[1] ** 2 + 1e-7) / 2)

# Sort index for selectivity, separately for E and I
ind_sort = np.concatenate((np.argsort(selectivity[:net.e_size]),
                           np.argsort(selectivity[net.e_size:]) + net.e_size))

Plot network connectivity sorted by stimulus selectivity

[16]:
# Plot distribution of stimulus selectivity
plt.figure(figsize=(6, 4))
plt.hist(selectivity)
plt.xlabel('Selectivity')
plt.ylabel('Number of neurons')
plt.show()
../_images/recurrent_networks_Song_2016_EI_RNN_26_0.png
[17]:
W = (bm.abs(net.w_rr) * net.mask).numpy()
# Sort by selectivity
W = W[:, ind_sort][ind_sort, :]
wlim = np.max(np.abs(W))

plt.figure(figsize=(10, 10))
plt.imshow(W, cmap='bwr_r', vmin=-wlim, vmax=wlim)
plt.colorbar()
plt.xlabel('From neurons')
plt.ylabel('To neurons')
plt.title('Network connectivity')
plt.tight_layout()
plt.show()
../_images/recurrent_networks_Song_2016_EI_RNN_27_0.png