(Bouchacourt & Buschman, 2019) Flexible Working Memory Model
Implementation of :
Bouchacourt, Flora, and Timothy J. Buschman. “A flexible model of working memory.” Neuron 103.1 (2019): 147-160.
Author:
Chaoming Wang (chao.brain@qq.com)
[1]:
import matplotlib.pyplot as plt
import numba
[2]:
import brainpy as bp
dt = bp.math.get_dt()
[3]:
# increase in order to run multiple trials with the same network
num_trials = 1
num_item_to_load = 6
[4]:
# Parameters for network architecture
# ------------------------------------
num_sensory_neuron = 512 # Number of recurrent neurons per sensory network
num_sensory_pool = 8 # Number of ring-like sensory network
num_all_sensory = num_sensory_pool * num_sensory_neuron
num_all_random = 1024 # Number of neuron in the random network
fI_slope = 0.4 # slope of the non-linear f-I function
bias = 0. # bias in the neuron firing response (cf page 1 right column of Burak, Fiete 2012)
tau = 10. # Synaptic time constant [ms]
init_range = 0.01 # Range to randomly initialize synaptic variables
[5]:
# Parameters for sensory network
# -------------------------------
k2 = 0.25 # width of negative surround suppression
k1 = 1. # width of positive amplification
A = 2. # amplitude of weight function
lambda_ = 0.28 # baseline of weight matrix for recurrent network
[6]:
# Parameters for interaction of
# random network <-> sensory network
# -----------------------------------
forward_balance = -1. # if -1, perfect feed-forward balance from SN to RN
backward_balance = -1. # if -1, perfect feedback balance from RN to SN
alpha = 2.1 # parameter used to compute the feedforward weight, before balancing
beta = 0.2 # parameter used to compute the feedback weight, before balancing
gamma = 0.35 # connectivity (gamma in the paper)
factor = 1000 # factor for computing weights values
[7]:
# Parameters for stimulus
# -----------------------
simulation_time = 1100 # # the simulation time [ms]
start_stimulation = 100 # [ms]
end_stimulation = 200 # [ms]
input_strength = 10 # strength of the stimulation
num_sensory_input_width = 32
# the width for input stimulation of the gaussian distribution
sigma = round(num_sensory_neuron / num_sensory_input_width)
three_sigma = 3 * sigma
activity_threshold = 3
[8]:
# Weights initialization
# ----------------------
# weight matrix within sensory network
sensory_encoding = 2. * bp.math.pi * bp.math.arange(1, num_sensory_neuron + 1) / num_sensory_neuron
diff = sensory_encoding.reshape((-1, 1)) - sensory_encoding
weight_mat_of_sensory = lambda_ + A * bp.math.exp(k1 * (bp.math.cos(diff) - 1)) - \
A * bp.math.exp(k2 * (bp.math.cos(diff) - 1))
diag = bp.math.arange(num_sensory_neuron)
weight_mat_of_sensory[diag, diag] = 0.
[9]:
# connectivity matrix between sensory and random network
conn_matrix_sensory2random = bp.math.random.rand(num_all_sensory, num_all_random) < gamma
[10]:
# weight matrix of sensory2random
ws = factor * alpha / conn_matrix_sensory2random.sum(axis=0)
weight_mat_sensory2random = conn_matrix_sensory2random * ws.reshape((1, -1))
ws = weight_mat_sensory2random.sum(axis=0).reshape((1, -1))
weight_mat_sensory2random += forward_balance / num_all_sensory * ws # balance
[11]:
# weight matrix of random2sensory
ws = factor * beta / conn_matrix_sensory2random.sum(axis=1)
weight_mat_random2sensory = conn_matrix_sensory2random.T * ws.reshape((1, -1))
ws = weight_mat_random2sensory.sum(axis=0).reshape((1, -1))
weight_mat_random2sensory += backward_balance / num_all_random * ws # balance
[12]:
@numba.njit(['float64[:](float64)'])
def get_input(input_center):
input = bp.math.zeros(num_sensory_neuron)
inp_ids = bp.math.arange(int(round(input_center - three_sigma)),
int(round(input_center + three_sigma + 1)), 1)
inp_scale = bp.math.exp(-(inp_ids - input_center) ** 2 / 2 / sigma ** 2) / (bp.math.sqrt(2 * bp.math.pi) * sigma)
inp_scale /= bp.math.max(inp_scale)
inp_ids = bp.math.remainder(inp_ids - 1, num_sensory_neuron)
input[inp_ids] = input_strength * inp_scale
input -= bp.math.sum(input) / input.shape[0]
return input
[13]:
def get_activity_vector(rates):
exp_stim_encoding = bp.math.exp(1j * sensory_encoding, dtype=complex)
timed_abs = bp.math.zeros(num_sensory_pool)
timed_angle = bp.math.zeros(num_sensory_pool)
for si in range(num_sensory_pool):
start = si * num_sensory_neuron
end = (si + 1) * num_sensory_neuron
exp_rates = bp.math.multiply(rates[start:end], exp_stim_encoding, dtype=complex)
mean_rates = bp.math.mean(exp_rates)
timed_angle[si] = bp.math.angle(mean_rates) * num_sensory_neuron / (2 * bp.math.pi)
timed_abs[si] = bp.math.absolute(mean_rates)
timed_angle[timed_angle < 0] += num_sensory_neuron
return timed_abs, timed_angle
[14]:
class PoissonNeuron(bp.NeuGroup):
target_backend = ['numpy', 'numba']
@bp.odeint(method='exponential_euler')
def int_s(self, s, t):
ds = -s / tau
return ds
def __init__(self, size, **kwargs):
super(PoissonNeuron, self).__init__(size=size, **kwargs)
self.s = bp.math.Variable(bp.math.zeros(self.num))
self.r = bp.math.Variable(bp.math.zeros(self.num))
self.input = bp.math.Variable(bp.math.zeros(self.num))
self.spike = bp.math.Variable(bp.math.zeros(self.num, dtype=bool))
def update(self, _t, _dt):
self.s[:] = self.int_s(self.s, _t)
self.r[:] = 0.4 * (1. + bp.math.tanh(fI_slope * (self.input + self.s + bias) - 3.)) / tau
self.spike[:] = bp.math.random.random(self.s.shape) < self.r * dt
self.input[:] = 0.
def reset(self):
self.s[:] = bp.math.random.random(self.num) * init_range
self.r[:] = 0.4 * (1. + bp.math.tanh(fI_slope * (bias + self.s) - 3.)) / tau
self.input[:] = bp.math.zeros(self.num)
self.spike[:] = bp.math.zeros(self.num, dtype=bool)
[15]:
sensory_net = PoissonNeuron(num_all_sensory, monitors=['r', 'spike'], name='S')
random_net = PoissonNeuron(num_all_random, monitors=['spike'], name='R')
[16]:
class Sen2SenSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba']
def __init__(self, pre, post, **kwargs):
super(Sen2SenSyn, self).__init__(pre=pre, post=post, **kwargs)
def update(self, _t, _dt):
for i in range(num_sensory_pool):
start = i * num_sensory_neuron
end = (i + 1) * num_sensory_neuron
for j in range(num_sensory_neuron):
if self.pre.spike[start + j]:
self.post.s[start: end] += weight_mat_of_sensory[j]
sensory2sensory_conn = Sen2SenSyn(pre=sensory_net, post=sensory_net)
[17]:
class OtherSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba']
def __init__(self, pre, post, weights, **kwargs):
self.weights = weights
super(OtherSyn, self).__init__(pre=pre, post=post, **kwargs)
def update(self, _t, _dt):
for i in range(self.pre.spike.shape[0]):
if self.pre.spike[i]:
self.post.s += self.weights[i]
random2sensory_conn = OtherSyn(pre=random_net,
post=sensory_net,
weights=weight_mat_random2sensory)
sensory2random_conn = OtherSyn(pre=sensory_net,
post=random_net,
weights=weight_mat_sensory2random)
[18]:
net = bp.Network(sensory_net, random_net, sensory2sensory_conn,
random2sensory_conn, sensory2random_conn)
net = bp.math.jit(net)
[19]:
for trial_idx in range(num_trials):
# inputs
# ------
pools_receiving_inputs = bp.math.random.choice(num_sensory_pool, num_item_to_load, replace=False)
print(f"Load {num_item_to_load} items in trial {trial_idx}.\n")
input_center = bp.math.ones(num_sensory_pool) * num_sensory_neuron / 2
inp_vector = bp.math.zeros((num_sensory_pool, num_sensory_neuron))
for si in pools_receiving_inputs:
inp_vector[si, :] = get_input(input_center[si])
inp_vector = inp_vector.flatten()
Iext, duration = bp.inputs.constant_current([(0., start_stimulation),
(inp_vector, end_stimulation - start_stimulation),
(0., simulation_time - end_stimulation)])
# running
# -------
sensory_net.reset()
random_net.reset()
net.run(duration, inputs=('S.input', Iext, 'iter'))
# results
# --------
rate_abs, rate_angle = get_activity_vector(sensory_net.mon.r[-1] * 1e3)
print(f"Stimulus is given in: {bp.math.sort(pools_receiving_inputs)}")
print(f"Memory is found in: {bp.math.where(rate_abs > activity_threshold)[0]}")
prob_maintained, prob_spurious = 0, 0
for si in range(num_sensory_pool):
if rate_abs[si] > activity_threshold:
if si in pools_receiving_inputs:
prob_maintained += 1
else:
prob_spurious += 1
print(str(prob_maintained) + ' maintained memories')
print(str(pools_receiving_inputs.shape[0] - prob_maintained) + ' forgotten memories')
print(str(prob_spurious) + ' spurious memories\n')
prob_maintained /= float(num_item_to_load)
if num_item_to_load != num_sensory_pool:
prob_spurious /= float(num_sensory_pool - num_item_to_load)
# visualization
# -------------
fig, gs = bp.visualize.get_figure(6, 1, 1.5, 12)
xlim = (0, duration)
fig.add_subplot(gs[0:4, 0])
bp.visualize.raster_plot(sensory_net.mon.ts, sensory_net.mon.spike, ylabel='Sensory Network', xlim=xlim)
for index_sn in range(num_sensory_pool + 1):
plt.axhline(index_sn * num_sensory_neuron)
plt.yticks([num_sensory_neuron * (i + 0.5) for i in range(num_sensory_pool)],
[f'pool-{i}' for i in range(num_sensory_pool)])
fig.add_subplot(gs[4:6, 0])
bp.visualize.raster_plot(sensory_net.mon.ts, random_net.mon.spike, ylabel='Random Network', xlim=xlim, show=True)
Load 6 items in trial 0.
Stimulus is given in: [0 1 2 3 5 7]
Memory is found in: [0 1 3]
3 maintained memories
3 forgotten memories
0 spurious memories
