(Bouchacourt & Buschman, 2019) Flexible Working Memory Model

Colab Open in Kaggle

Implementation of :

  • Bouchacourt, Flora, and Timothy J. Buschman. “A flexible model of working memory.” Neuron 103.1 (2019): 147-160.

Author:

[1]:
import matplotlib.pyplot as plt
import jax
[2]:
import brainpy as bp
import brainpy.math as bm
[3]:
bp.__version__
[3]:
'2.4.3'
[4]:
# increase in order to run multiple trials with the same network
num_trials = 1
num_item_to_load = 6
[5]:
# Parameters for network architecture
# ------------------------------------

num_sensory_neuron = 512  # Number of recurrent neurons per sensory network
num_sensory_pool = 8  # Number of ring-like sensory network
num_all_sensory = num_sensory_pool * num_sensory_neuron
num_all_random = 1024  # Number of neuron in the random network
fI_slope = 0.4  # slope of the non-linear f-I function
bias = 0.  # bias in the neuron firing response (cf page 1 right column of Burak, Fiete 2012)
tau = 10.  # Synaptic time constant [ms]
init_range = 0.01  # Range to randomly initialize synaptic variables
[6]:
# Parameters for sensory network
# -------------------------------

k2 = 0.25  # width of negative surround suppression
k1 = 1.  # width of positive amplification
A = 2.  # amplitude of weight function
lambda_ = 0.28  # baseline of weight matrix for recurrent network
[7]:
# Parameters for interaction of
# random network <-> sensory network
# -----------------------------------

forward_balance = -1.  # if -1, perfect feed-forward balance from SN to RN
backward_balance = -1.  # if -1, perfect feedback balance from RN to SN
alpha = 2.1  # parameter used to compute the feedforward weight, before balancing
beta = 0.2  # parameter used to compute the feedback weight, before balancing
gamma = 0.35  # connectivity (gamma in the paper)
factor = 1000  # factor for computing weights values
[8]:
# Parameters for stimulus
# -----------------------

simulation_time = 1100  # # the simulation time [ms]
start_stimulation = 100  # [ms]
end_stimulation = 200  # [ms]
input_strength = 10  # strength of the stimulation
num_sensory_input_width = 32
# the width for input stimulation of the gaussian distribution
sigma = round(num_sensory_neuron / num_sensory_input_width)
three_sigma = 3 * sigma
activity_threshold = 3
[9]:
# Weights initialization
# ----------------------

# weight matrix within sensory network
sensory_encoding = 2. * bm.pi * bm.arange(1, num_sensory_neuron + 1) / num_sensory_neuron
diff = sensory_encoding.reshape((-1, 1)) - sensory_encoding
weight_mat_of_sensory = lambda_ + A * bm.exp(k1 * (bm.cos(diff) - 1)) - A * bm.exp(k2 * (bm.cos(diff) - 1))
diag = bm.arange(num_sensory_neuron)
weight_mat_of_sensory[diag, diag] = 0.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[10]:
# connectivity matrix between sensory and random network
conn_matrix_sensory2random = bm.random.rand(num_all_sensory, num_all_random) < gamma
[11]:
# weight matrix of sensory2random
ws = factor * alpha / conn_matrix_sensory2random.sum(axis=0)
weight_mat_sensory2random = conn_matrix_sensory2random * ws.reshape((1, -1))
ws = weight_mat_sensory2random.sum(axis=0).reshape((1, -1))
weight_mat_sensory2random += forward_balance / num_all_sensory * ws  # balance
[12]:
# weight matrix of random2sensory
ws = factor * beta / conn_matrix_sensory2random.sum(axis=1)
weight_mat_random2sensory = conn_matrix_sensory2random.T * ws.reshape((1, -1))
ws = weight_mat_random2sensory.sum(axis=0).reshape((1, -1))
weight_mat_random2sensory += backward_balance / num_all_random * ws  # balance
[13]:
@jax.jit
def f(inp_ids, center):
  inp_scale = bm.exp(-(inp_ids - center) ** 2 / 2 / sigma ** 2) / (bm.sqrt(2 * bm.pi) * sigma)
  inp_scale /= bm.max(inp_scale)
  inp_ids = bm.remainder(inp_ids - 1, num_sensory_neuron)
  input = bm.zeros(num_sensory_neuron)
  input[inp_ids] = input_strength * inp_scale
  input -= bm.sum(input) / num_sensory_neuron
  return input


def get_input(center):
  inp_ids = bm.arange(bm.asarray(center - three_sigma, dtype=bm.int32),
                      bm.asarray(center + three_sigma + 1, dtype=bm.int32),
                      1)
  return f(inp_ids, center)
[14]:
def get_activity_vector(rates):
  exp_stim_encoding = bm.exp(1j * sensory_encoding)
  timed_abs = bm.zeros(num_sensory_pool)
  timed_angle = bm.zeros(num_sensory_pool)
  for si in range(num_sensory_pool):
    start = si * num_sensory_neuron
    end = (si + 1) * num_sensory_neuron
    exp_rates = bm.multiply(rates[start:end], exp_stim_encoding)
    mean_rates = bm.mean(exp_rates)
    timed_angle[si] = bm.angle(mean_rates) * num_sensory_neuron / (2 * bm.pi)
    timed_abs[si] = bm.absolute(mean_rates)
  timed_angle[timed_angle < 0] += num_sensory_neuron
  return timed_abs, timed_angle
