(Mi, et. al., 2017) STP for Working Memory Capacity

Implementation of the paper:

  • Mi, Yuanyuan, Mikhail Katkov, and Misha Tsodyks. “Synaptic correlates of working memory capacity.” Neuron 93.2 (2017): 323-330.

Author:

[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
[3]:
dt = 0.0001  # [s]
bp.math.set_dt(dt=dt)
[4]:
alpha = 1.5
J_EE = 8.  # the connection strength in each excitatory neural clusters
J_IE = 1.75  # Synaptic efficacy E → I
J_EI = 1.1  # Synaptic efficacy I → E
tau_f = 1.5  # time constant of STF  [s]
tau_d = .3  # time constant of STD  [s]
U = 0.3  # minimum STF value
tau = 0.008  # time constant of firing rate of the excitatory neurons [s]
tau_I = tau  # time constant of firing rate of the inhibitory neurons

Ib = 8.  # background input and external input
Iinh = 0.  # the background input of inhibtory neuron

cluster_num = 16  # the number of the clusters
[5]:
# the parameters of external input

stimulus_num = 5
Iext_train = 225  # the strength of the external input
Ts_interval = 0.070  # the time interval between the consequent external input [s]
Ts_duration = 0.030  # the time duration of the external input [s]
duration = 2.500  # [s]
[6]:
# the excitatory cluster model and the inhibitory pool model

class WorkingMemoryModel(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(WorkingMemoryModel, self).__init__(size, **kwargs)

    self.hi = bm.Variable(bm.asarray([0.]))
    self.u = bm.Variable(bm.ones(cluster_num) * U)
    self.x = bm.Variable(bm.ones(cluster_num))
    self.h = bm.Variable(bm.zeros(cluster_num))
    self.input = bm.Variable(bm.zeros(cluster_num))
    self.inh_r = bm.Variable(self.log(self.hi))
    self.r = bm.Variable(self.log(self.h))

    self.integral = bp.odeint(self.derivative)

  def derivative(self, u, x, h, hi, t, r, r_inh, Iext):
    du = (U - u) / tau_f + U * (1 - u) * r
    dx = (1 - x) / tau_d - u * x * r
    dh = (-h + J_EE * u * x * r - J_EI * r_inh + Iext + Ib) / tau
    dhi = (-hi + J_IE * bm.sum(r) + Iinh) / tau_I
    return du, dx, dh, dhi

  def log(self, h):
    return alpha * bm.log(1. + bm.exp(h / alpha))

  def update(self, _t, _dt):
    self.u[:], self.x[:], self.h[:], self.hi[:] = self.integral(
      self.u, self.x, self.h, self.hi, _t,
      self.r, self.inh_r, self.input)
    self.r[:] = self.log(self.h)
    self.inh_r[:] = self.log(self.hi)
    self.input[:] = 0.
[7]:
# the external input

I_inputs = bm.zeros((int(duration / dt), cluster_num))
for i in range(stimulus_num):
  t_start = (Ts_interval + Ts_duration) * i + Ts_interval
  t_end = t_start + Ts_duration
  idx_start, idx_end = int(t_start / dt), int(t_end / dt)
  I_inputs[idx_start: idx_end, i] = Iext_train
[8]:
# model.monwork running

model = WorkingMemoryModel(cluster_num, monitors=['u', 'x', 'r', 'h'])
model.run(duration, inputs=['input', I_inputs, 'iter'])
[8]:
1.0426909923553467
[9]:
# visualization

colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys())

fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
fig.add_subplot(gs[0, 0])
for i in range(stimulus_num):
  plt.plot(model.mon.ts, model.mon.r[:, i], label='Cluster-{}'.format(i))
plt.ylabel("$r (Hz)$")
plt.legend(loc='right')

fig.add_subplot(gs[1, 0])
hist_Jux = J_EE * model.mon.u * model.mon.x
for i in range(stimulus_num):
  plt.plot(model.mon.ts, hist_Jux[:, i])
plt.ylabel("$J_{EE}ux$")

fig.add_subplot(gs[2, 0])
for i in range(stimulus_num):
  plt.plot(model.mon.ts, model.mon.u[:, i], colors[i])
plt.ylabel('u')

fig.add_subplot(gs[3, 0])
for i in range(stimulus_num):
  plt.plot(model.mon.ts, model.mon.x[:, i], colors[i])
plt.ylabel('x')

fig.add_subplot(gs[4, 0])
for i in range(stimulus_num):
  plt.plot(model.mon.ts, model.mon.r[:, i], colors[i])
plt.ylabel('h')
plt.xlabel('time [s]')

plt.show()
../_images/working_memory_Mi_2017_working_memory_capacity_10_0.png