(Laje & Buonomano, 2013) Simulation in RNN
[1]:
import time
[2]:
import brainpy.math as bm
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import scipy.stats
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.io import loadmat
[3]:
__all__ = [
'fig1',
'fig2',
'fig3',
]
[4]:
dt = 1.
time2len = lambda t: int(t / dt)
[5]:
def fig1(net, verbose=True):
# Parameters
# ----------
num_rec_train = 30 # Number of learning trials for the recurrent weights
num_readout_train = 10 # Number of learning trials for the readout weights
num_perturbation = 5 # Number of perturbation trials
stimulus_amplitude = 5.0 # Amplitude of the input pulse
t_offset = time2len(200) # Time to wait before the stimulation
d_stim = time2len(50) # Duration of the stimulation
d_trajectory = time2len(2000 + 150) # Duration of the desired target_traj
t_relax = time2len(550) # Duration to relax after the target_traj
trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial
times = bm.arange(0, trial_duration, dt)
perturbation_amplitude = 0.5 # Amplitude of the perturbation pulse
t_perturb = time2len(500) # Offset for the perturbation
d_perturb = time2len(10) # Duration of the perturbation
target_baseline = 0.2 # Baseline of the output_traj function
target_amplitude = 1. # Maximal value of the output_traj function
target_width = 30. # Width of the Gaussian
target_time = time2len(d_trajectory - 150) # Peak time within the learning interval
# Input definitions
# -----------------
# Impulse after 200 ms
impulse = bm.zeros((trial_duration, net.num_input))
impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
# Perturbation during the trial
perturbation = bm.zeros((trial_duration, net.num_input))
perturbation[t_offset: t_offset + d_stim, 0] = stimulus_amplitude
perturbation[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude
# Target output for learning the readout weights
output_traj = bm.zeros((trial_duration, net.num_output))
output_traj[:, 0] = target_baseline + (target_amplitude - target_baseline) * \
bm.exp(-(t_offset + d_stim + target_time - times) ** 2 / target_width ** 2)
# Main procedure
# --------------
tstart = time.time()
# Initial trial to determine the innate target_traj
if verbose: print('Initial trial to determine a target_traj (without noise)')
initial_traj, initial_output, _ = net.simulate(stimulus=impulse, noise=False)
# Pre-training test trial
if verbose: print('Pre-training test trial')
pretrain_traj, pretrain_output, _ = net.simulate(stimulus=impulse, noise=True)
# Perturbation trial
if verbose: print(num_perturbation, 'perturbation trials')
perturbation_initial = []
for i in range(num_perturbation):
_, perturbation_output, _ = net.simulate(stimulus=perturbation, noise=True)
perturbation_initial.append(perturbation_output)
# 20 trials of learning for the recurrent weights
for i in range(num_rec_train):
t0 = time.time()
if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse,
target_traj=initial_traj,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_trajectory)
if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
# 10 trials of learning for the readout weights
for i in range(num_readout_train):
t0 = time.time()
if verbose: print(f'Learning trial readout {i + 1} loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse,
target_traj=output_traj,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_trajectory,
learn_readout=True)
if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
# Test trial
if verbose: print('2 test trials')
reproductions = []
final_outputs = []
for i in range(2):
reproduction, final_output, _ = net.simulate(stimulus=impulse, noise=True)
reproductions.append(reproduction)
final_outputs.append(final_output)
# Perturbation trial
if verbose: print(num_perturbation, 'perturbation trials')
perturbation_final = []
for i in range(num_perturbation):
_, perturbation_output, _ = net.simulate(stimulus=perturbation)
perturbation_final.append(perturbation_output)
if verbose: print('Simulation done in', time.time() - tstart, 'seconds.')
# Visualization
# -------------
plt.figure(figsize=(8, 12))
# innate trajectory
ax = plt.subplot2grid((4, 2), (0, 0), colspan=2)
im = ax.imshow(initial_traj[:, :100].T, aspect='auto', origin='lower')
cax = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.2))
ax.set_title('Innate Trajectory')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Recurrent units')
# pre-training results
ax = plt.subplot2grid((4, 2), (1, 0))
for i in range(3):
ax.plot(times, initial_traj[:, i] + i * 2 + 1, 'b')
ax.plot(times, pretrain_traj[:, i] + i * 2 + 1, 'r')
ax.set_yticks([i * 2 + 1 for i in range(3)])
ax.set_yticklabels([0, 0, 0])
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_title('Pre-training')
ax.set_ylabel('Firing rate $r$ [Hz]')
ax.set_xlim(times[0], times[-1])
ax = plt.subplot2grid((4, 2), (2, 0))
ax.plot(times, initial_output[:, 0], 'b')
ax.plot(times, pretrain_output[:, 0], 'r')
ax.axhline(output_traj[0, 0], c='k')
ax.set_yticks([-2, -1, 0, 1, 2])
ax.set_ylim((-2, 2))
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_ylabel('Output (test)')
ax.set_xlim(times[0], times[-1])
ax = plt.subplot2grid((4, 2), (3, 0))
for i in range(num_perturbation):
ax.plot(times, perturbation_initial[i][:, 0])
ax.axhline(output_traj[0, 0], c='k')
ax.set_yticks([-2, -1, 0, 1, 2])
ax.set_ylim((-2, 2))
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Output (perturbed)')
ax.set_xlim(times[0], times[-1])
# post-training results
ax = plt.subplot2grid((4, 2), (1, 1))
for i in range(3):
ax.plot(times, reproductions[0][:, i] + i * 2 + 1, 'b')
ax.plot(times, reproductions[1][:, i] + i * 2 + 1, 'r')
ax.set_yticks([i * 2 + 1 for i in range(3)])
ax.set_yticklabels([0, 0, 0])
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_title('Post-training')
ax.set_xlim(times[0], times[-1])
ax = plt.subplot2grid((4, 2), (2, 1))
ax.plot(times, final_outputs[0][:, 0], 'b')
ax.plot(times, final_outputs[1][:, 0], 'r')
ax.axhline(output_traj[0, 0], c='k')
ax.set_yticks([-2, -1, 0, 1, 2])
ax.set_ylim((-2, 2))
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_xlim(times[0], times[-1])
ax = plt.subplot2grid((4, 2), (3, 1))
for i in range(num_perturbation):
ax.plot(times, perturbation_final[i][:, 0])
ax.axhline(output_traj[0, 0], c='k')
ax.set_yticks([-2, -1, 0, 1, 2])
ax.set_ylim((-2, 2))
ymin, ymax = ax.get_ylim()
ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7))
ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
ax.set_xlabel('Time (ms)')
ax.set_xlim(times[0], times[-1])
plt.tight_layout()
plt.show()
[6]:
def fig2(net, verbose=True):
# Parameters
# ----------
num_rec_train = 30 # Number of learning trials for the recurrent weights
num_readout_train = 10 # Number of learning trials for the readout weights
num_test = 5 # Number of test trials
num_perturb = 5 # Number of perturbation trials
stimulus_amplitude = 2.0 # Amplitude of the input pulse
t_offset = time2len(200) # Time to wait before the stimulation
d_stim = time2len(50) # Duration of the stimulation
t_relax = time2len(150) # Duration to relax after the target_traj
perturbation_amplitude = 0.2 # Amplitude of the perturbation pulse
t_perturb = time2len(300) # Offset for the perturbation
d_perturb = time2len(10) # Duration of the perturbation
# Input definitions
# -----------------
# Retrieve the targets and reformat them
targets = loadmat('data/DAC_handwriting_output_targets.mat')
chaos = targets['chaos']
neuron = targets['neuron']
# Durations
_, d_chaos = chaos.shape
_, d_neuron = neuron.shape
# Impulses
impulse_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input))
impulse_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
impulse_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input))
impulse_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude
# Perturbation
perturbation_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input))
perturbation_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
perturbation_chaos[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude
perturbation_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input))
perturbation_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude
perturbation_neuron[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 3] = perturbation_amplitude
# Targets
target_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_output))
target_chaos[t_offset + d_stim: t_offset + d_stim + d_chaos, :] = chaos.T
target_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_output))
target_neuron[t_offset + d_stim: t_offset + d_stim + d_neuron, :] = neuron.T
# Main procedure
# --------------
tstart = time.time()
# Initial trial to determine the innate target_traj
if verbose: print('Initial chaos trial')
trajectory_chaos, initial_chaos_output, _ = net.simulate(stimulus=impulse_chaos, noise=False)
if verbose: print('Initial neuron trial')
trajectory_neuron, initial_neuron_output, _ = net.simulate(stimulus=impulse_neuron, noise=False)
# learning for the recurrent weights
for i in range(num_rec_train):
t0 = time.time()
if verbose: print(f'Learning recurrent {i + 1} "chaos" loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse_chaos,
target_traj=trajectory_chaos,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_chaos)
if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='')
t0 = time.time()
_, _, loss = net.simulate(stimulus=impulse_neuron,
target_traj=trajectory_neuron,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_neuron)
if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s')
# learning for the readout weights
for i in range(num_readout_train):
t0 = time.time()
if verbose: print(f'Learning readout {i + 1} "chaos" loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse_chaos,
target_traj=target_chaos,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_chaos,
learn_readout=True)
if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='')
t0 = time.time()
_, _, loss = net.simulate(stimulus=impulse_neuron,
target_traj=target_neuron,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_neuron,
learn_readout=True)
if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s')
# Test trials
final_output_chaos = []
final_output_neuron = []
for _ in range(num_test):
if verbose: print('Test chaos trial')
_, o, _ = net.simulate(stimulus=impulse_chaos)
final_output_chaos.append(o)
if verbose: print('Test neuron trial')
_, o, _ = net.simulate(stimulus=impulse_neuron)
final_output_neuron.append(o)
# Perturbation trials
perturbation_output_chaos = []
perturbation_output_neuron = []
for _ in range(num_perturb):
if verbose: print('Perturbation chaos trial')
_, o, _ = net.simulate(stimulus=perturbation_chaos)
perturbation_output_chaos.append(o)
if verbose: print('Perturbation neuron trial')
_, o, _ = net.simulate(stimulus=perturbation_neuron)
perturbation_output_neuron.append(o)
if verbose: print('Simulation done in', time.time() - tstart, 'seconds.')
# Visualization
# -------------
subsampling_chaos = (t_offset + d_stim + bm.linspace(0, d_chaos, 20)).astype(bm.int32)
subsampling_chaos = bm.unique(subsampling_chaos)
subsampling_neuron = (t_offset + d_stim + bm.linspace(0, d_neuron, 20)).astype(bm.int32)
subsampling_neuron = bm.unique(subsampling_neuron)
plt.figure(figsize=(12, 8))
ax = plt.subplot2grid((2, 2), (0, 0))
ax.plot(chaos[0, :], chaos[1, :], linewidth=2.)
for i in range(num_perturb):
ax.plot(final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0],
final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1])
ax.plot(final_output_chaos[i][subsampling_chaos, 0],
final_output_chaos[i][subsampling_chaos, 1], 'bo')
ax.set_xlabel('x')
ax.set_ylabel('Without perturbation\ny')
ax = plt.subplot2grid((2, 2), (0, 1))
ax.plot(neuron[0, :], neuron[1, :], linewidth=2.)
for i in range(num_perturb):
ax.plot(final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0],
final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1])
ax.plot(final_output_neuron[i][subsampling_neuron, 0],
final_output_neuron[i][subsampling_neuron, 1], 'bo')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax = plt.subplot2grid((2, 2), (1, 0))
ax.plot(chaos[0, :], chaos[1, :], linewidth=2.)
for i in range(num_perturb):
ax.plot(perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0],
perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1])
ax.plot(perturbation_output_chaos[i][subsampling_chaos, 0],
perturbation_output_chaos[i][subsampling_chaos, 1], 'bo')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_ylabel('With perturbation\ny')
ax = plt.subplot2grid((2, 2), (1, 1))
ax.plot(neuron[0, :], neuron[1, :], linewidth=2.)
for i in range(num_perturb):
ax.plot(perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0],
perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1])
ax.plot(perturbation_output_neuron[i][subsampling_neuron, 0],
perturbation_output_neuron[i][subsampling_neuron, 1], 'bo')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()
[7]:
def fig3(nets, verbose=True):
# Parameters
# ----------
num_rec_train = 20 # Number of learning trials for the recurrent weights
num_readout_train = 10 # Number of learning trials for the readout weights
stimulus_amplitude = 5.0 # Amplitude of the input pulse
t_offset = time2len(200) # Time to wait before the stimulation
d_stim = time2len(50) # Duration of the stimulation
t_relax = time2len(150) # Duration to relax after the target_traj
target_baseline = 0.2 # Baseline of the output_traj function
target_amplitude = 1. # Maximal value of the output_traj function
target_width = 30. # Width of the Gaussian
# Main procedure
# --------------
# Vary the timing interval
delays = [250, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000]
# Store the Pearson correlation coefficients
pearsons = []
# Iterate over the delays
for target_time in delays:
if verbose:
print('*' * 60)
print('Learning a delay of', target_time)
print('*' * 60)
print()
d_trajectory = time2len(target_time + 150) # Duration of the desired target_traj
trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial
pearsons.append([])
for n, net in enumerate(nets): # 10 networks per delay
if verbose: print(f'## Network {n + 1} ##', )
net.init()
# Impulse input after 200 ms
impulse = bm.zeros((trial_duration, net.num_input))
impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
# Target output for learning the readout weights
target = bm.zeros((trial_duration, net.num_output))
time_axis = bm.linspace(0, trial_duration, trial_duration)
target[:, 0] = target_baseline + (target_amplitude - target_baseline) * \
bm.exp(-(t_offset + d_stim + target_time - time_axis) ** 2 / target_width ** 2)
# Initial trial to determine the innate target_traj
if verbose: print('Initial trial to determine a target_traj (without noise)')
trajectory, initial_output, _ = net.simulate(stimulus=impulse, noise=False)
# 20 trials of learning for the recurrent weights
for i in range(num_rec_train):
t0 = time.time()
if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse,
target_traj=trajectory,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_trajectory)
if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
# 10 trials of learning for the readout weights
for i in range(num_readout_train):
t0 = time.time()
if verbose: print(f'Learning trial readout {i + 1} loss: ', end='')
_, _, loss = net.simulate(stimulus=impulse,
target_traj=target,
learn_start=t_offset + d_stim,
learn_stop=t_offset + d_stim + d_trajectory,
learn_readout=True)
if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
# Test trial
if verbose: print('Test trial')
reproduction, final_output, _ = net.simulate(stimulus=impulse)
# Pearson correlation coefficient
pred = final_output[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0]
desired = target[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0]
r, p = scipy.stats.pearsonr(desired, pred)
pearsons[-1].append(r)
# Save the results
pearsons = bm.asarray(pearsons)
# Visualization
# -------------
plt.figure(figsize=(8, 6))
correlation_mean = bm.mean(pearsons ** 2, axis=1)
correlation_std = bm.std(pearsons ** 2, axis=1)
plt.errorbar(bm.array(delays) / 1000., correlation_mean, correlation_std / bm.sqrt(10), linestyle='-', marker='^')
plt.xlim((0., 8.5))
plt.ylim((-0.1, 1.1))
plt.xlabel('Interval (s)')
plt.ylabel('Performance ($R^2$)')
plt.show()