[15]:
class PoissonNeuron(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super().__init__(size=size, **kwargs)

    self.s = bm.Variable(bm.zeros(self.num))
    self.r = bm.Variable(bm.zeros(self.num))
    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()

    self.int_s = bp.odeint(lambda s, t: -s / tau, method='exp_euler')

  def update(self):
    tdi = bp.share.get_shargs()
    self.s.value = self.int_s(self.s, tdi.t, tdi.dt)
    self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (self.input + self.s + bias) - 3.)) / tau
    self.spike.value = self.rng.random(self.s.shape) < self.r * tdi.dt
    self.input[:] = 0.

  def reset_state(self, batch_size=None):
    self.s.value = self.rng.random(self.num) * init_range
    self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (bias + self.s) - 3.)) / tau
    self.input.value = bm.zeros(self.num)
    self.spike.value = bm.zeros(self.num, dtype=bool)
[16]:
class Sen2SenSyn(bp.SynConn):
  def __init__(self, pre, post, **kwargs):
    super(Sen2SenSyn, self).__init__(pre=pre, post=post, **kwargs)

  def update(self):
    for i in range(num_sensory_pool):
      start = i * num_sensory_neuron
      end = (i + 1) * num_sensory_neuron
      self.post.s[start: end] += bm.dot(self.pre.spike[start: end],
                                        weight_mat_of_sensory)
d:\codes\projects\brainpy\brainpy\_src\deprecations.py:86: DeprecationWarning: brainpy.SynConn is deprecated. Use brainpy.dyn.SynConn instead.
  _deprecate(message)
[17]:
class OtherSyn(bp.SynConn):
  def __init__(self, pre, post, weights, **kwargs):
    super(OtherSyn, self).__init__(pre=pre, post=post, **kwargs)
    self.weights = weights

  def update(self):
    self.post.s += bm.dot(self.pre.spike, self.weights)
d:\codes\projects\brainpy\brainpy\_src\deprecations.py:86: DeprecationWarning: brainpy.SynConn is deprecated. Use brainpy.dyn.SynConn instead.
  _deprecate(message)
[18]:
class Network(bp.DynSysGroup):
  def __init__(self):
    super().__init__()

    self.sensory = PoissonNeuron(num_all_sensory)
    self.random = PoissonNeuron(num_all_random)
    self.sensory2sensory = Sen2SenSyn(pre=self.sensory, post=self.sensory)
    self.random2sensory = OtherSyn(pre=self.random,
                                   post=self.sensory,
                                   weights=weight_mat_random2sensory)
    self.sensory2random = OtherSyn(pre=self.sensory,
                                   post=self.random,
                                   weights=weight_mat_sensory2random)
[19]:
for trial_idx in range(num_trials):
  # inputs
  # ------
  pools_receiving_inputs = bm.random.choice(num_sensory_pool, num_item_to_load, replace=False)
  print(f"Load {num_item_to_load} items in trial {trial_idx}.\n")

  input_center = bm.ones(num_sensory_pool) * num_sensory_neuron / 2
  inp_vector = bm.zeros((num_sensory_pool, num_sensory_neuron))
  for si in pools_receiving_inputs:
    inp_vector[si, :] = get_input(input_center[si])
  inp_vector = inp_vector.flatten()
  Iext, duration = bp.inputs.constant_input(
    [(0., start_stimulation),
     (inp_vector, end_stimulation - start_stimulation),
     (0., simulation_time - end_stimulation)]
  )

  # running
  # -------
  net = Network()
  runner = bp.DSRunner(net,
                       inputs=(net.sensory.input, Iext, 'iter'),
                       monitors={'S.r': net.sensory.r,
                                 'S.spike': net.sensory.spike,
                                 'R.spike': net.random.spike})
  runner.predict(duration, reset_state=True)

  # results
  # --------

  rate_abs, rate_angle = get_activity_vector(runner.mon['S.r'][-1] * 1e3)
  print(f"Stimulus is given in: {bm.sort(pools_receiving_inputs)}")
  print(f"Memory is found in: {bm.where(rate_abs > activity_threshold)[0]}")

  prob_maintained, prob_spurious = 0, 0
  for si in range(num_sensory_pool):
    if rate_abs[si] > activity_threshold:
      if si in pools_receiving_inputs:
        prob_maintained += 1
      else:
        prob_spurious += 1
  print(str(prob_maintained) + ' maintained memories')
  print(str(pools_receiving_inputs.shape[0] - prob_maintained) + ' forgotten memories')
  print(str(prob_spurious) + ' spurious memories\n')
  prob_maintained /= float(num_item_to_load)
  if num_item_to_load != num_sensory_pool:
    prob_spurious /= float(num_sensory_pool - num_item_to_load)

  # visualization
  # -------------
  fig, gs = bp.visualize.get_figure(6, 1, 1.5, 12)
  xlim = (0, duration)

  fig.add_subplot(gs[0:4, 0])
  bp.visualize.raster_plot(runner.mon.ts,
                           runner.mon['S.spike'],
                           ylabel='Sensory Network', xlim=xlim)
  for index_sn in range(num_sensory_pool + 1):
    plt.axhline(index_sn * num_sensory_neuron)
  plt.yticks([num_sensory_neuron * (i + 0.5) for i in range(num_sensory_pool)],
             [f'pool-{i}' for i in range(num_sensory_pool)])

  fig.add_subplot(gs[4:6, 0])
  bp.visualize.raster_plot(runner.mon.ts,
                           runner.mon['R.spike'],
                           ylabel='Random Network', xlim=xlim, show=True)
Load 6 items in trial 0.

Stimulus is given in: [1 2 3 4 5 7]
Memory is found in: [1 3 5]
3 maintained memories
3 forgotten memories
0 spurious memories

../_images/working_memory_Bouchacourt_2019_Flexible_working_memory_20_3.png