Train RNN to Solve Parametric Working Memory

[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym

import neurogym as ngym
[4]:
# Environment
task = 'DelayComparison-v0'
timing = {'delay': ('choice', [200, 400, 800, 1600, 3200]),
          'response': ('constant', 500)}
kwargs = {'dt': 100, 'timing': timing}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task,
                       env_kwargs=kwargs,
                       batch_size=16,
                       seq_len=seq_len)

# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, def_act=0, fig_kwargs={'figsize': (8, 6)})
plt.show()
../_images/recurrent_networks_ParametricWorkingMemory_4_0.png
[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size
[6]:
class RNN(bp.DynamicalSystem):
  def __init__(self, num_input, num_hidden, num_output, num_batch,
               w_ir=bp.init.KaimingNormal(scale=1.),
               w_rr=bp.init.KaimingNormal(scale=1.),
               w_ro=bp.init.KaimingNormal(scale=1.),
               dt=None, seed=None):
    super(RNN, self).__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.rng = bm.random.RandomState(seed=seed)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, size=(num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, size=(num_hidden, num_hidden)))
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, size=(num_hidden, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    # variables
    self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
    self.o = bm.Variable(bm.zeros((num_batch, num_output)))

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h.value)
    self.o.value = self.readout(self.h.value)
    return self.h.value, self.o.value

  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs)

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    loss = bp.losses.cross_entropy_loss(os, ys.flatten())
    return loss, os
[7]:
# Instantiate the network and print information
hidden_size = 64
net = RNN(num_input=input_size,
          num_hidden=hidden_size,
          num_output=output_size,
          num_batch=batch_size,
          dt=env.dt)
[8]:
predict = bm.jit(net.predict)
[9]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[10]:
# gradient function
grad = bm.grad(net.loss,
               child_objs=net,
               grad_vars=net.train_vars().unique(),
               return_value=True,
               has_aux=True)
[11]:
@bm.jit
@bm.to_object(child_objs=(grad, opt))
def train(xs, ys):
  grads, loss, os = grad(xs, ys)
  opt.update(grads)
  return loss, os
[12]:
running_acc = 0
running_loss = 0
for i in range(2000):
  inputs, labels_np = dataset()
  inputs = bm.asarray(inputs)
  labels = bm.asarray(labels_np)
  loss, outputs = train(inputs, labels)
  running_loss += loss
  # Compute performance
  output_np = np.argmax(bm.as_numpy(outputs), axis=-1).flatten()
  labels_np = labels_np.flatten()
  ind = labels_np > 0  # Only analyze time points when target is not fixation
  running_acc += np.mean(labels_np[ind] == output_np[ind])
  if i % 100 == 99:
    running_loss /= 100
    running_acc /= 100
    print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(i + 1, running_loss, running_acc))
    running_loss = 0
    running_acc = 0
Step 100, Loss 0.1653, Acc 0.149
Step 200, Loss 0.0321, Acc 0.674
Step 300, Loss 0.0238, Acc 0.764
Step 400, Loss 0.0187, Acc 0.825
Step 500, Loss 0.0150, Acc 0.857
Step 600, Loss 0.0124, Acc 0.884
Step 700, Loss 0.0099, Acc 0.909
Step 800, Loss 0.0110, Acc 0.894
Step 900, Loss 0.0092, Acc 0.912
Step 1000, Loss 0.0089, Acc 0.913
Step 1100, Loss 0.0080, Acc 0.926
Step 1200, Loss 0.0079, Acc 0.921
Step 1300, Loss 0.0079, Acc 0.923
Step 1400, Loss 0.0075, Acc 0.924
Step 1500, Loss 0.0080, Acc 0.921
Step 1600, Loss 0.0071, Acc 0.930
Step 1700, Loss 0.0071, Acc 0.930
Step 1800, Loss 0.0074, Acc 0.920
Step 1900, Loss 0.0069, Acc 0.928
Step 2000, Loss 0.0071, Acc 0.926
[13]:
def run(num_trial=1):
  env.reset(no_step=True)
  perf = 0
  activity_dict = {}
  trial_infos = {}
  for i in range(num_trial):
    env.new_trial()
    ob, gt = env.ob, env.gt
    inputs = bm.asarray(ob[:, np.newaxis, :])
    rnn_activity, action_pred = predict(inputs)
    rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
    activity_dict[i] = rnn_activity
    trial_infos[i] = env.trial

  # Concatenate activity for PCA
  activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0)
  print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)

  # Print trial informations
  for i in range(5):
    if i >= num_trial: break
    print('Trial ', i, trial_infos[i])

  pca = PCA(n_components=2)
  pca.fit(activity)
  # print('Shape of the projected activity: (Time points, PCs): ', activity_pc.shape)

  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(6, 3))
  for i in range(num_trial):
    activity_pc = pca.transform(activity_dict[i])
    trial = trial_infos[i]
    color = 'red' if trial['ground_truth'] == 0 else 'blue'
    _ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
    if i < 3:
      _ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
  ax1.set_xlabel('PC 1')
  ax1.set_ylabel('PC 2')
  plt.show()
[14]:
run(num_trial=1)
Shape of the neural activity: (Time points, Neurons):  (32, 64)
Trial  0 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
../_images/recurrent_networks_ParametricWorkingMemory_14_1.png
[15]:
run(num_trial=20)
Shape of the neural activity: (Time points, Neurons):  (562, 64)
Trial  0 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial  1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  2 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial  3 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial  4 {'ground_truth': 1, 'vpair': (18, 10), 'v1': 18, 'v2': 10}
../_images/recurrent_networks_ParametricWorkingMemory_15_1.png
[16]:
run(num_trial=100)
Shape of the neural activity: (Time points, Neurons):  (2778, 64)
Trial  0 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial  1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  2 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  3 {'ground_truth': 2, 'vpair': (22, 14), 'v1': 14, 'v2': 22}
Trial  4 {'ground_truth': 1, 'vpair': (22, 14), 'v1': 22, 'v2': 14}
../_images/recurrent_networks_ParametricWorkingMemory_16_1.png