BrainPy Examples
This repository contains examples of using BrainPy to implement various models about neurons, synapse, networks, etc. We welcome your implementation, which can be post through our github page.
If you run some codes failed, please tell us through github issue: https://github.com/PKU-NIP-Lab/BrainPyExamples/issues
If you found these examples are useful for your research, please kindly cite us.
If you want to add more examples, please fork our github: https://github.com/PKU-NIP-Lab/BrainPyExamples
(Izhikevich, 2003) Izhikevich Model
@Chaoming Wang @Xinyu Liu
1. Model Overview
Izhikevich neuron model reproduces spiking and bursting behavior of known types of cortical neurons. This model combines the biologically plausibility of Hodgkin–Huxley-type dynamics (HH model) and the computational efficiency of integrate-and-fire neurons (LIF model) .
Bifurcation methodologies enable us to reduce many biophysically accurate Hodgkin–Huxley-type neuronal models to a two-dimensional system of ordinary differential equations of the form:
with the auxiliary after-spike resetting:
\(v\) represents the membrane potential of the neuron.
\(u\) represents a membrane recovery variable, which accounts for the activation of \(\mathrm{K}^{+}\) ionic currents and inactivation of \(\mathrm{Na}^{+}\) ionic currents, and it provides negative feedback to \(v\) .
About the parameter ( a, b, c, d ) :
a: The parameter \(a\) describes the time scale of the recovery variable \(u\). Smaller values result in slower recovery. A typical value is \(a = 0.02\).
b: The parameter \(b\) describes the sensitivity of the recovery variable \(u\) to the subthreshold fluctuations of the membrane potential \(v\) and depend the resting potential in the model (\(60-70 mV\)). Greater values couple \(v\) and \(u\) more strongly resulting in possible subthreshold oscillations and low-threshold spiking dynamics. A typical value is \(b = 0.2\).
c: The parameter \(c\) describes the after-spike reset value of the membrane potential \(v\) caused by the fast high-threshold \(\mathrm{K}^{+}\) conductances. A typical value is \(c = 65 mV\).
d: The parameter \(d\) describes after-spike reset of the recovery variable \(u\) caused by slow high-threshold \(\mathrm{Na}^{+}\) and \(\mathrm{K}^{+}\) conductances. A typical value is \(d = 2\).
The threshold value of the model neuron is between \(–70mV\) and \(-50mV\), and it is dynamic, as in biological neurons.
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
Summary of the neuro-computational properties of biological spiking neurons:
The model can exhibit firing patterns of all known types of cortical neurons with the choice of parameters \(a,b,c,d\) and given below.
2. Different firing patterns
The following interpretation of the most prominent features of biological spiking neurons is based on reference [1].
Neuro-computational properties |
a |
b |
c |
d |
---|---|---|---|---|
Tonic Spiking |
0.02 |
0.04 |
-65 |
2 |
Phasic Spiking |
0.02 |
0.25 |
-65 |
6 |
Tonic Bursting |
0.02 |
0.2 |
-50 |
2 |
Phasic Bursting |
0.02 |
0.25 |
-55 |
0.05 |
Mixed Model |
0.02 |
0.2 |
-55 |
4 |
Spike Frequency Adaptation |
0.01 |
0.2 |
-65 |
8 |
Class 1 Excitability |
0.02 |
-0.1 |
-55 |
6 |
Class 2 Excitability |
0.2 |
0.26 |
-65 |
0 |
Spike Latency |
0.02 |
0.2 |
-65 |
6 |
Subthreshold Oscillations |
0.05 |
0.26 |
-60 |
0 |
Resonator |
0.1 |
0.26 |
-60 |
-1 |
Integrator |
0.02 |
-0.1 |
-55 |
6 |
Rebound Spike |
0.03 |
0.25 |
-60 |
4 |
Rebound Burst |
0.03 |
0.25 |
-52 |
0 |
Threshold Variability |
0.03 |
0.25 |
-60 |
4 |
Bistability |
1 |
1.5 |
-60 |
0 |
Depolarizing After-Potentials |
1 |
0.2 |
-60 |
-21 |
Accommodation |
0.02 |
1 |
-55 |
4 |
Inhibition-Induced Spiking |
-0.02 |
-1 |
-60 |
8 |
Inhibition-Induced Bursting |
-0.026 |
-1 |
-45 |
0 |
The table above gives the value of the parameter \(a,b,c,d\) under 20 types of firing patterns.
Tonic Spiking
While the inputison, the neuron continues to fire a train of spikes. This kind of behavior, called tonic spiking, can be observed in the three types of cortical neurons: regular spiking (RS) excitatory neurons,low threshold spiking(LTS),and fast spiking (FS) inhibitory neurons.Continuous firing of such neurons indicate that there is a persistent input.
[2]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.40, -65.0, 2.0
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Tonic Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Phasic Spiking
A neuron may fire only a single spike at the onset of the input, and remain quiescent afterwards. Such a response is called phasic spiking, and it is useful for detection of the beginning of stimulation.
[3]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65.0, 6.0
current = bp.inputs.section_input(values=[0., 1.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Phasic Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Tonic Bursting
Some neurons, such as the chattering neurons in cat neocortex, fire periodic bursts of spikes when stimulated. The interburst frequency may be as high as 50 Hz, and it is believed that such neurons contribute to the gamma-frequency oscillations in the brain.
[4]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -50.0, 2.0
current = bp.inputs.section_input(values=[0., 15.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Tonic Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Phasic Bursting
Similarly to the phasic spikers, some neurons are phasic bursters. Such neurons report the beginning of the stimulation by transmitting a burst.
[5]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -55.0, 0.05
current = bp.inputs.section_input(values=[0., 1.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Phasic Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Mixed Mode
Intrinsically bursting (IB) excitatory neurons in mammalian neocortex can exhibit a mixed type of spiking activity. They fire a phasic burst at the onset of stimulation and then switch to the tonic spiking mode.
[6]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -55.0, 4.0
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Mixed Mode')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Spike Frequency Adaptation
The most common type of excitatory neuron in mammalian neocortex, namely the regular spiking (RS) cell, fires tonic spikes with decreasing frequency. That is, the frequency is relatively high at the onset of stimulation, and then it adapts. Low-threshold spiking (LTS) inhibitory neurons also have this property.
[7]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.01, 0.20, -65.0, 8.0
current = bp.inputs.section_input(values=[0., 30.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Spike Frequency Adaptation')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Class 1 Excitability
The frequency of tonic spiking of neocortical RS excitatory neurons depends on the strength of the input, and it may span the range from 2 Hz to 200 Hz, or even greater. The ability to fire low-frequency spikes when the input is weak (but superthreshold) is called Class 1 excitability.
[8]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, -0.1, -55.0, 6.0
current = bp.inputs.ramp_input(c_start=0., c_end=80., t_start=50., t_end=200., duration=250)
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=250.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Class 1')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 250.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 150)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Class 2 Excitability
Some neurons cannot fire low-frequency spike trains. That is, they are either quiescent or fire a train of spikes with a certain relatively large frequency, say 40 Hz. Such neurons are called Class 2 excitable.
[9]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.20, 0.26, -65.0, 0.0
current = bp.inputs.ramp_input(c_start=0., c_end=10., t_start=50., t_end=200., duration=250)
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=250.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Class 2')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 250.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Spike Latency
Most cortical neurons fire spikes with a delay that depends on the strength of the input signal. For a relatively weak but superthreshold input, the delay, also called spike latency, can be quite large.
[10]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -65.0, 6.0
current = bp.inputs.section_input(values=[0., 50., 0.], durations=[15, 1, 15])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=31.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Spike Latency')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 24.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Subthreshold Oscillations
Practically every brain structure has neurons capable of exhibiting oscillatory potentials. The frequency of such oscillations play an important role and such neurons act as bandpass filters.
[11]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.05, 0.26, -60.0, 0.0
current = bp.inputs.section_input(values=[0., 50., 0.], durations=[15, 1, 200])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=216.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Subthreshold Oscillation')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.tick_params('y', colors='b')
ax1.set_xlim(-0.1, 216.1)
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Rebound Spike
When a neuron receives and then is released from an inhibitory input, it may fire a post-inhibitory (rebound) spike. This phenomenon is related to the anodal break excitation in excitable membranes. Many spiking neurons can fire in response to brief inhibitory inputs thereby blurring the difference between excitation and inhibition.
[12]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.03, 0.25, -60.0, 4.0
current = bp.inputs.section_input(values=[7., 0., 7.], durations=[10, 5, 40])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=55.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Rebound Spike')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 55.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 70)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Depolarizing After-potentials
After firing a spike, the membrane potential of a neuron may exhibit a prolonged after-hyperpolarization (AHP)),or a prolonged depolarized after-potential (DAP). Such DAPs can appear because of dendritic influence, because of a high-threshold inward currents activated during the spike, or because of an interplay between subthreshold voltage-gated currents.
[13]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 1.00, 0.20, -60.0, -21.0
current = bp.inputs.section_input(values=[0., 23, 0], durations=[7, 1, 50])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=58.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Depolarizing After-Potentials')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 58.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Resonator
Due to the resonance phenomenon, neurons having oscillatory potentials can respond selectively to the inputs having frequency content similar to the frequency of subthreshold oscillations. Such neurons can implement frequency-modulated (FM) interactions and multiplexing of signal.
[14]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.10, 0.26, -60.0, -1.0
current = bp.inputs.section_input(values=[-1, 0., -1, 0, -1, 0, -1, 0, -1],
durations=[10, 10, 10, 10, 100, 10, 30, 20, 30])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=230.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Resonator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 230.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(-2, 5)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Integrator
Neurons without oscillatory potentials act as integrators. They prefer high-frequency input; the higher the frequency the more likely they fire.
[15]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, -0.1, -55.0, 6.0
current = bp.inputs.section_input(values=[0, 48, 0, 48, 0, 48, 0, 48, 0],
durations=[19, 1, 1, 1, 28, 1, 1, 1, 56])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=109.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Integrator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 109.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 200)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Threshold Variability
Biological neurons have a variable threshold that depends on the prior activity of the neurons.
[16]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.03, 0.25, -60.0, 4.0
current = bp.inputs.section_input(values=[0, 5, 0, -5, 0, 5, 0],
durations=[13, 3, 78, 2, 2, 3, 13])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=114.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Threshold Variability')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 114.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(-6, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Bistability
Some neurons can exhibit two stable modes of operation: resting and tonic spiking (or even bursting). An excitatory or inhibitory pulse can switch between the modes, thereby creating an interesting possibility for bistability and short-term memory.
[17]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 1.00, 1.50, -60.0, 0.0
current = bp.inputs.section_input(values=[0., 5., 0, 5, 0], durations=[10, 1, 10, 1, 10])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=32.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Bistability')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 32.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

3. Different types of neurons
The table above gives the value of the parameter \(a,b,c,d\) under 7 types of neurons.
Neuron |
a |
b |
c |
d |
---|---|---|---|---|
Regular Spiking (RS) |
0.02 |
0.2 |
-65 |
8 |
Intrinsically Bursting (IB) |
0.02 |
0.2 |
-55 |
4 |
Chattering (CH) |
0.02 |
0.2 |
-50 |
2 |
Fast Spiking (FS) |
0.1 |
0.2 |
-65 |
2 |
Thalamo-cortical (TC) |
0.02 |
0.25 |
-65 |
0.05 |
Resonator (RZ) |
0.1 |
0.26 |
-65 |
2 |
Low-threshold Spiking (LTS) |
0.02 |
0.25 |
-65 |
2 |
Regular Spiking (RS)
[18]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -65, 8
current = bp.inputs.section_input(values=[0., 15.], durations=[50, 250])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=300.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Regular Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 300.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Intrinsically Bursting (IB)
[19]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -55, 4
current = bp.inputs.section_input(values=[0., 15.], durations=[50, 250])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=300.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Intrinsically Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 300.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Chattering (CH)
[20]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -50, 2
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 350])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=400.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Chattering')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 400.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Fast Spiking (FS)
[21]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.1, 0.2, -65, 2
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Fast Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Thalamo-cortical (TC)
[22]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65, 0.05
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 100])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=150.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Thalamo-cortical')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 150.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Resonator (RZ)
[23]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.1, 0.26, -65, 2
current = bp.inputs.section_input(values=[0., 5.], durations=[100, 300])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=400.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Resonator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 400.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

Low-threshold Spiking (LTS)
[24]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65, 2
current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)
fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Low-threshold Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()

References
[1]. Izhikevich, E. M . Which model to use for cortical spiking neurons?[J]. IEEE Trans Neural Netw, 2004, 15(5):1063-1070.
[2]. Izhikevich E M . Simple model of spiking neurons[J]. IEEE Transactions on Neural Networks, 2003, 14(6):1569-72.
(Brette, Romain. 2004) LIF phase locking
Implementation of the paper:
Brette, Romain. “Dynamics of one-dimensional spiking neuron models.” Journal of mathematical biology 48.1 (2004): 38-56.
Author:
Chaoming Wang (chao.brain@qq.com)
[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
import matplotlib.pyplot as plt
[3]:
# set parameters
num = 2000
tau = 100. # ms
Vth = 1. # mV
Vr = 0. # mV
inputs = bm.linspace(2., 4., num)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
class LIF(bp.NeuGroup):
def __init__(self, size, **kwargs):
super(LIF, self).__init__(size, **kwargs)
self.V = bm.Variable(bm.zeros(size))
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.integral = bp.odeint(self.derivative)
def derivative(self, V, t):
return (-V + inputs + 2 * bm.sin(2 * bm.pi * t / tau)) / tau
def update(self, tdi):
V = self.integral(self.V, tdi.t, tdi.dt)
self.spike.value = V >= Vth
self.V.value = bm.where(self.spike > 0., Vr, V)
[5]:
group = LIF(num)
runner = bp.DSRunner(group, monitors=['spike'])
[6]:
t = runner.run(duration=5 * 1000.)
indices, times = bp.measure.raster_plot(runner.mon.spike, runner.mon.ts)
# plt.plot((times % tau) / tau, inputs[indices], ',')
spike_phases = (times % tau) / tau
params = inputs[indices]
plt.scatter(x=spike_phases, y=params, c=spike_phases,
marker=',', s=0.1, cmap="coolwarm")
plt.xlabel('Spike phase')
plt.ylabel('Parameter (input)')
plt.show()

(Gerstner, 2005): Adaptive Exponential Integrate-and-Fire model
Adaptive Exponential Integrate-and-Fire neuron model is a spiking model, describes single neuron behavior and can generate many kinds of firing patterns by tuning parameters.
[1]:
import brainpy as bp
bp.math.set_dt(0.01)
bp.math.enable_x64()
bp.math.set_platform('cpu')
Tonic
[2]:
group = bp.neurons.AdExIF(size=1, a=0., b=60., R=.5, delta_T=2., tau=20., tau_w=30.,
V_reset=-55., V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', title='tonic')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

Adapting
[3]:
group = bp.neurons.AdExIF(size=1, a=0., b=5., R=.5, tau=20., tau_w=100., delta_T=2.,
V_reset=-55., V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(200.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', title='adapting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

Init Bursting
[4]:
group = bp.neurons.AdExIF(size=1, a=.5, b=7., R=.5, tau=5., tau_w=100., delta_T=2.,
V_reset=-51, V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(300.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-55., -35.), title='init_bursting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

Bursting
[5]:
group = bp.neurons.AdExIF(size=1, a=-0.5, b=7., R=.5, delta_T=2., tau=5, tau_w=100,
V_reset=-46, V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='bursting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

Transient
[6]:
group = bp.neurons.AdExIF(size=1, a=1., b=10., R=.5, tau=10, tau_w=100, delta_T=2.,
V_reset=-60, V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 55.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='transient')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

Delayed
[7]:
group = bp.neurons.AdExIF(size=1, a=-1., b=10., R=.5, delta_T=2., tau=5., tau_w=100.,
V_reset=-60, V_rest=-70, V_th=-30, V_T=-50)
runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 20.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='delayed')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)

(Niebur, et. al, 2009) Generalized integrate-and-fire model
Implementation of the paper: Mihalaş, Ştefan, and Ernst Niebur. “A generalized linear integrate-and-fire neural model produces diverse spiking behaviors.” Neural computation 21.3 (2009): 704-718.
[1]:
import matplotlib.pyplot as plt
import brainpy as bp
Model Overview
Generalized integrate-and-fire model is a spiking neuron model, describes single neuron behavior and can generate most kinds of firing patterns by tuning parameters.
Generalized IF model is originated from Leaky Integrate-and-Fire model (LIF model), yet it’s differentiated from LIF model, for it includes internal currents \(I_j\) in its expressions.
Generalized IF neuron fire when \(V\) meet \(V_{th}\):
Different firing patterns
These arbitrary number of internal currents \(I_j\) can be seen as currents caused by ion channels’ dynamics, provides the GeneralizedIF model a flexibility to generate various firing patterns.
With appropriate parameters, we can reproduce most of the single neuron firing patterns. In the original paper (Mihalaş et al., 2009), the author used two internal currents \(I1\) and \(I2\).
[2]:
def run(model, duration, I_ext):
runner = bp.DSRunner(model,
inputs=('input', I_ext, 'iter'),
monitors=['V', 'V_th'])
runner.run(duration)
ts = runner.mon.ts
fig, gs = bp.visualize.get_figure(1, 1, 4, 8)
ax1 = fig.add_subplot(gs[0, 0])
#ax1.title.set_text(f'{mode}')
ax1.plot(ts, runner.mon.V[:, 0], label='V')
ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential')
ax1.set_xlim(-0.1, ts[-1] + 0.1)
plt.legend()
ax2 = ax1.twinx()
ax2.plot(ts, I_ext, color='turquoise', label='input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('External input')
ax2.set_xlim(-0.1, ts[-1] + 0.1)
ax2.set_ylim(-5., 20.)
plt.legend(loc='lower left')
plt.show()
Simulate Generalized IF neuron groups to generate different spiking patterns. Here we plot 20 spiking patterns in groups of 4. The plots are labeled with corresponding pattern names above the plots.
Tonic Spiking
[3]:
Iext, duration = bp.inputs.constant_input([(1.5, 200.)])
neu = bp.neurons.GIF(1)
run(neu, duration, Iext)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Class 1 Excitability
[4]:
Iext, duration = bp.inputs.constant_input([(1. + 1e-6, 500.)])
neu = bp.neurons.GIF(1)
run(neu, duration, Iext)

Spike Frequency Adaptation
[5]:
Iext, duration = bp.inputs.constant_input([(2., 200.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Phasic Spiking
[6]:
Iext, duration = bp.inputs.constant_input([(1.5, 500.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Accomodation
[7]:
Iext, duration = bp.inputs.constant_input([(1.5, 100.),
(0, 500.),
(0.5, 100.),
(1., 100.),
(1.5, 100.),
(0., 100.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Threshold Variability
[8]:
Iext, duration = bp.inputs.constant_input([(1.5, 20.),
(0., 180.),
(-1.5, 20.),
(0., 20.),
(1.5, 20.),
(0., 140.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Rebound Spiking
[9]:
Iext, duration = bp.inputs.constant_input([(0, 50.), (-3.5, 750.), (0., 200.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Class 2 Excitability
[10]:
Iext, duration = bp.inputs.constant_input([(2 * (1. + 1e-6), 200.)])
neu = bp.neurons.GIF(1, a=0.005)
neu.V_th[:] = -30.
run(neu, duration, Iext)

Integrator
[11]:
Iext, duration = bp.inputs.constant_input([(1.5, 20.),
(0., 10.),
(1.5, 20.),
(0., 250.),
(1.5, 20.),
(0., 30.),
(1.5, 20.),
(0., 30.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Input Bistability
[12]:
Iext, duration = bp.inputs.constant_input([(1.5, 100.),
(1.7, 400.),
(1.5, 100.),
(1.7, 400.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)

Hyperpolarization-induced Spiking
[13]:
Iext, duration = bp.inputs.constant_input([(-1., 400.)])
neu = bp.neurons.GIF(1, V_th_reset=-60., V_th_inf=-120.)
neu.V_th[:] = -50.
run(neu, duration, Iext)

Hyperpolarization-induced Bursting
[14]:
Iext, duration = bp.inputs.constant_input([(-1., 400.)])
neu = bp.neurons.GIF(1, V_th_reset=-60., V_th_inf=-120., A1=10.,
A2=-0.6)
neu.V_th[:] = -50.
run(neu, duration, Iext)

Tonic Bursting
[15]:
Iext, duration = bp.inputs.constant_input([(2., 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)

Phasic Bursting
[16]:
Iext, duration = bp.inputs.constant_input([(1.5, 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)

Rebound Bursting
[17]:
Iext, duration = bp.inputs.constant_input([(0, 100.), (-3.5, 500.), (0., 400.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)

Mixed Mode
[18]:
Iext, duration = bp.inputs.constant_input([(2., 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=5., A2=-0.3)
run(neu, duration, Iext)

Afterpotentials
[19]:
Iext, duration = bp.inputs.constant_input([(2., 15.), (0, 185.)])
neu = bp.neurons.GIF(1, a=0.005, A1=5., A2=-0.3)
run(neu, duration, Iext)

Basal Bistability
[20]:
Iext, duration = bp.inputs.constant_input([(5., 10.), (0., 90.), (5., 10.), (0., 90.)])
neu = bp.neurons.GIF(1, A1=8., A2=-0.1)
run(neu, duration, Iext)

Preferred Frequency
[21]:
Iext, duration = bp.inputs.constant_input([(5., 10.),
(0., 10.),
(4., 10.),
(0., 370.),
(5., 10.),
(0., 90.),
(4., 10.),
(0., 290.)])
neu = bp.neurons.GIF(1, a=0.005, A1=-3., A2=0.5)
run(neu, duration, Iext)

Spike Latency
[22]:
Iext, duration = bp.inputs.constant_input([(8., 2.), (0, 48.)])
neu = bp.neurons.GIF(1, a=-0.08)
run(neu, duration, Iext)

(Jansen & Rit, 1995): Jansen-Rit Model
The Jansen-Rit model, a neural mass model of the dynamic interactions between 3 populations:
pyramidal cells (PCs)
excitatory interneurons (EINs)
inhibitory interneurons (IINs)
Originally, the model has been developed to describe the waxing-and-waning of EEG activity in the alpha frequency range (8-12 Hz) in the visual cortex [1]. In the past years, however, it has been used as a generic model to describe the macroscopic electrophysiological activity within a cortical column [2].
By using the linearity of the convolution operation, the dynamic interactions between PCs, EINs and IINs can be expressed via 6 coupled ordinary differential equations that are composed of the two operators defined above:
where \(V_{pce}\), \(V_{pci}\), \(V_{in}\) are used to represent the average membrane potential deflection caused by the excitatory synapses at the PC population, the inhibitory synapses at the PC population, and the excitatory synapses at both interneuron populations, respectively.
[1] B.H. Jansen & V.G. Rit (1995) Electroencephalogram and visual evoked potential generation in a mathematical model of coupled cortical columns. Biological Cybernetics, 73(4): 357-366.
[2] A. Spiegler, S.J. Kiebel, F.M. Atay, T.R. Knösche (2010) Bifurcation analysis of neural mass models: Impact of extrinsic inputs and dendritic time constants. NeuroImage, 52(3): 1041-1058, https://doi.org/10.1016/j.neuroimage.2009.12.081.
[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
class JansenRitModel(bp.DynamicalSystem):
def __init__(self, num, C=135., method='exp_auto'):
super(JansenRitModel, self).__init__()
self.num = num
# parameters #
self.v_max = 5. # maximum firing rate
self.v0 = 6. # firing threshold
self.r = 0.56 # slope of the sigmoid
# other parameters
self.A = 3.25
self.B = 22.
self.a = 100.
self.tau_e = 0.01 # second
self.tau_i = 0.02 # second
self.b = 50.
self.e0 = 2.5
# The connectivity constants
self.C1 = C
self.C2 = 0.8 * C
self.C3 = 0.25 * C
self.C4 = 0.25 * C
# variables #
# y0, y1 and y2 representing the firing rate of
# pyramidal, excitatory and inhibitory neurones.
self.y0 = bm.Variable(bm.zeros(self.num))
self.y1 = bm.Variable(bm.zeros(self.num))
self.y2 = bm.Variable(bm.zeros(self.num))
self.y3 = bm.Variable(bm.zeros(self.num))
self.y4 = bm.Variable(bm.zeros(self.num))
self.y5 = bm.Variable(bm.zeros(self.num))
self.p = bm.Variable(bm.ones(self.num) * 220.)
# integral function
self.derivative = bp.JointEq([self.dy0, self.dy1, self.dy2, self.dy3, self.dy4, self.dy5])
self.integral = bp.odeint(self.derivative, method=method)
def sigmoid(self, x):
return self.v_max / (1. + bm.exp(self.r * (self.v0 - x)))
def dy0(self, y0, t, y3): return y3
def dy1(self, y1, t, y4): return y4
def dy2(self, y2, t, y5): return y5
def dy3(self, y3, t, y0, y1, y2):
return (self.A * self.sigmoid(y1 - y2) - 2 * y3 - y0 / self.tau_e) / self.tau_e
def dy4(self, y4, t, y0, y1, p):
return (self.A * (p + self.C2 * self.sigmoid(self.C1 * y0)) - 2 * y4 - y1 / self.tau_e) / self.tau_e
def dy5(self, y5, t, y0, y2):
return (self.B * self.C4 * self.sigmoid(self.C3 * y0) - 2 * y5 - y2 / self.tau_i) / self.tau_i
def update(self, tdi):
self.y0.value, self.y1.value, self.y2.value, self.y3.value, self.y4.value, self.y5.value = \
self.integral(self.y0, self.y1, self.y2, self.y3, self.y4, self.y5, tdi.t, p=self.p, dt=tdi.dt)
[3]:
def simulation(duration=5.):
dt = 0.1 / 1e3
# random input uniformly distributed between 120 and 320 pulses per second
all_ps = bm.random.uniform(120, 320, size=(int(duration / dt), 1))
jrm = JansenRitModel(num=6, C=bm.array([68., 128., 135., 270., 675., 1350.]))
runner = bp.DSRunner(jrm,
monitors=['y0', 'y1', 'y2', 'y3', 'y4', 'y5'],
inputs=['p', all_ps, 'iter', '='],
dt=dt)
runner.run(duration)
start, end = int(2 / dt), int(duration / dt)
fig, gs = bp.visualize.get_figure(6, 3, 2, 3)
for i in range(6):
fig.add_subplot(gs[i, 0])
title = 'E' if i == 0 else None
xlabel = 'time [s]' if i == 5 else None
bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y1[start: end, i],
title=title, xlabel=xlabel, ylabel='Hz')
fig.add_subplot(gs[i, 1])
title = 'P' if i == 0 else None
bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y0[start: end, i],
title=title, xlabel=xlabel)
fig.add_subplot(gs[i, 2])
title = 'I' if i == 0 else None
bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y2[start: end, i],
title=title, show=i==5, xlabel=xlabel)
[4]:
simulation()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Teka, et. al, 2018): Fractional-order Izhikevich neuron model
Implementation of the model:
Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. “Spiking and bursting patterns of fractional-order Izhikevich model.” Communications in Nonlinear Science and Numerical Simulation 56 (2018): 161-176.
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
def run_model(dt=0.1, duration=500, alpha=1.0):
inputs, length = bp.inputs.section_input([0, 10], [50, duration],
dt=dt, return_length=True)
neuron = bp.neurons.FractionalIzhikevich(1, num_memory=int(length / dt), alpha=alpha)
runner = bp.DSRunner(neuron,
monitors=['V'],
inputs=['input', inputs, 'iter'],
dt=dt)
runner.run(length)
plt.plot(runner.mon.ts, runner.mon.V.flatten())
plt.xlabel('Time [ms]')
plt.ylabel('Potential [mV]')
plt.title(r'$\alpha$=' + str(alpha))
plt.show()
Regular spiking
[3]:
run_model(dt=0.1, duration=500, alpha=1.0)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Intrinsically bursting
[4]:
run_model(dt=0.1, duration=500, alpha=0.87)

Mixed Mode (Irregular)
[5]:
run_model(dt=0.1, duration=500, alpha=0.86)

Chattering
[6]:
run_model(dt=0.1, duration=500, alpha=0.8)

Bursting
[7]:
run_model(dt=0.1, duration=1000, alpha=0.7)

Bursting with longer bursts
[8]:
run_model(dt=0.1, duration=1000, alpha=0.5)

Fast spiking
[9]:
run_model(dt=0.1, duration=1000, alpha=0.3)

(Mondal, et. al, 2019): Fractional-order FitzHugh-Rinzel bursting neuron model
Implementation of the paper:
Mondal, A., Sharma, S.K., Upadhyay, R.K. et al. Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. Sci Rep 9, 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
def run_model(model, inputs, length):
runner = bp.DSRunner(model, monitors=['V'], inputs=inputs)
runner.run(length)
return runner
Parameter set 1
[3]:
dt = 0.1
Iext = 0.3125
duration = 4000
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.0001,
w_initializer=bp.init.Constant(-0.1),
y_initializer=bp.init.Constant(0.1))
[4]:
alphas = [1.0, 0.98, 0.95]
plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
neuron = bp.neurons.FractionalFHR(1,
alpha=alpha,
num_memory=4000,
**neuron_pars)
runner = run_model(neuron, inputs=['input', Iext], length=duration)
plt.subplot(len(alphas), 1, i+1)
plt.plot(runner.mon.ts, runner.mon.V[:, 0])
plt.title(r'$\alpha$=' + str(alphas[i]))
plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Parameter set 2
[5]:
Iext = 0.4
duration = 3500
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.0001)
[6]:
alphas = [1.0, 0.92, 0.85, 0.68]
plt.figure(figsize=(12, 10))
for i, alpha in enumerate(alphas):
neuron = bp.neurons.FractionalFHR(1,
alpha=alpha,
num_memory=4000,
**neuron_pars)
runner = run_model(neuron, inputs=['input', Iext], length=duration)
plt.subplot(len(alphas), 1, i+1)
plt.plot(runner.mon.ts, runner.mon.V[:, 0])
plt.title(r'$\alpha$=' + str(alphas[i]))
plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()

Parameter set 3
[7]:
Iext = 3
duration = 500
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.18)
[8]:
alphas = [1.0, 0.99, 0.97, 0.95]
plt.figure(figsize=(12, 10))
for i, alpha in enumerate(alphas):
neuron = bp.neurons.FractionalFHR(1,
alpha=alpha,
**neuron_pars)
runner = run_model(neuron, inputs=['input', Iext], length=duration)
plt.subplot(len(alphas), 1, i+1)
plt.plot(runner.mon.ts, runner.mon.V[:, 0])
plt.title(r'$\alpha$=' + str(alphas[i]))
plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()

Parameter set 4
[9]:
Iext = 0.3125
duration = 3500
neuron_pars = dict(a=0.7, b=0.8, c=1.3, d=1., delta=0.08, mu=0.0001)
[10]:
alphas = [1.0, 0.85, 0.80]
plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
neuron = bp.neurons.FractionalFHR(1,
alpha=alpha,
num_memory=3000,
**neuron_pars)
runner = run_model(neuron, inputs=['input', Iext], length=duration)
plt.subplot(len(alphas), 1, i+1)
plt.plot(runner.mon.ts, runner.mon.V[:, 0])
plt.title(r'$\alpha$=' + str(alphas[i]))
plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()

Parameter set 5
[11]:
Iext = 0.3125
duration = 2500
neuron_pars = dict(a=0.7, b=0.8, c=-0.908, d=1., delta=0.08, mu=0.002)
[12]:
alphas = [1.0, 0.98, 0.95]
plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
neuron = bp.neurons.FractionalFHR(1, alpha=alpha, **neuron_pars)
runner = run_model(neuron, inputs=['input', Iext], length=duration)
plt.subplot(len(alphas), 1, i+1)
plt.plot(runner.mon.ts, runner.mon.V[:, 0])
plt.title(r'$\alpha$=' + str(alphas[i]))
plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()

(Si Wu, 2008): Continuous-attractor Neural Network 1D
Here we show the implementation of the paper:
Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.
Author:
Chaoming Wang (chao.brain@qq.com)
The mathematical equation of the Continuous-Attractor Neural Network (CANN) is given by:
[7]:
import brainpy as bp
import brainpy.math as bm
[8]:
class CANN1D(bp.dyn.NeuGroup):
def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4.,
z_min=-bm.pi, z_max=bm.pi, **kwargs):
super(CANN1D, self).__init__(size=num, **kwargs)
# parameters
self.tau = tau # The synaptic time constant
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, num) # The encoded feature values
self.rho = num / self.z_range # The neural density
self.dx = self.z_range / num # The stimulus density
# variables
self.u = bm.Variable(bm.zeros(num))
self.input = bm.Variable(bm.zeros(num))
# The connection matrix
self.conn_mat = self.make_conn(self.x)
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, u, t, Iext):
r1 = bm.square(u)
r2 = 1.0 + self.k * bm.sum(r1)
r = r1 / r2
Irec = bm.dot(self.conn_mat, r)
du = (-u + Irec + Iext) / self.tau
return du
def dist(self, d):
d = bm.remainder(d, self.z_range)
d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
return d
def make_conn(self, x):
assert bm.ndim(x) == 1
x_left = bm.reshape(x, (-1, 1))
x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
d = self.dist(x_left - x_right)
Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / \
(bm.sqrt(2 * bm.pi) * self.a)
return Jxx
def get_stimulus_by_pos(self, pos):
return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def update(self, tdi):
self.u.value = self.integral(self.u, tdi.t, self.input, tdi.dt)
self.input[:] = 0.
[9]:
cann = CANN1D(num=512, k=0.1)
Population coding
[10]:
I1 = cann.get_stimulus_by_pos(0.)
Iext, duration = bp.inputs.section_input(values=[0., I1, 0.],
durations=[1., 8., 8.],
return_length=True)
runner = bp.DSRunner(cann,
inputs=['input', Iext, 'iter'],
monitors=['u'])
runner.run(duration)
bp.visualize.animate_1D(
dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
{'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
frame_step=1,
frame_delay=100,
show=True,
# save_path='../../images/cann-encoding.gif'
)
Template matching
The cann can perform efficient population decoding by achieving template-matching.
[11]:
cann.k = 8.1
dur1, dur2, dur3 = 10., 30., 0.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
Iext = bm.zeros((num1 + num2 + num3,) + cann.size)
Iext[:num1] = cann.get_stimulus_by_pos(0.5)
Iext[num1:num1 + num2] = cann.get_stimulus_by_pos(0.)
Iext[num1:num1 + num2] += 0.1 * cann.A * bm.random.randn(num2, *cann.size)
runner = bp.dyn.DSRunner(cann,
inputs=('input', Iext, 'iter'),
monitors=['u'])
runner.run(dur1 + dur2 + dur3)
bp.visualize.animate_1D(
dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
{'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
frame_step=5,
frame_delay=50,
show=True,
# save_path='../../images/cann-decoding.gif'
)
Smooth tracking
The cann can track moving stimulus.
[12]:
dur1, dur2, dur3 = 20., 20., 20.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
position = bm.zeros(num1 + num2 + num3)
position[num1: num1 + num2] = bm.linspace(0., 12., num2)
position[num1 + num2:] = 12.
position = position.reshape((-1, 1))
Iext = cann.get_stimulus_by_pos(position)
runner = bp.dyn.DSRunner(cann,
inputs=('input', Iext, 'iter'),
monitors=['u'])
runner.run(dur1 + dur2 + dur3)
bp.visualize.animate_1D(
dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
{'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
frame_step=5,
frame_delay=50,
show=True,
# save_path='../../images/cann-tracking.gif'
)
(Si Wu, 2008): Continuous-attractor Neural Network 2D
Implementation of the paper: - Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.
The mathematical equation of the Continuous-Attractor Neural Network (CANN) is given by:
The model solving is usig Fast Fourier transform. It can run on CPU, GPU, TPU devices.
[1]:
import matplotlib.pyplot as plt
import jax
[2]:
import brainpy as bp
import brainpy.math as bm
[3]:
class CANN2D(bp.dyn.NeuGroup):
def __init__(self, length, tau=1., k=8.1, a=0.5, A=10., J0=4.,
z_min=-bm.pi, z_max=bm.pi, name=None):
super(CANN2D, self).__init__(size=(length, length), name=name)
# parameters
self.length = length
self.tau = tau # The synaptic time constant
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, length) # The encoded feature values
self.rho = length / self.z_range # The neural density
self.dx = self.z_range / length # The stimulus density
# The connections
self.conn_mat = self.make_conn()
# variables
self.r = bm.Variable(bm.zeros((length, length)))
self.u = bm.Variable(bm.zeros((length, length)))
self.input = bm.Variable(bm.zeros((length, length)))
def show_conn(self):
plt.imshow(bm.as_numpy(self.conn_mat))
plt.colorbar()
plt.show()
def dist(self, d):
v_size = bm.asarray([self.z_range, self.z_range])
return bm.where(d > v_size / 2, v_size - d, d)
def make_conn(self):
x1, x2 = bm.meshgrid(self.x, self.x)
value = bm.stack([x1.flatten(), x2.flatten()]).T
@jax.vmap
def get_J(v):
d = self.dist(bm.abs(v - value))
d = bm.linalg.norm(d, axis=1)
# d = d.reshape((self.length, self.length))
Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
return Jxx
return get_J(value)
def get_stimulus_by_pos(self, pos):
assert bm.size(pos) == 2
x1, x2 = bm.meshgrid(self.x, self.x)
value = bm.stack([x1.flatten(), x2.flatten()]).T
d = self.dist(bm.abs(bm.asarray(pos) - value))
d = bm.linalg.norm(d, axis=1)
d = d.reshape((self.length, self.length))
return self.A * bm.exp(-0.25 * bm.square(d / self.a))
def update(self, tdi):
r1 = bm.square(self.u)
r2 = 1.0 + self.k * bm.sum(r1)
self.r.value = r1 / r2
interaction = (self.r.flatten() @ self.conn_mat).reshape((self.length, self.length))
self.u.value = self.u + (-self.u + self.input + interaction) / self.tau * tdi.dt
self.input[:] = 0.
[4]:
cann = CANN2D(length=100, k=0.1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
cann.show_conn()

[6]:
# encoding
Iext, length = bp.inputs.section_input(
values=[cann.get_stimulus_by_pos([0., 0.]), 0.],
durations=[10., 20.],
return_length=True
)
runner = bp.dyn.DSRunner(cann,
inputs=['input', Iext, 'iter'],
monitors=['r'],
dyn_vars=cann.vars())
runner.run(length)
[7]:
bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)),
net_size=(cann.length, cann.length))
[8]:
# tracking
length = 20
positions = bp.inputs.ramp_input(-bm.pi, bm.pi, duration=length, t_start=0)
positions = bm.stack([positions, positions]).T
Iext = jax.vmap(cann.get_stimulus_by_pos)(positions)
runner = bp.dyn.DSRunner(cann,
inputs=['input', Iext, 'iter'],
monitors=['r'],
dyn_vars=cann.vars())
runner.run(length)
[9]:
bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)),
net_size=(cann.length, cann.length))
CANN 1D Oscillatory Tracking
Implementation of the paper:
Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.
Mi, Y., Fung, C. C., Wong, M. K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 1(January), 505.
[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
[2]:
class CANN1D(bp.NeuGroup):
def __init__(self, num, tau=1., tau_v=50., k=1., a=0.3, A=0.2, J0=1.,
z_min=-bm.pi, z_max=bm.pi, m=0.3):
super(CANN1D, self).__init__(size=num)
# parameters
self.tau = tau # The synaptic time constant
self.tau_v = tau_v
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
self.m = m
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, num) # The encoded feature values
self.rho = num / self.z_range # The neural density
self.dx = self.z_range / num # The stimulus density
# The connection matrix
self.conn_mat = self.make_conn()
# variables
self.r = bm.Variable(bm.zeros(num))
self.u = bm.Variable(bm.zeros(num))
self.v = bm.Variable(bm.zeros(num))
self.input = bm.Variable(bm.zeros(num))
def dist(self, d):
d = bm.remainder(d, self.z_range)
d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
return d
def make_conn(self):
x_left = bm.reshape(self.x, (-1, 1))
x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)
d = self.dist(x_left - x_right)
conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
return conn
def get_stimulus_by_pos(self, pos):
return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def update(self, tdi):
r1 = bm.square(self.u)
r2 = 1.0 + self.k * bm.sum(r1)
self.r.value = r1 / r2
Irec = bm.dot(self.conn_mat, self.r)
self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * tdi.dt
self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * tdi.dt
self.input[:] = 0.
[3]:
cann = CANN1D(num=512)
[4]:
dur1, dur2, dur3 = 100., 2000., 500.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
position = bm.zeros(num1 + num2 + num3)
final_pos = cann.a / cann.tau_v * 0.6 * dur2
position[num1: num1 + num2] = bm.linspace(0., final_pos, num2)
position[num1 + num2:] = final_pos
position = position.reshape((-1, 1))
Iext = cann.get_stimulus_by_pos(position)
runner = bp.DSRunner(cann,
inputs=('input', Iext, 'iter'),
monitors=['u', 'v'])
runner.run(dur1 + dur2 + dur3)
_ = bp.visualize.animate_1D(
dynamical_vars=[
{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
{'ys': runner.mon.v, 'xs': cann.x, 'legend': 'v'},
{'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}
],
frame_step=30,
frame_delay=5,
show=True
)
(Wang, 2002) Decision making spiking model
Implementation of the paper: Wang, Xiao-Jing. “Probabilistic decision making by slow reverberation in cortical circuits.” Neuron 36.5 (2002): 955-968.
Author : Chaoming Wang (chao.brain@qq.com)
Please refer to github example folder:
(Wong & Wang, 2006) Decision making rate model
Please refer to Anlysis of A Decision Making Model.
(Vreeswijk & Sompolinsky, 1996) E/I balanced network
Overviews
Van Vreeswijk and Sompolinsky proposed E-I balanced network in 1996 to explain the temporally irregular spiking patterns. They suggested that the temporal variability may originated from the balance between excitatory and inhibitory inputs.
There are \(N_E\) excitatory neurons and \(N_I\) inbibitory neurons.
An important feature of the network is random and sparse connectivity. Connections between neurons \(K\) meets \(1 << K << N_E\).
Implementations
[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
Dynamic of membrane potential is given as:
where \(I_i^{net}(t)\) represents the synaptic current, which describes the sum of excitatory and inhibitory neurons.
where
Parameters: \(J_E = \frac 1 {\sqrt {pN_e}}, J_I = \frac 1 {\sqrt {pN_i}}\)
We can see from the dynamic that network is based on leaky Integrate-and-Fire neurons, and we can just use get_LIF
from bpmodels.neurons
to get this model.
The function of \(I_i^{net}(t)\) is actually a synase with single exponential decay, we can also get it by using get_exponential
.
Network
Let’s create a neuron group with \(N_E\) excitatory neurons and \(N_I\) inbibitory neurons. Use conn=bp.connect.FixedProb(p)
to implement the random and sparse connections.
[2]:
class EINet(bp.Network):
def __init__(self, num_exc, num_inh, prob, JE, JI):
# neurons
pars = dict(V_rest=-52., V_th=-50., V_reset=-60., tau=10., tau_ref=0.,
V_initializer=bp.init.Normal(-60., 10.))
E = bp.neurons.LIF(num_exc, **pars)
I = bp.neurons.LIF(num_inh, **pars)
# synapses
E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob), g_max=JE, tau=2.)
E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob), g_max=JE, tau=2.)
I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob), g_max=JI, tau=2.)
I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob), g_max=JI, tau=2.)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
num_exc = 500
num_inh = 500
prob = 0.1
Ib = 3.
JE = 1 / bp.math.sqrt(prob * num_exc)
JI = -1 / bp.math.sqrt(prob * num_inh)
[4]:
net = EINet(num_exc, num_inh, prob=prob, JE=JE, JI=JI)
runner = bp.DSRunner(net,
monitors=['E.spike'],
inputs=[('E.input', Ib), ('I.input', Ib)])
t = runner.run(1000.)
Visualization
[5]:
import matplotlib.pyplot as plt
fig, gs = bp.visualize.get_figure(4, 1, 2, 10)
fig.add_subplot(gs[:3, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(50, 950))
fig.add_subplot(gs[3, 0])
rates = bp.measure.firing_rate(runner.mon['E.spike'], 5.)
plt.plot(runner.mon.ts, rates)
plt.xlim(50, 950)
plt.show()

Reference
[1] Van Vreeswijk, Carl, and Haim Sompolinsky. “Chaos in neuronal networks with balanced excitatory and inhibitory activity.” Science 274.5293 (1996): 1724-1726.
(Brette, et, al., 2007) COBA
Implementation of the paper:
Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
which is based on the balanced network proposed by:
Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
Authors:
[1]:
import brainpy as bp
bp.math.set_platform('cpu')
Version 1
[2]:
class EINet(bp.Network):
def __init__(self, scale=1.0, method='exp_auto'):
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
E = bp.neurons.LIF(num_exc, **pars, method=method)
I = bp.neurons.LIF(num_inh, **pars, method=method)
# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02),
g_max=we, tau=5., method=method,
output=bp.synouts.COBA(E=0.))
E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02),
g_max=we, tau=5., method=method,
output=bp.synouts.COBA(E=0.))
I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02),
g_max=wi, tau=10., method=method,
output=bp.synouts.COBA(E=-80.))
I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02),
g_max=wi, tau=10., method=method,
output=bp.synouts.COBA(E=-80.))
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
# network
net = EINet()
[4]:
# simulation
runner = bp.DSRunner(
net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)]
)
runner.run(100.)
[5]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)

Version 2
[6]:
class EINet_V2(bp.Network):
def __init__(self, scale=1.0, method='exp_auto'):
super(EINet_V2, self).__init__()
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
self.N = bp.neurons.LIF(num_exc + num_inh,
V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
method=method, V_initializer=bp.initialize.Normal(-55., 2.))
# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc],
post=self.N,
conn=bp.connect.FixedProb(0.02),
g_max=we, tau=5.,
output=bp.synouts.COBA(E=0.),
method=method)
self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:],
post=self.N,
conn=bp.connect.FixedProb(0.02),
g_max=wi, tau=10.,
output=bp.synouts.COBA(E=-80.),
method=method)
[7]:
net = EINet_V2(scale=1., method='exp_auto')
# simulation
runner = bp.DSRunner(
net,
monitors={'spikes': net.N.spike},
inputs=[(net.N.input, 20.)]
)
runner.run(100.)
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)

(Brette, et, al., 2007) CUBA
Implementation of the paper:
Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
which is based on the balanced network proposed by:
Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
Authors:
[1]:
import brainpy as bp
bp.math.set_platform('cpu')
[2]:
class CUBA(bp.Network):
def __init__(self, scale=1.0, method='exp_auto'):
# network size
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
# neurons
pars = dict(V_rest=-49, V_th=-50., V_reset=-60, tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
E = bp.neurons.LIF(num_exc, **pars, method=method)
I = bp.neurons.LIF(num_inh, **pars, method=method)
# synapses
we = 1.62 / scale # excitatory synaptic weight (voltage)
wi = -9.0 / scale # inhibitory synaptic weight
E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(0.02),
g_max=we, tau=5., method=method)
E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(0.02),
g_max=we, tau=5., method=method)
I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(0.02),
g_max=wi, tau=10., method=method)
I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(0.02),
g_max=wi, tau=10., method=method)
super(CUBA, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
# network
net = CUBA()
[4]:
# simulation
runner = bp.DSRunner(net,
monitors=['E.spike'],
inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
[5]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)

(Brette, et, al., 2007) COBA-HH
Implementation of the paper:
Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
which is based on the balanced network proposed by:
Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
Authors:
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
Parameters
[2]:
num_exc = 3200
num_inh = 800
Cm = 200 # Membrane Capacitance [pF]
gl = 10. # Leak Conductance [nS]
g_Na = 20. * 1000
g_Kd = 6. * 1000 # K Conductance [nS]
El = -60. # Resting Potential [mV]
ENa = 50. # reversal potential (Sodium) [mV]
EK = -90. # reversal potential (Potassium) [mV]
VT = -63.
V_th = -20.
# Time constants
taue = 5. # Excitatory synaptic time constant [ms]
taui = 10. # Inhibitory synaptic time constant [ms]
# Reversal potentials
Ee = 0. # Excitatory reversal potential (mV)
Ei = -80. # Inhibitory reversal potential (Potassium) [mV]
# excitatory synaptic weight
we = 6. # excitatory synaptic conductance [nS]
# inhibitory synaptic weight
wi = 67. # inhibitory synaptic conductance [nS]
Implementation 1
[3]:
class HH(bp.NeuGroup):
def __init__(self, size, method='exp_auto'):
super(HH, self).__init__(size)
# variables
self.V = bm.Variable(El + (bm.random.randn(self.num) * 5 - 5))
self.m = bm.Variable(bm.zeros(self.num))
self.n = bm.Variable(bm.zeros(self.num))
self.h = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.input = bm.Variable(bm.zeros(size))
def dV(V, t, m, h, n, Isyn):
gna = g_Na * (m * m * m) * h
gkd = g_Kd * (n * n * n * n)
dVdt = (-gl * (V - El) - gna * (V - ENa) - gkd * (V - EK) + Isyn) / Cm
return dVdt
def dm(m, t, V, ):
m_alpha = 0.32 * (13 - V + VT) / (bm.exp((13 - V + VT) / 4) - 1.)
m_beta = 0.28 * (V - VT - 40) / (bm.exp((V - VT - 40) / 5) - 1)
dmdt = (m_alpha * (1 - m) - m_beta * m)
return dmdt
def dh(h, t, V):
h_alpha = 0.128 * bm.exp((17 - V + VT) / 18)
h_beta = 4. / (1 + bm.exp(-(V - VT - 40) / 5))
dhdt = (h_alpha * (1 - h) - h_beta * h)
return dhdt
def dn(n, t, V):
c = 15 - V + VT
n_alpha = 0.032 * c / (bm.exp(c / 5) - 1.)
n_beta = .5 * bm.exp((10 - V + VT) / 40)
dndt = (n_alpha * (1 - n) - n_beta * n)
return dndt
# functions
self.integral = bp.odeint(bp.JointEq([dV, dm, dh, dn]), method=method)
def update(self, tdi):
V, m, h, n = self.integral(self.V, self.m, self.h, self.n, tdi.t, Isyn=self.input, dt=tdi.dt)
self.spike.value = bm.logical_and(self.V < V_th, V >= V_th)
self.m.value = m
self.h.value = h
self.n.value = n
self.V.value = V
self.input[:] = 0.
[4]:
class ExpCOBA(bp.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
method='exp_auto'):
super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')
# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
self.pre2post = self.conn.require('pre2post')
# variables
self.g = bm.Variable(bm.zeros(self.post.num))
# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
def update(self, tdi):
self.g.value = self.integral(self.g, tdi.t, dt=tdi.dt)
post_sps = bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
self.g.value += post_sps
self.post.input += self.g * (self.E - self.post.V)
[5]:
class COBAHH(bp.Network):
def __init__(self, scale=1., method='exp_auto'):
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
E = HH(num_exc, method=method)
I = HH(num_inh, method=method)
E2E = ExpCOBA(pre=E, post=E, conn=bp.conn.FixedProb(prob=0.02),
E=Ee, g_max=we / scale, tau=taue, method=method)
E2I = ExpCOBA(pre=E, post=I, conn=bp.conn.FixedProb(prob=0.02),
E=Ee, g_max=we / scale, tau=taue, method=method)
I2E = ExpCOBA(pre=I, post=E, conn=bp.conn.FixedProb(prob=0.02),
E=Ei, g_max=wi / scale, tau=taui, method=method)
I2I = ExpCOBA(pre=I, post=I, conn=bp.conn.FixedProb(prob=0.02),
E=Ei, g_max=wi / scale, tau=taui, method=method)
super(COBAHH, self).__init__(E2E, E2I, I2I, I2E, E=E, I=I)
[6]:
net = COBAHH()
[7]:
runner = bp.DSRunner(net, monitors=['E.spike'])
t = runner.run(100.)
[8]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)

Implementation 2
[9]:
class HH2(bp.CondNeuGroup):
def __init__(self, size):
super(HH2, self).__init__(size, )
self.INa = bp.channels.INa_TM1991(size, g_max=100., V_sh=-63.)
self.IK = bp.channels.IK_TM1991(size, g_max=30., V_sh=-63.)
self.IL = bp.channels.IL(size, E=-60., g_max=0.05)
[10]:
class EINet_v2(bp.Network):
def __init__(self, scale=1.):
super(EINet_v2, self).__init__()
prob = 0.02
self.num_exc = int(3200 * scale)
self.num_inh = int(800 * scale)
self.N = HH2(self.num_exc + self.num_inh)
self.Esyn = bp.synapses.Exponential(self.N[:self.num_exc],
self.N,
bp.conn.FixedProb(prob),
g_max=0.03 / scale, tau=5,
output=bp.synouts.COBA(E=0.))
self.Isyn = bp.synapses.Exponential(self.N[self.num_exc:],
self.N,
bp.conn.FixedProb(prob),
g_max=0.335 / scale, tau=10.,
output=bp.synouts.COBA(E=-80))
[11]:
net = EINet_v2(scale=1)
runner = bp.DSRunner(net, monitors={'spikes': net.N.spike})
runner.run(100.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)

(Tian, et al., 2020) E/I Net for fast response
Implementation of the paper:
Tian, Gengshuo, et al. “Excitation-Inhibition Balanced Neural Networks for Fast Signal Detection.” Frontiers in Computational Neuroscience 14 (2020): 79.
Author: Chaoming Wang
[ ]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
# set parameters
num = 10000
num_inh = int(num * 0.2)
num_exc = num - num_inh
prob = 0.25
tau_E = 15.
tau_I = 10.
V_reset = 0.
V_threshold = 15.
f_E = 3.
f_I = 2.
mu_f = 0.1
tau_Es = 6.
tau_Is = 5.
JEE = 0.25
JEI = -1.
JIE = 0.4
JII = -1.
[3]:
class LIF(bp.NeuGroup):
def __init__(self, size, tau, **kwargs):
super(LIF, self).__init__(size, **kwargs)
# parameters
self.tau = tau
# variables
self.V = bp.math.Variable(bp.math.zeros(size))
self.spike = bp.math.Variable(bp.math.zeros(size, dtype=bool))
self.input = bp.math.Variable(bp.math.zeros(size))
# integral
self.integral = bp.odeint(lambda V, t, Isyn: (-V + Isyn) / self.tau)
def update(self, tdi):
V = self.integral(self.V, tdi.t, self.input, tdi.dt)
self.spike.value = V >= V_threshold
self.V.value = bm.where(self.spike, V_reset, V)
self.input[:] = 0.
[4]:
class EINet(bp.Network):
def __init__(self):
# neurons
E = LIF(num_exc, tau=tau_E)
I = LIF(num_inh, tau=tau_I)
E.V[:] = bm.random.random(num_exc) * (V_threshold - V_reset) + V_reset
I.V[:] = bm.random.random(num_inh) * (V_threshold - V_reset) + V_reset
# synapses
E2I = bp.synapses.Exponential(pre=E, post=I, conn=bp.conn.FixedProb(prob), tau=tau_Es, g_max=JIE)
E2E = bp.synapses.Exponential(pre=E, post=E, conn=bp.conn.FixedProb(prob), tau=tau_Es, g_max=JEE)
I2I = bp.synapses.Exponential(pre=I, post=I, conn=bp.conn.FixedProb(prob), tau=tau_Is, g_max=JII)
I2E = bp.synapses.Exponential(pre=I, post=E, conn=bp.conn.FixedProb(prob), tau=tau_Is, g_max=JEI)
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[5]:
net = EINet()
[6]:
runner = bp.DSRunner(net,
monitors=['E.spike', 'I.spike'],
inputs=[('E.input', f_E * bm.sqrt(num) * mu_f),
('I.input', f_I * bm.sqrt(num) * mu_f)])
t = runner.run(100.)
[7]:
# visualization
fig, gs = bp.visualize.get_figure(5, 1, 1.5, 10)
fig.add_subplot(gs[:3, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(0, 100))
fig.add_subplot(gs[3:, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], xlim=(0, 100), show=True)

Classify MNIST dataset by a fully connected LIF layer
The tutorial will introduce how to train a simple SNN using the encoder and the surrogate gradient method to classify MNIST.
Convolutional SNN to Classify Fashion-MNIST
In this tutorial, we will build a convolutional spiking neural network to classify the Fashion-MNIST dataset.
(2022, NeurIPS): Online Training Through Time for Spiking Neural Networks
Implementation of the paper:
Xiao, M., Meng, Q., Zhang, Z., He, D., & Lin, Z. (2022). Online Training Through Time for Spiking Neural Networks. ArXiv, abs/2210.04195.
(2019, Zenke, F.): SNN Surrogate Gradient Learning
Training a spiking neural network with surrogate gradient learning.
(2019, Zenke, F.): SNN Surrogate Gradient Learning to Classify Fashion-MNIST
Training a spiking neural network on a simple vision dataset.
(2021, Raminmh): Liquid time-constant Networks
Training a liquid_time_constant_network with BPTT.
Predicting Mackey-Glass timeseries
[1]:
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bp_data
# bm.set_platform('cpu')
bm.set_environment(mode=bm.batching_mode, x64=True)
[2]:
import numpy as np
import matplotlib.pyplot as plt
Dataset
[3]:
def plot_mackey_glass_series(ts, x_series, x_tau_series, num_sample):
plt.figure(figsize=(13, 5))
plt.subplot(121)
plt.title(f"Timeserie - {num_sample} timesteps")
plt.plot(ts[:num_sample], x_series[:num_sample], lw=2, color="lightgrey", zorder=0)
plt.scatter(ts[:num_sample], x_series[:num_sample], c=ts[:num_sample], cmap="viridis", s=6)
plt.xlabel("$t$")
plt.ylabel("$P(t)$")
ax = plt.subplot(122)
ax.margins(0.05)
plt.title(f"Phase diagram: $P(t) = f(P(t-\\tau))$")
plt.plot(x_tau_series[: num_sample], x_series[: num_sample], lw=1, color="lightgrey", zorder=0)
plt.scatter(x_tau_series[:num_sample], x_series[: num_sample], lw=0.5, c=ts[:num_sample], cmap="viridis", s=6)
plt.xlabel("$P(t-\\tau)$")
plt.ylabel("$P(t)$")
cbar = plt.colorbar()
# cbar.ax.set_ylabel('$t$', rotation=270)
cbar.ax.set_ylabel('$t$')
plt.tight_layout()
plt.show()
[4]:
dt = 0.1
mg_data = bp_data.chaos.MackeyGlassEq(25000, dt=dt, tau=17, beta=0.2, gamma=0.1, n=10, inits=1.2, seed=123)
plot_mackey_glass_series(mg_data.ts, mg_data.xs, mg_data.ys, num_sample=int(1000 / dt))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[5]:
forecast = int(10 / dt) # predict 10 s ahead
train_length = int(20000 / dt)
sample_rate = int(1 / dt)
X_train = mg_data.xs[:train_length:sample_rate]
Y_train = mg_data.xs[forecast: train_length + forecast: sample_rate]
X_test = mg_data.xs[train_length: -forecast: sample_rate]
Y_test = mg_data.xs[train_length + forecast::sample_rate]
X_train = np.expand_dims(X_train, 0)
Y_train = np.expand_dims(Y_train, 0)
X_test = np.expand_dims(X_test, 0)
Y_test = np.expand_dims(Y_test, 0)
[6]:
sample = 300
fig = plt.figure(figsize=(15, 5))
plt.plot(X_train.flatten()[:sample], label="Training data")
plt.plot(Y_train.flatten()[:sample], label="True prediction")
plt.legend()
plt.show()

Model
[7]:
class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in, num_hidden,
Wrec_initializer=bp.init.KaimingNormal())
self.o = bp.layers.Dense(num_hidden, num_out, mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
model = ESN(1, 100, 1)
[8]:
runner = bp.DSTrainer(model)
out = runner.predict(bm.asarray(X_train))
out.shape
[8]:
(1, 20000, 1)
Training
[9]:
trainer = bp.RidgeTrainer(model, alpha=1e-6)
[10]:
_ = trainer.fit([bm.asarray(X_train),
bm.asarray(Y_train)])
Prediction
[11]:
ys_predict = trainer.predict(bm.asarray(X_train), reset_state=True)
start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.arange(end - start).to_numpy(),
bm.as_numpy(ys_predict)[0, start:end, 0],
lw=3,
label="ESN prediction")
plt.plot(bm.arange(end - start).to_numpy(),
Y_train[0, start:end, 0],
linestyle="--",
lw=2,
label="True value")
plt.legend()
plt.show()

[12]:
ys_predict = trainer.predict(bm.asarray(X_test), reset_state=True)
start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.arange(end - start).to_numpy(),
bm.as_numpy(ys_predict)[0, start:end, 0],
lw=3,
label="ESN prediction")
plt.plot(bm.arange(end - start).to_numpy(),
Y_test[0, start:end, 0],
linestyle="--",
lw=2,
label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(bm.as_numpy(ys_predict), Y_test)}')
plt.legend()
plt.show()

(Gauthier, et. al, 2021): Next generation reservoir computing
Implementation of the paper:
Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2
[1]:
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd
bm.set_environment(bm.batching_mode, x64=True)
[2]:
def plot_weights(Wout, coefs, bias=None):
Wout = np.asarray(Wout)
if bias is not None:
bias = np.asarray(bias)
Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0)
coefs.insert(0, 'bias')
x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2]
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(131)
ax.grid(axis="y")
ax.set_xlabel("$[W_{out}]_x$")
ax.set_ylabel("Features")
ax.set_yticks(np.arange(len(coefs)))
ax.set_yticklabels(coefs)
ax.barh(np.arange(x_Wout.size), x_Wout)
ax1 = fig.add_subplot(132)
ax1.grid(axis="y")
ax1.set_yticks(np.arange(len(coefs)))
ax1.set_xlabel("$[W_{out}]_y$")
ax1.barh(np.arange(y_Wout.size), y_Wout)
ax2 = fig.add_subplot(133)
ax2.grid(axis="y")
ax2.set_yticks(np.arange(len(coefs)))
ax2.set_xlabel("$[W_{out}]_z$")
ax2.barh(np.arange(z_Wout.size), z_Wout)
plt.show()
Forecasting Lorenz63 strange attractor
[3]:
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
res = bm.hstack([res['x'], res['y'], res['z']])
return res.reshape((1, ) + res.shape)
[4]:
def plot_lorenz(ground_truth, predictions):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(121, projection='3d')
ax.set_title("Generated attractor")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax.grid(False)
ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Real attractor")
ax2.grid(False)
ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
plt.show()
[5]:
dt = 0.01
t_warmup = 5. # ms
t_train = 10. # ms
t_test = 120. # ms
num_warmup = int(t_warmup / dt) # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)
Datasets
[6]:
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
dt=dt,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
X_warmup = get_subset(lorenz_series, 0, num_warmup - 1)
Y_warmup = get_subset(lorenz_series, 1, num_warmup)
X_train = get_subset(lorenz_series, num_warmup - 1, num_warmup + num_train - 1)
# Target: Lorenz[t] - Lorenz[t - 1]
dX_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_train
X_test = get_subset(lorenz_series,
num_warmup + num_train - 1,
num_warmup + num_train + num_test - 1)
Y_test = get_subset(lorenz_series,
num_warmup + num_train,
num_warmup + num_train + num_test)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Model
[7]:
class NGRC(bp.DynamicalSystem):
def __init__(self, num_in):
super(NGRC, self).__init__()
self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True,)
self.di = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None, mode=bm.training_mode)
def update(self, sha, x):
dx = self.di(sha, self.r(sha, x))
return x + dx
[8]:
model = NGRC(3)
print(model.r.get_feature_names())
['1', 'x0(t)', 'x1(t)', 'x2(t)', 'x0(t-1)', 'x1(t-1)', 'x2(t-1)', 'x0(t)^2', 'x0(t) x1(t)', 'x0(t) x2(t)', 'x0(t) x0(t-1)', 'x0(t) x1(t-1)', 'x0(t) x2(t-1)', 'x1(t)^2', 'x1(t) x2(t)', 'x1(t) x0(t-1)', 'x1(t) x1(t-1)', 'x1(t) x2(t-1)', 'x2(t)^2', 'x2(t) x0(t-1)', 'x2(t) x1(t-1)', 'x2(t) x2(t-1)', 'x0(t-1)^2', 'x0(t-1) x1(t-1)', 'x0(t-1) x2(t-1)', 'x1(t-1)^2', 'x1(t-1) x2(t-1)', 'x2(t-1)^2']
Training
[9]:
# warm-up
trainer = bp.RidgeTrainer(model, alpha=2.5e-6)
# training
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
trainer.fit([X_train, dX_train])
plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b)
Warmup NMS: 107865.37236001804

Prediction
[10]:
model = bm.jit(model)
outputs = [model(dict(), X_test[:, 0])]
for i in range(1, X_test.shape[1]):
outputs.append(model(dict(), outputs[i - 1]))
outputs = bm.asarray(outputs)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS: 148.18244150701176

Forecasting the double-scroll system
[11]:
def plot_double_scroll(ground_truth, predictions):
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(121, projection='3d')
ax.set_title("Generated attractor")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_zlabel("$z$")
ax.grid(False)
ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Real attractor")
ax2.grid(False)
ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
plt.show()
[12]:
dt = 0.02
t_warmup = 10. # ms
t_train = 100. # ms
t_test = 800. # ms
num_warmup = int(t_warmup / dt) # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)
Datasets
[13]:
data_series = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt)
X_warmup = get_subset(data_series, 0, num_warmup - 1)
Y_warmup = get_subset(data_series, 1, num_warmup)
X_train = get_subset(data_series, num_warmup - 1, num_warmup + num_train - 1)
# Target: Lorenz[t] - Lorenz[t - 1]
dX_train = get_subset(data_series, num_warmup, num_warmup + num_train) - X_train
X_test = get_subset(data_series,
num_warmup + num_train - 1,
num_warmup + num_train + num_test - 1)
Y_test = get_subset(data_series,
num_warmup + num_train,
num_warmup + num_train + num_test)
Model
[14]:
class NGRC(bp.DynamicalSystem):
def __init__(self, num_in):
super(NGRC, self).__init__()
self.r = bp.layers.NVAR(num_in, delay=2, order=3)
self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode)
def update(self, sha, x):
di = self.di(sha, self.r(sha, x))
return x + di
model = NGRC(3)
Training
[15]:
# warm-up
trainer = bp.RidgeTrainer(model, alpha=1e-5, jit=True)
[16]:
# training
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
trainer.fit([X_train, dX_train])
plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b)
Warmup NMS: 3.1532995057589828

Prediction
[17]:
model = bm.jit(model)
outputs = [model(dict(), X_test[:, 0])]
for i in range(1, X_test.shape[1]):
outputs.append(model(dict(), outputs[i - 1]))
outputs = bm.asarray(outputs).squeeze()
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_double_scroll(Y_test.numpy().squeeze(), outputs.numpy())
Prediction NMS: 1.4541218384321954

Infering dynamics of Lorenz63 strange attractor
[18]:
def get_subset(data, start, end):
res = {'x': data.xs[start: end],
'y': data.ys[start: end],
'z': data.zs[start: end]}
X = bm.hstack([res['x'], res['y']])
X = X.reshape((1,) + X.shape)
Y = res['z']
Y = Y.reshape((1, ) + Y.shape)
return X, Y
[19]:
def plot_lorenz2(x, y, true_z, predict_z, linewidth=.8):
fig1 = plt.figure()
fig1.set_figheight(8)
fig1.set_figwidth(12)
t_all = t_warmup + t_train + t_test
ts = np.arange(0, t_all, dt)
h = 240
w = 2
# top left of grid is 0,0
axs1 = plt.subplot2grid(shape=(h, w), loc=(0, 0), colspan=2, rowspan=30)
axs2 = plt.subplot2grid(shape=(h, w), loc=(36, 0), colspan=2, rowspan=30)
axs3 = plt.subplot2grid(shape=(h, w), loc=(72, 0), colspan=2, rowspan=30)
axs4 = plt.subplot2grid(shape=(h, w), loc=(132, 0), colspan=2, rowspan=30)
axs5 = plt.subplot2grid(shape=(h, w), loc=(168, 0), colspan=2, rowspan=30)
axs6 = plt.subplot2grid(shape=(h, w), loc=(204, 0), colspan=2, rowspan=30)
# training phase x
axs1.set_title('training phase')
axs1.plot(ts[num_warmup:num_warmup + num_train],
x[num_warmup:num_warmup + num_train],
color='b', linewidth=linewidth)
axs1.set_ylabel('x')
axs1.axes.xaxis.set_ticklabels([])
axs1.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
axs1.axes.set_ybound(-21., 21.)
axs1.text(-.14, .9, 'a)', ha='left', va='bottom', transform=axs1.transAxes)
# training phase y
axs2.plot(ts[num_warmup:num_warmup + num_train],
y[num_warmup:num_warmup + num_train],
color='b', linewidth=linewidth)
axs2.set_ylabel('y')
axs2.axes.xaxis.set_ticklabels([])
axs2.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
axs2.axes.set_ybound(-26., 26.)
axs2.text(-.14, .9, 'b)', ha='left', va='bottom', transform=axs2.transAxes)
# training phase z
axs3.plot(ts[num_warmup:num_warmup + num_train],
true_z[num_warmup:num_warmup + num_train],
color='b', linewidth=linewidth)
axs3.plot(ts[num_warmup:num_warmup + num_train],
predict_z[num_warmup:num_warmup + num_train],
color='r', linewidth=linewidth)
axs3.set_ylabel('z')
axs3.set_xlabel('time')
axs3.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
axs3.axes.set_ybound(3., 48.)
axs3.text(-.14, .9, 'c)', ha='left', va='bottom', transform=axs3.transAxes)
# testing phase x
axs4.set_title('testing phase')
axs4.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
x[num_warmup + num_train:num_warmup + num_train + num_test],
color='b', linewidth=linewidth)
axs4.set_ylabel('x')
axs4.axes.xaxis.set_ticklabels([])
axs4.axes.set_ybound(-21., 21.)
axs4.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
axs4.text(-.14, .9, 'd)', ha='left', va='bottom', transform=axs4.transAxes)
# testing phase y
axs5.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
y[num_warmup + num_train:num_warmup + num_train + num_test],
color='b', linewidth=linewidth)
axs5.set_ylabel('y')
axs5.axes.xaxis.set_ticklabels([])
axs5.axes.set_ybound(-26., 26.)
axs5.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
axs5.text(-.14, .9, 'e)', ha='left', va='bottom', transform=axs5.transAxes)
# testing phose z
axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
true_z[num_warmup + num_train:num_warmup + num_train + num_test],
color='b', linewidth=linewidth)
axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
predict_z[num_warmup + num_train:num_warmup + num_train + num_test],
color='r', linewidth=linewidth)
axs6.set_ylabel('z')
axs6.set_xlabel('time')
axs6.axes.set_ybound(3., 48.)
axs6.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
axs6.text(-.14, .9, 'f)', ha='left', va='bottom', transform=axs6.transAxes)
plt.show()
[20]:
dt = 0.02
t_warmup = 10. # ms
t_train = 20. # ms
t_test = 50. # ms
num_warmup = int(t_warmup / dt) # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)
Datasets
[21]:
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
dt=dt,
inits={'x': 17.67715816276679,
'y': 12.931379185960404,
'z': 43.91404334248268})
X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup)
X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
X_test, Y_test = get_subset(lorenz_series, 0, num_warmup + num_train + num_test)
Model
[22]:
class NGRC(bp.DynamicalSystem):
def __init__(self, num_in):
super(NGRC, self).__init__()
self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5)
self.o = bp.layers.Dense(self.r.num_out, 1, mode=bm.training_mode)
def update(self, sha, x):
return self.o(sha, self.r(sha, x))
model = NGRC(2)
Training
[23]:
trainer = bp.RidgeTrainer(model, alpha=0.05)
# warm-up
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
# training
_ = trainer.fit([X_train, Y_train])
Warmup NMS: 7268.475225524938
Prediction
[24]:
outputs = trainer.predict(X_test, reset_state=True)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
Prediction NMS: 590.5516565428906
[25]:
plot_lorenz2(x=bm.as_numpy(lorenz_series.xs.flatten()),
y=bm.as_numpy(lorenz_series.ys.flatten()),
true_z=bm.as_numpy(lorenz_series.zs.flatten()),
predict_z=bm.as_numpy(outputs.flatten()))

(Sussillo & Abbott, 2009) FORCE Learning
Implementation of the paper:
Sussillo, David, and Larry F. Abbott. “Generating coherent patterns of activity from chaotic neural networks.” Neuron 63, no. 4 (2009): 544-557.
[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
[3]:
class EchoStateNet(bp.DynamicalSystem):
r"""The continuous-time Echo State Network.
.. math::
\frac{dh}{dt} = -h + W_{ir} * x + W_{rr} * r + W_{or} * z \\
r = \tanh(h) \\
o = W_{ro} * r
"""
def __init__(self, num_input, num_hidden, num_output,
tau=1.0, dt=0.1, g=1.8, alpha=1.0, **kwargs):
super(EchoStateNet, self).__init__(**kwargs)
# parameters
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output
self.tau = tau
self.dt = dt
self.g = g
self.alpha = alpha
# weights
self.w_ir = bm.random.normal(size=(num_input, num_hidden)) / bm.sqrt(num_input)
self.w_rr = g * bm.random.normal(size=(num_hidden, num_hidden)) / bm.sqrt(num_hidden)
self.w_or = bm.random.normal(size=(num_output, num_hidden))
w_ro = bm.random.normal(size=(num_hidden, num_output)) / bm.sqrt(num_hidden)
self.w_ro = bm.Variable(w_ro)
# variables
self.h = bm.Variable(bm.random.normal(size=num_hidden) * 0.5) # hidden
self.r = bm.Variable(bm.tanh(self.h)) # firing rate
self.o = bm.Variable(bm.dot(self.r, w_ro)) # output unit
self.P = bm.Variable(bm.eye(num_hidden) * self.alpha) # inverse correlation matrix
def update(self, x):
# update the hidden and output state
dhdt = -self.h + bm.dot(x, self.w_ir)
dhdt += bm.dot(self.r, self.w_rr)
dhdt += bm.dot(self.o, self.w_or)
self.h += self.dt / self.tau * dhdt
self.r.value = bm.tanh(self.h)
self.o.value = bm.dot(self.r, self.w_ro)
def rls(self, target):
# update the inverse correlation matrix
k = bm.expand_dims(bm.dot(self.P, self.r), axis=1) # (num_hidden, 1)
hPh = bm.dot(self.r.T, k) # (1,)
c = 1.0 / (1.0 + hPh) # (1,)
self.P -= bm.dot(k * c, k.T) # (num_hidden, num_hidden)
# update the output weights
e = bm.atleast_2d(self.o - target) # (1, num_output)
dw = bm.dot(-c * k, e) # (num_hidden, num_output)
self.w_ro += dw
def simulate(self, xs):
f = bm.make_loop(self.update, dyn_vars=[self.h, self.r, self.o], out_vars=[self.r, self.o])
return f(xs)
def train(self, xs, targets):
def _f(x):
input, target = x
self.update(input)
self.rls(target)
f = bm.make_loop(_f, dyn_vars=self.vars(), out_vars=[self.r, self.o])
return f([xs, targets])
[4]:
def print_force(ts, rates, outs, targets, duration, ntoplot=10):
"""Plot activations and outputs for the Echo state network."""
plt.figure(figsize=(16, 16))
plt.subplot(321)
plt.plot(ts, targets + 2 * np.arange(0, targets.shape[1]), 'g')
plt.plot(ts, outs + 2 * np.arange(0, outs.shape[1]), 'r')
plt.xlim((0, duration))
plt.title('Target (green), Output (red)')
plt.xlabel('Time')
plt.ylabel('Dimension')
plt.subplot(122)
plt.imshow(rates.T, interpolation=None)
plt.title('Hidden activations of ESN')
plt.xlabel('Time')
plt.ylabel('Dimension')
plt.subplot(323)
plt.plot(ts, rates[:, 0:ntoplot] + 2 * np.arange(0, ntoplot), 'b')
plt.xlim((0, duration))
plt.title('%d hidden activations of ESN' % (ntoplot))
plt.xlabel('Time')
plt.ylabel('Dimension')
plt.subplot(325)
plt.plot(ts, np.sqrt(np.square(outs - targets)), 'c')
plt.xlim((0, duration))
plt.title('Error - mean absolute error')
plt.xlabel('Time')
plt.ylabel('Error')
plt.tight_layout()
plt.show()
[5]:
def plot_params(net):
"""Plot some of the parameters associated with the ESN."""
assert isinstance(net, EchoStateNet)
plt.figure(figsize=(16, 10))
plt.subplot(221)
plt.imshow(bm.as_numpy(net.w_rr + net.w_ro @ net.w_or), interpolation=None)
plt.colorbar()
plt.title('Effective matrix - W_rr + W_ro * W_or')
plt.subplot(222)
plt.imshow(bm.as_numpy(net.w_ro), interpolation=None)
plt.colorbar()
plt.title('Readout weights - W_ro')
x_circ = np.linspace(-1, 1, 1000)
y_circ = np.sqrt(1 - x_circ ** 2)
evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr))
plt.subplot(223)
plt.plot(np.real(evals), np.imag(evals), 'o')
plt.plot(x_circ, y_circ, 'k')
plt.plot(x_circ, -y_circ, 'k')
plt.axis('equal')
plt.title('Eigenvalues of W_rr')
evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr + net.w_ro @ net.w_or))
plt.subplot(224)
plt.plot(np.real(evals), np.imag(evals), 'o', color='orange')
plt.plot(x_circ, y_circ, 'k')
plt.plot(x_circ, -y_circ, 'k')
plt.axis('equal')
plt.title('Eigenvalues of W_rr + W_ro * W_or')
plt.tight_layout()
plt.show()
[6]:
dt = 0.1
T = 30
times = bm.arange(0, T, dt)
xs = bm.zeros((times.shape[0], 1))
Generate some target data by running an ESN, and just grabbing hidden dimensions as the targets of the FORCE trained network.
[7]:
esn1 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.8)
rs, ys = esn1.simulate(xs)
targets = rs[:, 0: esn1.num_output] # This will be the training data for the trained ESN
plt.plot(times, targets + 2 * np.arange(0, esn1.num_output), 'g')
plt.xlim((0, T))
plt.ylabel('Dimensions')
plt.xlabel('Time')
plt.show()

Un-trained ESN.
[8]:
esn2 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.5)
rs, ys = esn2.simulate(xs) # the untrained ESN
print_force(times, rates=rs, outs=ys, targets=targets, duration=T, ntoplot=10)

Trained ESN.
[9]:
esn3 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.5, alpha=1.)
rs, ys = esn3.train(xs=xs, targets=targets) # train once
print_force(times, rates=rs, outs=ys, targets=targets, duration=T, ntoplot=10)

[10]:
plot_params(esn3)

(Sherman & Rinzel, 1992) Gap junction leads to anti-synchronization
Implementation of the paper:
Sherman, A., & Rinzel, J. (1992). Rhythmogenic effects of weak electrotonic coupling in neuronal models. Proceedings of the National Academy of Sciences, 89(6), 2471-2474.
Author: Chaoming Wang
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import matplotlib.pyplot as plt
Fig 1: weakly coupled cells can oscillate antiphase
The “square-wave burster” model is given by:
where
where
At t = 0.5 s the junctional-coupling conductance, \(g_c\), is raised to 0.08, and a small symmetry-breaking perturbation (0.3 mV) is applied to one of the cells. This destabilizes the single-cell oscillation and leads to an antiphase oscillation. At t = 5.5 s the single-cell behavior is restored by increasing \(g_c\) to 0.24; alternatively, one could set \(g_c\) to 0, but then the two cells would not be in-phase.
[3]:
lambda_ = 0.8
V_m = -20
theta_m = 12
V_n = -17
theta_n = 5.6
V_ca = 25
V_K = -75
tau = 20
g_ca = 3.6
g_K = 10
g_s = 4
[4]:
class Model1(bp.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(Model1, self).__init__()
# parameters
self.gc = bm.Variable(bm.zeros(1))
self.I = bm.Variable(bm.zeros(2))
self.S = 0.15
# variables
self.V = bm.Variable(bm.zeros(2))
self.n = bm.Variable(bm.zeros(2))
# integral
self.integral = bp.odeint(bp.JointEq([self.dV, self.dn]), method=method)
def dV(self, V, t, n):
I_in = g_ca / (1 + bp.math.exp((V_m - V) / theta_m)) * (V - V_ca)
I_out = g_K * n * (V - V_K)
Is = g_s * self.S * (V - V_K)
Ij = self.gc * bm.array([V[0] - V[1], V[1] - V[0]])
dV = (- I_in - I_out - Is - Ij + self.I) / tau
return dV
def dn(self, n, t, V):
n_inf = 1 / (1 + bp.math.exp((V_n - V) / theta_n))
dn = lambda_ * (n_inf - n) / tau
return dn
def update(self, tdi):
V, n = self.integral(self.V, self.n, tdi.t, tdi.dt)
self.V.value = V
self.n.value = n
[5]:
def run_and_plot1(model, duration, inputs=None, plot_duration=None):
runner = bp.dyn.DSRunner(model, inputs=inputs, monitors=['V', 'n', 'gc', 'I'])
runner.run(duration)
fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
plot_duration = (0, duration) if plot_duration is None else plot_duration
fig.add_subplot(gs[0:3, 0])
plt.plot(runner.mon.ts, runner.mon.V)
plt.ylabel('V [mV]')
plt.xlim(*plot_duration)
fig.add_subplot(gs[3, 0])
plt.plot(runner.mon.ts, runner.mon.gc)
plt.ylabel(r'$g_c$')
plt.xlim(*plot_duration)
fig.add_subplot(gs[4, 0])
plt.plot(runner.mon.ts, runner.mon.I[:, 0])
plt.ylabel(r'$I_0$')
plt.xlim(*plot_duration)
plt.xlabel('Time [ms]')
plt.show()
[6]:
model = Model1()
model.S = 0.15
model.V[:] = -55.
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
[7]:
gc = bp.inputs.section_input(values=[0., 0.0, 0.24], durations=[500, 5000, 1500])
Is = bp.inputs.section_input(values=[0., bm.array([0.3, 0.])], durations=[500., 6500.])
run_and_plot1(model, duration=7000, inputs=[('gc', gc, 'iter', '='),
('I', Is, 'iter', '=')])

Fig 2: weak coupling can convert excitable cells into spikers
Cells are initially uncoupled and at rest, but one cell has a current of strength 1.0 injected for 0.5 s, resulting in two spikes. Spiking ends when the current stimulus is removed. The unstimulated cell remains at rest. At t = 2 s, \(g_c\) is increased to 0.04. This does not prevent the stimulated cell from remaining at rest, but the system is now bistable and the rest state coexists with an antiphase oscillation. A second identical current stimulus draws both cells near enough to the oscillatory solution so that they continue to oscillate after the stimulus terminates.
[8]:
model = Model1()
model.S = 0.177
model.V[:] = -62.69
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
[9]:
gc = bp.inputs.section_input(values=[0, 0.04], durations=[2000, 2500])
Is = bp.inputs.section_input(values=[bm.array([1., 0.]), 0., bm.array([1., 0.]), 0.],
durations=[500, 2000, 500, 1500])
run_and_plot1(model, 4500, inputs=[('gc', gc, 'iter', '='),
('I', Is, 'iter', '=')])

Fig 3: weak coupling can increase the period of bursting
We consider cells with endogenous bursting properties. Now \(S\) is a slow dynamic variable, satisfying
with \(\tau_S \gg \tau\).
[10]:
tau_S = 35 * 1e3 # ms
V_S = -38 # mV
theta_S = 10 # mV
[11]:
class Model2(bp.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(Model2, self).__init__()
# parameters
self.lambda_ = 0.1
self.gc = bm.Variable(bm.zeros(1))
self.I = bm.Variable(bm.zeros(2))
# variables
self.V = bm.Variable(bm.zeros(2))
self.n = bm.Variable(bm.zeros(2))
self.S = bm.Variable(bm.zeros(2))
# integral
self.integral = bp.odeint(bp.JointEq([self.dV, self.dn, self.dS]), method=method)
def dV(self, V, t, n, S):
I_in = g_ca / (1 + bm.exp((V_m - V) / theta_m)) * (V - V_ca)
I_out = g_K * n * (V - V_K)
Is = g_s * S * (V - V_K)
Ij = self.gc * bm.array([V[0] - V[1], V[1] - V[0]])
dV = (- I_in - I_out - Is - Ij + self.I) / tau
return dV
def dn(self, n, t, V):
n_inf = 1 / (1 + bm.exp((V_n - V) / theta_n))
dn = self.lambda_ * (n_inf - n) / tau
return dn
def dS(self, S, t, V):
S_inf = 1 / (1 + bm.exp((V_S - V) / theta_S))
dS = (S_inf - S) / tau_S
return dS
def update(self, tdi):
V, n, S = self.integral(self.V, self.n, self.S, tdi.t, dt=tdi.dt)
self.V.value = V
self.n.value = n
self.S.value = S
[12]:
def run_and_plot2(model, duration, inputs=None, plot_duration=None):
runner = bp.dyn.DSRunner(model, inputs=inputs, monitors=['V', 'S'])
runner.run(duration)
fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
plot_duration = (0, duration) if plot_duration is None else plot_duration
fig.add_subplot(gs[0:3, 0])
plt.plot(runner.mon.ts, runner.mon.V)
plt.ylabel('V [mV]')
plt.xlim(*plot_duration)
fig.add_subplot(gs[3:, 0])
plt.plot(runner.mon.ts, runner.mon.S)
plt.ylabel('S')
plt.xlim(*plot_duration)
plt.xlabel('Time [ms]')
plt.show()
With \(\lambda = 0.9\), an isolated cell alternates periodically between a depolarized spiking phase and a hyperpolarized silent phase.
[13]:
model = Model2()
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.
model.I[:] = 0.
[14]:
run_and_plot2(model, 50 * 1e3)

When two identical bursters are coupled with \(g_c = 0.06\) and started in-phase, they initially follow the single-cell bursting solution. This behavior is unstable, however, and a new stable burst pattern emerges during the second burst with smaller amplitude, higher frequency, antiphase spikes.
[15]:
model = Model2(method='exp_auto')
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
[16]:
run_and_plot2(model, 50 * 1e3)

[17]:
model = Model2(method='exp_auto')
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
run_and_plot2(model, 4 * 1e3)

Fig 4: weak coupling can convert spikers to bursters
Parameters are the same as in Fig. 3, except \(\lambda = 0.8\), resulting in repetitive spiking (beating) instead of bursting. Oscillations in \(S\) are nearly abolished. Two identical cells are started with identical initial conditions (only one shown for clarity). At t = 20 s, \(g_c\) is increased to 0.04 (right arrow) and a small symmetry-breaking perturbation (0.3 mV) is applied to one cell. After a brief transient, the two cells begin to burst in-phase but with antiphase spikes, as in Fig. 3.
[18]:
model = Model2()
model.lambda_ = 0.8
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
[19]:
gc = bp.inputs.section_input(values=[0., 0.04], durations=[20 * 1e3, 30 * 1e3])
Is = bp.inputs.section_input(values=[0., bp.math.array([0.3, 0.])], durations=[20 * 1e3, 30 * 1e3])
run_and_plot2(model, 50 * 1e3, inputs=[('gc', gc, 'iter', '='),
('I', Is, 'iter', '=')])

(Fazli and Richard, 2022): Electrically Coupled Bursting Pituitary Cells
Implementation of the paper:
Fazli, Mehran, and Richard Bertram. “Network Properties of Electrically Coupled Bursting Pituitary Cells.” Frontiers in Endocrinology 13 (2022).
[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
class PituitaryCell(bp.NeuGroup):
def __init__(self, size, name=None):
super(PituitaryCell, self).__init__(size, name=name)
# parameter values
self.vn = -5
self.kc = 0.12
self.ff = 0.005
self.vca = 60
self.vk = -75
self.vl = -50.0
self.gk = 2.5
self.cm = 5
self.gbk = 1
self.gca = 2.1
self.gsk = 2
self.vm = -20
self.vb = -5
self.sn = 10
self.sm = 12
self.sbk = 2
self.taun = 30
self.taubk = 5
self.ks = 0.4
self.alpha = 0.0015
self.gl = 0.2
# variables
self.V = bm.Variable(bm.random.random(self.num) * -90 + 20)
self.n = bm.Variable(bm.random.random(self.num) / 2)
self.b = bm.Variable(bm.random.random(self.num) / 2)
self.c = bm.Variable(bm.random.random(self.num))
self.input = bm.Variable(self.num)
# integrators
self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler')
def dn(self, n, t, V):
ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn))
return (ninf - n) / self.taun
def db(self, b, t, V):
bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk))
return (bkinf - b) / self.taubk
def dc(self, c, t, V):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
ica = self.gca * minf * (V - self.vca)
return -self.ff * (self.alpha * ica + self.kc * c)
def dV(self, V, t, n, b, c):
minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
cinf = c ** 2 / (c ** 2 + self.ks * self.ks)
ica = self.gca * minf * (V - self.vca)
isk = self.gsk * cinf * (V - self.vk)
ibk = self.gbk * b * (V - self.vk)
ikdr = self.gk * n * (V - self.vk)
il = self.gl * (V - self.vl)
return -(ica + isk + ibk + ikdr + il + self.input) / self.cm
def update(self, tdi, x=None):
V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt)
self.V.value = V
self.n.value = n
self.c.value = c
self.b.value = b
def clear_input(self):
self.input.value = bm.zeros_like(self.input)
[3]:
class PituitaryNetwork(bp.Network):
def __init__(self, num, gc):
super(PituitaryNetwork, self).__init__()
self.N = PituitaryCell(num)
self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc)
[4]:
net = PituitaryNetwork(2, 0.002)
runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5)
runner.run(10 * 1e3)
fig, gs = bp.visualize.get_figure(1, 1, 6, 10)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Wang & Buzsáki, 1996) Gamma Oscillation
Here we show the implementation of gamma oscillation proposed by Xiao-Jing Wang and György Buzsáki (1996). They demonstrated that the GABA\(_A\) synaptic transmission provides a suitable mechanism for synchronized gamma oscillations in a network of fast-spiking interneurons.
Let’s first import brainpy and set profiles.
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_dt(0.05)
The network is constructed with Hodgkin–Huxley (HH) type neurons and GABA\(_A\) synapses.
The dynamics of the HH type neurons is given by:
where \(I(t)\) is the injected current, the leak current $ I_L = g_L (V - E_L) $, and the transient sodium current
where the activation variable \(m\) is assumed fast and substituted by its steady-state function \(m_{\infty} = \alpha_m / (\alpha_m + \beta_m)\). And the inactivation variable \(h\) obeys a first=order kinetics:
where the activation variable \(n\) also obeys a first=order kinetics:
[2]:
class HH(bp.NeuGroup):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35.,
gK=9., gL=0.1, V_th=20., phi=5.0, method='exp_auto'):
super(HH, self).__init__(size=size)
# parameters
self.ENa = ENa
self.EK = EK
self.EL = EL
self.C = C
self.gNa = gNa
self.gK = gK
self.gL = gL
self.V_th = V_th
self.phi = phi
# variables
self.V = bm.Variable(bm.ones(size) * -65.)
self.h = bm.Variable(bm.ones(size) * 0.6)
self.n = bm.Variable(bm.ones(size) * 0.32)
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
self.input = bm.Variable(bm.zeros(size))
self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)
# integral
self.integral = bp.odeint(bp.JointEq([self.dV, self.dh, self.dn]), method=method)
def dh(self, h, t, V):
alpha = 0.07 * bm.exp(-(V + 58) / 20)
beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1)
dhdt = alpha * (1 - h) - beta * h
return self.phi * dhdt
def dn(self, n, t, V):
alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
beta = 0.125 * bm.exp(-(V + 44) / 80)
dndt = alpha * (1 - n) - beta * n
return self.phi * dndt
def dV(self, V, t, h, n, Iext):
m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
m_beta = 4 * bm.exp(-(V + 60) / 18)
m = m_alpha / (m_alpha + m_beta)
INa = self.gNa * m ** 3 * h * (V - self.ENa)
IK = self.gK * n ** 4 * (V - self.EK)
IL = self.gL * (V - self.EL)
dVdt = (- INa - IK - IL + Iext) / self.C
return dVdt
def update(self, tdi):
V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.input, tdi.dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.t_last_spike.value = bm.where(self.spike, tdi.t, self.t_last_spike)
self.V.value = V
self.h.value = h
self.n.value = n
self.input[:] = 0.
Let’s run a simulation of a network with 100 neurons with constant inputs (1 \(\mu\)A/cm\(^2\)).
[3]:
num = 100
neu = HH(num)
neu.V[:] = -70. + bm.random.normal(size=num) * 20
syn = bp.synapses.GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False))
syn.g_max = 0.1 / num
net = bp.Network(neu=neu, syn=syn)
runner = bp.DSRunner(net, monitors=['neu.spike', 'neu.V'], inputs=['neu.input', 1.])
runner.run(duration=500.)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['neu.V'], ylabel='Membrane potential (N0)')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['neu.spike'], show=True)

We can see the result of this simulation that cells starting at random and asynchronous initial conditions quickly become synchronized and their spiking times are perfectly in-phase within 5-6 oscillatory cycles.
Reference:
Wang, Xiao-Jing, and György Buzsáki. “Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model.” Journal of neuroscience 16.20 (1996): 6402-6413.
(Brunel & Hakim, 1999) Fast Global Oscillation
Implementation of the paper:
Brunel, Nicolas, and Vincent Hakim. “Fast global oscillations in networks of integrate-and-fire neurons with low firing rates.” Neural computation 11.7 (1999): 1621-1671.
Author: Chaoming Wang
[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
[2]:
Vr = 10. # mV
theta = 20. # mV
tau = 20. # ms
delta = 2. # ms
taurefr = 2. # ms
duration = 100. # ms
J = .1 # mV
muext = 25. # mV
sigmaext = 1. # mV
C = 1000
N = 5000
sparseness = float(C) / N
[3]:
class LIF(bp.NeuGroup):
def __init__(self, size, **kwargs):
super(LIF, self).__init__(size, **kwargs)
# variables
self.V = bm.Variable(bm.ones(self.num) * Vr)
self.t_last_spike = bm.Variable(-1e7 * bm.ones(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
# integration functions
fv = lambda V, t: (-V + muext) / tau
gv = lambda V, t: sigmaext / bm.sqrt(tau)
self.int_v = bp.sdeint(f=fv, g=gv)
def update(self, tdi):
V = self.int_v(self.V, tdi.t, tdi.dt)
in_ref = (tdi.t - self.t_last_spike) < taurefr
V = bm.where(in_ref, self.V, V)
spike = V >= theta
self.spike.value = spike
self.V.value = bm.where(spike, Vr, V)
self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
self.refractory.value = bm.logical_or(in_ref, spike)
[4]:
group = LIF(N)
syn = bp.synapses.Delta(group, group,
conn=bp.conn.FixedProb(sparseness),
delay_step=int(delta / bm.get_dt()),
post_ref_key='refractory',
output=bp.synouts.CUBA(target_var='V'),
g_max=-J)
net = bp.Network(syn, group=group)
[5]:
runner = bp.DSRunner(net, monitors=['group.spike'])
runner.run(duration)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['group.spike'],
xlim=(0, duration), show=True)

(Diesmann, et, al., 1999) Synfire Chains
Implementation of the paper:
Diesmann, Markus, Marc-Oliver Gewaltig, and Ad Aertsen. “Stable propagation of synchronous spiking in cortical neural networks.” Nature 402.6761 (1999): 529-533.
Author: Chaoming Wang
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
duration = 100. # ms
# Neuron model parameters
Vr = -70. # mV
Vt = -55. # mV
tau_m = 10. # ms
tau_ref = 1. # ms
tau_psp = 0.325 # ms
weight = 4.86 # mV
noise = 39.24 # mV
# Neuron groups
n_groups = 10
group_size = 100
spike_sigma = 1.
# Synapse parameter
delay = 5.0 # ms
[3]:
# neuron model
# ------------
class Groups(bp.NeuGroup):
def __init__(self, size, **kwargs):
super(Groups, self).__init__(size, **kwargs)
self.V = bm.Variable(Vr + bm.random.random(self.num) * (Vt - Vr))
self.x = bm.Variable(bm.zeros(self.num))
self.y = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
# integral functions
self.int_V = bp.odeint(lambda V, t, x: (-(V - Vr) + x) / tau_m)
self.int_x = bp.odeint(lambda x, t, y: (-x + y) / tau_psp)
self.int_y = bp.sdeint(f=lambda y, t: -y / tau_psp + 25.27,
g=lambda y, t: noise)
def update(self, tdi):
self.x[:] = self.int_x(self.x, tdi.t, self.y, tdi.dt)
self.y[:] = self.int_y(self.y, tdi.t, tdi.dt)
in_ref = (tdi.t - self.t_last_spike) < tau_ref
V = self.int_V(self.V, tdi.t, self.x, tdi.dt)
V = bm.where(in_ref, self.V, V)
self.spike.value = V >= Vt
self.t_last_spike.value = bm.where(self.spike, tdi.t, self.t_last_spike)
self.V.value = bm.where(self.spike, Vr, V)
self.refractory.value = bm.logical_or(in_ref, self.spike)
[4]:
# synaptic model
# ---------------
class SynBetweenGroups(bp.TwoEndConn):
def __init__(self, group, ext_group, **kwargs):
super(SynBetweenGroups, self).__init__(group, group, **kwargs)
self.group = group
self.ext = ext_group
# variables
self.delay_step = int(delay/bm.get_dt())
self.g = bm.LengthDelay(bm.zeros(self.group.num), self.delay_step)
def update(self, tdi):
# synapse model between external and group 1
g = bm.zeros(self.group.num)
g[:group_size] = weight * self.ext.spike.sum()
# feed-forward connection
for i in range(1, n_groups):
s1 = (i - 1) * group_size
s2 = i * group_size
s3 = (i + 1) * group_size
g[s2: s3] = weight * self.group.spike[s1: s2].sum()
# delay push
self.g.update(g)
# delay pull
self.group.y += self.g(self.delay_step)
[5]:
# network running
# ---------------
def run_network(spike_num=48):
bm.random.seed(123)
times = bm.random.randn(spike_num) * spike_sigma + 20
ext_group = bp.neurons.SpikeTimeGroup(spike_num, times=times.value, indices=bm.arange(spike_num).value)
group = Groups(size=n_groups * group_size)
syn_conn = SynBetweenGroups(group, ext_group)
net = bp.Network(ext_group=ext_group, syn_conn=syn_conn, group=group)
# simulation
runner = bp.DSRunner(net,
monitors=['group.spike'],
dyn_vars=net.vars() + dict(rng=bm.random.DEFAULT))
runner.run(duration)
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['group.spike'],
xlim=(0, duration), show=True)
[6]:
run_network(spike_num=51)

When external spike num is 44, the synchronous excitation disperses and eventually dies out.
[7]:
run_network(spike_num=44)

(Li, et. al, 2017): Unified Thalamus Oscillation Model
Implementation of the model:
Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. “Unified thalamic model generates multiple distinct oscillations with state-dependent entrainment by stimulation.” PLoS computational biology 13.10 (2017): e1005797.
[1]:
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
from brainpy import channels, synapses, synouts, synplast
HTC neuron
[2]:
class HTC(bp.CondNeuGroup):
def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ):
gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
IL = channels.IL(size, g_max=gL, E=-70)
IKL = channels.IKL(size, g_max=gKL)
INa = channels.INa_Ba2002(size, V_sh=-30)
IDR = channels.IKDR_Ba2002(size, V_sh=-30., phi=0.25)
Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43)
ICaL = channels.ICaL_IS2008(size, g_max=0.5)
IAHP = channels.IAHP_De1994(size, g_max=0.3, E=-90.)
ICaN = channels.ICaN_IS2008(size, g_max=0.5)
ICaT = channels.ICaT_HM1992(size, g_max=2.1)
ICaHT = channels.ICaHT_HM1992(size, g_max=3.0)
Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL,
IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT)
super(HTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20.,
IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)
RTC neuron
[3]:
class RTC(bp.CondNeuGroup):
def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ):
gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
IL = channels.IL(size, g_max=gL, E=-70)
IKL = channels.IKL(size, g_max=gKL)
INa = channels.INa_Ba2002(size, V_sh=-40)
IDR = channels.IKDR_Ba2002(size, V_sh=-40, phi=0.25)
Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43)
ICaL = channels.ICaL_IS2008(size, g_max=0.3)
IAHP = channels.IAHP_De1994(size, g_max=0.1, E=-90.)
ICaN = channels.ICaN_IS2008(size, g_max=0.6)
ICaT = channels.ICaT_HM1992(size, g_max=2.1)
ICaHT = channels.ICaHT_HM1992(size, g_max=0.6)
Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL,
IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT)
super(RTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20.,
IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)
IN neuron
[4]:
class IN(bp.CondNeuGroup):
def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ):
gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
IL = channels.IL(size, g_max=gL, E=-60)
IKL = channels.IKL(size, g_max=gKL)
INa = channels.INa_Ba2002(size, V_sh=-30)
IDR = channels.IKDR_Ba2002(size, V_sh=-30, phi=0.25)
Ih = channels.Ih_HM1992(size, g_max=0.05, E=-43)
IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.)
ICaN = channels.ICaN_IS2008(size, g_max=0.1)
ICaHT = channels.ICaHT_HM1992(size, g_max=2.5)
Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5,
IAHP=IAHP, ICaN=ICaN, ICaHT=ICaHT)
super(IN, self).__init__(size, A=1.7e-4, V_initializer=V_initializer, V_th=20.,
IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)
TRN neuron
[5]:
class TRN(bp.CondNeuGroup):
def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ):
gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
IL = channels.IL(size, g_max=gL, E=-60)
IKL = channels.IKL(size, g_max=gKL)
INa = channels.INa_Ba2002(size, V_sh=-40)
IDR = channels.IKDR_Ba2002(size, V_sh=-40)
IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.)
ICaN = channels.ICaN_IS2008(size, g_max=0.2)
ICaT = channels.ICaT_HP1992(size, g_max=1.3)
Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=100., d=0.5,
IAHP=IAHP, ICaN=ICaN, ICaT=ICaT)
super(TRN, self).__init__(size, A=1.43e-4,
V_initializer=V_initializer, V_th=20.,
IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca)
Thalamus network
[6]:
class MgBlock(bp.SynOut):
def __init__(self, E=0.):
super(MgBlock, self).__init__()
self.E = E
def filter(self, g):
V = self.master.post.V.value
return g * (self.E - V) / (1 + bm.exp(-(V + 25) / 12.5))
[7]:
class Thalamus(bp.Network):
def __init__(
self,
g_input: Dict[str, float],
g_KL: Dict[str, float],
HTC_V_init=bp.init.OneInit(-65.),
RTC_V_init=bp.init.OneInit(-65.),
IN_V_init=bp.init.OneInit(-70.),
RE_V_init=bp.init.OneInit(-70.),
):
super(Thalamus, self).__init__()
# populations
self.HTC = HTC(size=(7, 7), gKL=g_KL['TC'], V_initializer=HTC_V_init)
self.RTC = RTC(size=(12, 12), gKL=g_KL['TC'], V_initializer=RTC_V_init)
self.RE = TRN(size=(10, 10), gKL=g_KL['RE'], V_initializer=IN_V_init)
self.IN = IN(size=(8, 8), gKL=g_KL['IN'], V_initializer=RE_V_init)
# noises
self.poisson_HTC = bp.neurons.PoissonGroup(self.HTC.size, freqs=100)
self.poisson_RTC = bp.neurons.PoissonGroup(self.RTC.size, freqs=100)
self.poisson_IN = bp.neurons.PoissonGroup(self.IN.size, freqs=100)
self.poisson_RE = bp.neurons.PoissonGroup(self.RE.size, freqs=100)
self.noise2HTC = synapses.Exponential(self.poisson_HTC, self.HTC, bp.conn.One2One(),
output=synouts.COBA(E=0.), tau=5.,
g_max=g_input['TC'])
self.noise2RTC = synapses.Exponential(self.poisson_RTC, self.RTC, bp.conn.One2One(),
output=synouts.COBA(E=0.), tau=5.,
g_max=g_input['TC'])
self.noise2IN = synapses.Exponential(self.poisson_IN, self.IN, bp.conn.One2One(),
output=synouts.COBA(E=0.), tau=5.,
g_max=g_input['IN'])
self.noise2RE = synapses.Exponential(self.poisson_RE, self.RE, bp.conn.One2One(),
output=synouts.COBA(E=0.), tau=5.,
g_max=g_input['RE'])
# HTC cells were connected with gap junctions
self.gj_HTC = synapses.GapJunction(self.HTC, self.HTC,
bp.conn.ProbDist(dist=2., prob=0.3, ),
comp_method='sparse',
g_max=1e-2)
# HTC provides feedforward excitation to INs
self.HTC2IN_ampa = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
alpha=0.94,
beta=0.18,
g_max=6e-3)
self.HTC2IN_nmda = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=MgBlock(),
alpha=1.,
beta=0.0067,
g_max=3e-3)
# INs delivered feedforward inhibition to RTC cells
self.IN2RTC = synapses.GABAa(self.IN, self.RTC, bp.conn.FixedProb(0.3),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=synouts.COBA(E=-80),
alpha=10.5,
beta=0.166,
g_max=3e-3)
# 20% RTC cells electrically connected with HTC cells
self.gj_RTC2HTC = synapses.GapJunction(self.RTC, self.HTC,
bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2),
comp_method='sparse',
g_max=1 / 300)
# Both HTC and RTC cells sent glutamatergic synapses to RE neurons, while
# receiving GABAergic feedback inhibition from the RE population
self.HTC2RE_ampa = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
alpha=0.94,
beta=0.18,
g_max=4e-3)
self.RTC2RE_ampa = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
alpha=0.94,
beta=0.18,
g_max=4e-3)
self.HTC2RE_nmda = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=MgBlock(),
alpha=1.,
beta=0.0067,
g_max=2e-3)
self.RTC2RE_nmda = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=MgBlock(),
alpha=1.,
beta=0.0067,
g_max=2e-3)
self.RE2HTC = synapses.GABAa(self.RE, self.HTC, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=synouts.COBA(E=-80),
alpha=10.5,
beta=0.166,
g_max=3e-3)
self.RE2RTC = synapses.GABAa(self.RE, self.RTC, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=synouts.COBA(E=-80),
alpha=10.5,
beta=0.166,
g_max=3e-3)
# RE neurons were connected with both gap junctions and GABAergic synapses
self.gj_RE = synapses.GapJunction(self.RE, self.RE,
bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2),
comp_method='sparse',
g_max=1 / 300)
self.RE2RE = synapses.GABAa(self.RE, self.RE, bp.conn.FixedProb(0.2),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=synouts.COBA(E=-70),
alpha=10.5, beta=0.166,
g_max=1e-3)
# 10% RE neurons project GABAergic synapses to local interneurons
# probability (0.05) was used for the RE->IN synapses according to experimental data
self.RE2IN = synapses.GABAa(self.RE, self.IN, bp.conn.FixedProb(0.05, pre_ratio=0.1),
delay_step=int(2 / bm.get_dt()),
stp=synplast.STD(tau=700, U=0.07),
output=synouts.COBA(E=-80),
alpha=10.5, beta=0.166,
g_max=1e-3, )
Simulation
[8]:
states = {
'delta': dict(g_input={'IN': 1e-4, 'RE': 1e-4, 'TC': 1e-4},
g_KL={'TC': 0.035, 'RE': 0.03, 'IN': 0.01}),
'spindle': dict(g_input={'IN': 3e-4, 'RE': 3e-4, 'TC': 3e-4},
g_KL={'TC': 0.01, 'RE': 0.02, 'IN': 0.015}),
'alpha': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.5e-3},
g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}),
'gamma': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.7e-2},
g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}),
}
[9]:
def rhythm_const_input(amp, freq, length, duration, t_start=0., t_end=None, dt=None):
if t_end is None: t_end = duration
if length > duration:
raise ValueError(f'Expected length <= duration, while we got {length} > {duration}')
sec_length = 1e3 / freq
values, durations = [0.], [t_start]
for t in np.arange(t_start, t_end, sec_length):
values.append(amp)
if t + length <= t_end:
durations.append(length)
values.append(0.)
if t + sec_length <= t_end:
durations.append(sec_length - length)
else:
durations.append(t_end - t - length)
else:
durations.append(t_end - t)
values.append(0.)
durations.append(duration - t_end)
return bp.inputs.section_input(values=values, durations=durations, dt=dt, )
[10]:
def try_trn_neuron():
trn = TRN(1)
I, length = bp.inputs.section_input(values=[0, -0.05, 0],
durations=[100, 100, 500],
return_length=True,
dt=0.01)
runner = bp.DSRunner(trn,
monitors=['V'],
inputs=['input', I, 'iter'],
dt=0.01)
runner.run(length)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
[11]:
try_trn_neuron()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[12]:
def try_network(state='delta'):
duration = 3e3
net = Thalamus(
IN_V_init=bp.init.OneInit(-70.),
RE_V_init=bp.init.OneInit(-70.),
HTC_V_init=bp.init.OneInit(-80.),
RTC_V_init=bp.init.OneInit(-80.),
**states[state],
)
net.reset()
currents = rhythm_const_input(2e-4, freq=4., length=10.,
duration=duration,
t_end=2e3, t_start=1e3)
plt.plot(currents)
plt.title('Current')
plt.show()
runner = bp.DSRunner(
net,
monitors=['HTC.spike', 'RTC.spike', 'RE.spike', 'IN.spike',
'HTC.V', 'RTC.V', 'RE.V', 'IN.V', ],
inputs=[('HTC.input', currents, 'iter'),
('RTC.input', currents, 'iter'),
('IN.input', currents, 'iter')],
)
runner.run(duration)
fig, gs = bp.visualize.get_figure(4, 2, 2, 5)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.get('HTC.V'), ylabel='HTC', xlim=(0, duration))
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RTC.V'), ylabel='RTC', xlim=(0, duration))
fig.add_subplot(gs[2, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.get('IN.V'), ylabel='IN', xlim=(0, duration))
fig.add_subplot(gs[3, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RE.V'), ylabel='RE', xlim=(0, duration))
fig.add_subplot(gs[0, 1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('HTC.spike'), xlim=(0, duration))
fig.add_subplot(gs[1, 1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RTC.spike'), xlim=(0, duration))
fig.add_subplot(gs[2, 1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('IN.spike'), xlim=(0, duration))
fig.add_subplot(gs[3, 1])
bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RE.spike'), xlim=(0, duration))
plt.show()
[13]:
try_network()


(Susin & Destexhe, 2021): Asynchronous Network
Implementation of the paper:
Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert
import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
def __init__(
self,
size,
# neuronal parameters
Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,
# synaptic parameters
tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,
# other parameters
name=None, method='exp_euler',
V_initializer=bp.init.Uniform(-65, -50),
w_initializer=bp.init.Constant(0.),
):
super(AdEx, self).__init__(size=size, name=name)
# neuronal parameters
self.Vth = Vth
self.delta = delta
self.tau_ref = tau_ref
self.tau_w = tau_w
self.a = a
self.b = b
self.C = C
self.gL = gL
self.EL = EL
self.V_reset = V_reset
self.V_sp_th = V_sp_th
# synaptic parameters
self.tau_e = tau_e
self.tau_i = tau_i
self.E_e = E_e
self.E_i = E_i
# neuronal variables
self.V = bp.init.variable_(V_initializer, self.num)
self.w = bp.init.variable_(w_initializer, self.num)
self.spike = bm.Variable(self.num, dtype=bool)
self.refractory = bm.Variable(self.num, dtype=bool)
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)
# synaptic parameters
self.ge = bm.Variable(self.num)
self.gi = bm.Variable(self.num)
# integral
self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)
def dge(self, ge, t):
return -ge / self.tau_e
def dgi(self, gi, t):
return -gi / self.tau_i
def dV(self, V, t, w, ge, gi, Iext=None):
I = ge * (self.E_e - V) + gi * (self.E_i - V)
if Iext is not None: I += Iext
dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
- w + self.gL * (self.EL - V) + I) / self.C
return dVdt
def dw(self, w, t, V):
dwdt = (self.a * (V - self.EL) - w) / self.tau_w
return dwdt
def update(self, tdi, x=None):
V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
tdi.t, Iext=x, dt=tdi.dt)
refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V.value, V)
spike = V >= self.V_sp_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.ge.value = ge
self.gi.value = gi
self.spike.value = spike
self.refractory.value = bm.logical_or(refractory, spike)
self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class AINet(bp.Network):
def __init__(self, ext_varied_rates, ext_weight=1., method='exp_euler', dt=bm.get_dt()):
super(AINet, self).__init__()
self.num_exc = 20000
self.num_inh = 5000
self.exc_syn_tau = 5. # ms
self.inh_syn_tau = 5. # ms
self.exc_syn_weight = 1. # nS
self.inh_syn_weight = 5. # nS
self.num_delay_step = int(1.5 / dt)
self.ext_varied_rates = ext_varied_rates
# neuronal populations
RS_par_ = RS_par.copy()
FS_par_ = FS_par.copy()
RS_par_.update(Vth=-50, V_sp_th=-40)
FS_par_.update(Vth=-50, V_sp_th=-40)
self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1))
# Poisson inputs
self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
# synaptic projections
self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
def change_freq(self, tdi):
self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]
[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
dt = bm.get_dt() if dt is None else dt
t = 0
num_gap = int(t_gap / dt)
num_total = int(t_total / dt)
num_transition = int(t_transition / dt)
inputs = []
ramp_up = np.linspace(c_low, c_high, num_transition)
ramp_down = np.linspace(c_high, c_low, num_transition)
plato_base = np.ones(num_gap) * c_low
while t < num_total:
num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
t += (num_gap + num_transition + num_plato + num_transition)
return bm.asarray(np.concatenate(inputs)[:num_total])
[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
# sampling_space: in seconds (no units)
# signal_time: in seconds (no units)
# low_cut: in Hz (no units)(band to filter)
# high_cut: in Hz (no units)(band to filter)
signal = signal - np.mean(signal)
width = 5.0 # The desired width in Hz of the transition from pass to stop
ripple_db = 60.0 # The desired attenuation in the stop band, in dB.
sampling_rate = 1. / sampling_space
Nyquist = sampling_rate / 2.
num_taps, beta = kaiserord(ripple_db, width / Nyquist)
if num_taps % 2 == 0:
num_taps = num_taps + 1 # Numtaps must be odd
taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
pass_zero=False, scale=True)
filtered_signal = lfilter(taps, 1.0, signal)
delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase
delay_index = int(np.floor(delay * sampling_rate))
filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal
# correcting the delay and taking out the "corrupted" signal part
filtered_time = signal_time[num_taps - 1:] - delay
cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]
# --------------------------------------------------------------------------
# The hilbert transform are very slow when the signal has odd lenght,
# This part check if the length is odd, and if this is the case it adds a zero in the end
# of all the vectors related to the filtered Signal:
if len(filtered_signal) % 2 != 0: # If the lengh is odd
tmp1 = filtered_signal.tolist()
tmp1.append(0)
tmp2 = filtered_time.tolist()
tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
tmp3 = cutted_signal.tolist()
tmp3.append(0)
filtered_signal = np.asarray(tmp1)
filtered_time = np.asarray(tmp2)
cutted_signal = np.asarray(tmp3)
# --------------------------------------------------------------------------
ht_filtered_signal = hilbert(filtered_signal)
envelope = np.abs(ht_filtered_signal)
phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians
return filtered_time, filtered_signal, cutted_signal, envelope, phase
[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
# 1. input firing rate
ax = fig.add_subplot(gs[0])
plt.plot(times, varied_rates)
if xlim is None:
xlim = (0, times[-1])
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('External\nRate (Hz)')
# 2. inhibitory cell rater plot
ax = fig.add_subplot(gs[1: 3])
i = 0
y_ticks = ([], [])
for key, (sp_matrix, sp_type) in spikes.items():
iis, sps = np.where(sp_matrix)
tts = times[iis]
plt.plot(tts, sps + i, '.', markersize=1, label=key)
y_ticks[0].append(i + sp_matrix.shape[1] / 2)
y_ticks[1].append(key)
i += sp_matrix.shape[1]
ax.set_xlim(*xlim)
ax.set_xlabel('')
ax.set_ylabel('Neuron Index')
ax.set_xticks([])
ax.set_yticks(*y_ticks)
# ax.legend()
# 3. example membrane potential
ax = fig.add_subplot(gs[3: 5])
for key, potential in example_potentials.items():
vs = np.where(spikes[key][0][:, 0], 0, potential)
plt.plot(times, vs, label=key)
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('V (mV)')
ax.legend()
# 4. LFP
ax = fig.add_subplot(gs[5:7])
ax.set_xlim(*xlim)
t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
times = times[t1: t2]
lfp = 0
for sp_matrix, sp_type in spikes.values():
lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
bm.get_dt() * 1e-3)
plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
plt.legend(loc='best')
plt.xlabel('Time (ms)')
# save or show
if filename:
plt.savefig(filename, dpi=500)
plt.show()
[8]:
def simulate_ai_net():
duration = 2e3
varied_rates = get_inputs(2., 2., 50., 150, 600, 1e3, duration)
net = AINet(varied_rates, ext_weight=1.)
runner = bp.DSRunner(
net,
inputs=net.change_freq,
monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
'RS.V0': lambda tdi: net.rs_pop.V[0],
'FS.spike': lambda tdi: net.fs_pop.spike,
'RS.spike': lambda tdi: net.rs_pop.spike}
)
runner.run(duration)
visualize_simulation_results(times=runner.mon.ts,
spikes={'FS': (runner.mon['FS.spike'], 'inh'),
'RS': (runner.mon['RS.spike'], 'exc')},
example_potentials={'FS': runner.mon['FS.V0'],
'RS': runner.mon['RS.V0']},
varied_rates=varied_rates.to_numpy())
[9]:
simulate_ai_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Susin & Destexhe, 2021): CHING Network for Generating Gamma Oscillation
Implementation of the paper:
Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert
import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
def __init__(
self,
size,
# neuronal parameters
Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,
# synaptic parameters
tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,
# other parameters
name=None, method='exp_euler',
V_initializer=bp.init.Uniform(-65, -50),
w_initializer=bp.init.Constant(0.),
):
super(AdEx, self).__init__(size=size, name=name)
# neuronal parameters
self.Vth = Vth
self.delta = delta
self.tau_ref = tau_ref
self.tau_w = tau_w
self.a = a
self.b = b
self.C = C
self.gL = gL
self.EL = EL
self.V_reset = V_reset
self.V_sp_th = V_sp_th
# synaptic parameters
self.tau_e = tau_e
self.tau_i = tau_i
self.E_e = E_e
self.E_i = E_i
# neuronal variables
self.V = bp.init.variable_(V_initializer, self.num)
self.w = bp.init.variable_(w_initializer, self.num)
self.spike = bm.Variable(self.num, dtype=bool)
self.refractory = bm.Variable(self.num, dtype=bool)
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)
# synaptic parameters
self.ge = bm.Variable(self.num)
self.gi = bm.Variable(self.num)
# integral
self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)
def dge(self, ge, t):
return -ge / self.tau_e
def dgi(self, gi, t):
return -gi / self.tau_i
def dV(self, V, t, w, ge, gi, Iext=None):
I = ge * (self.E_e - V) + gi * (self.E_i - V)
if Iext is not None: I += Iext
dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
- w + self.gL * (self.EL - V) + I) / self.C
return dVdt
def dw(self, w, t, V):
dwdt = (self.a * (V - self.EL) - w) / self.tau_w
return dwdt
def update(self, tdi, x=None):
V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
tdi.t, Iext=x, dt=tdi.dt)
refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V.value, V)
spike = V >= self.V_sp_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.ge.value = ge
self.gi.value = gi
self.spike.value = spike
self.refractory.value = bm.logical_or(refractory, spike)
self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class CHINGNet(bp.Network):
def __init__(self, ext_varied_rates, method='exp_euler', dt=bm.get_dt()):
super(CHINGNet, self).__init__()
self.num_rs = 19000
self.num_fs = 5000
self.num_ch = 1000
self.exc_syn_tau = 5. # ms
self.inh_syn_tau = 5. # ms
self.exc_syn_weight = 1. # nS
self.inh_syn_weight1 = 7. # nS
self.inh_syn_weight2 = 5. # nS
self.ext_weight1 = 1. # nS
self.ext_weight2 = 0.75 # nS
self.num_delay_step = int(1.5 / dt)
self.ext_varied_rates = ext_varied_rates
# neuronal populations
RS_par_ = RS_par.copy()
FS_par_ = FS_par.copy()
Ch_par_ = Ch_par.copy()
RS_par_.update(Vth=-50, V_sp_th=-40)
FS_par_.update(Vth=-50, V_sp_th=-40)
Ch_par_.update(Vth=-50, V_sp_th=-40)
self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
self.ch_pop = AdEx(self.num_ch, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **Ch_par_)
self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1))
# Poisson inputs
self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.ext_weight2)
self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.ext_weight1)
self.ext_to_CH = bp.synapses.Delta(self.ext_pop, self.ch_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.ext_weight1)
# synaptic projections
self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_Ch = bp.synapses.Delta(self.rs_pop, self.ch_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight1,
delay_step=self.num_delay_step)
self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight2,
delay_step=self.num_delay_step)
self.FS_to_Ch = bp.synapses.Delta(self.fs_pop, self.ch_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight1,
delay_step=self.num_delay_step)
self.Ch_to_RS = bp.synapses.Delta(self.ch_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.Ch_to_FS = bp.synapses.Delta(self.ch_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.Ch_to_Ch = bp.synapses.Delta(self.ch_pop, self.ch_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
def change_freq(self, tdi):
self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]
[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
dt = bm.get_dt() if dt is None else dt
t = 0
num_gap = int(t_gap / dt)
num_total = int(t_total / dt)
num_transition = int(t_transition / dt)
inputs = []
ramp_up = np.linspace(c_low, c_high, num_transition)
ramp_down = np.linspace(c_high, c_low, num_transition)
plato_base = np.ones(num_gap) * c_low
while t < num_total:
num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
t += (num_gap + num_transition + num_plato + num_transition)
return bm.asarray(np.concatenate(inputs)[:num_total])
[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
# sampling_space: in seconds (no units)
# signal_time: in seconds (no units)
# low_cut: in Hz (no units)(band to filter)
# high_cut: in Hz (no units)(band to filter)
signal = signal - np.mean(signal)
width = 5.0 # The desired width in Hz of the transition from pass to stop
ripple_db = 60.0 # The desired attenuation in the stop band, in dB.
sampling_rate = 1. / sampling_space
Nyquist = sampling_rate / 2.
num_taps, beta = kaiserord(ripple_db, width / Nyquist)
if num_taps % 2 == 0:
num_taps = num_taps + 1 # Numtaps must be odd
taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
pass_zero=False, scale=True)
filtered_signal = lfilter(taps, 1.0, signal)
delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase
delay_index = int(np.floor(delay * sampling_rate))
filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal
# correcting the delay and taking out the "corrupted" signal part
filtered_time = signal_time[num_taps - 1:] - delay
cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]
# --------------------------------------------------------------------------
# The hilbert transform are very slow when the signal has odd lenght,
# This part check if the length is odd, and if this is the case it adds a zero in the end
# of all the vectors related to the filtered Signal:
if len(filtered_signal) % 2 != 0: # If the lengh is odd
tmp1 = filtered_signal.tolist()
tmp1.append(0)
tmp2 = filtered_time.tolist()
tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
tmp3 = cutted_signal.tolist()
tmp3.append(0)
filtered_signal = np.asarray(tmp1)
filtered_time = np.asarray(tmp2)
cutted_signal = np.asarray(tmp3)
# --------------------------------------------------------------------------
ht_filtered_signal = hilbert(filtered_signal)
envelope = np.abs(ht_filtered_signal)
phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians
return filtered_time, filtered_signal, cutted_signal, envelope, phase
[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
# 1. input firing rate
ax = fig.add_subplot(gs[0])
plt.plot(times, varied_rates)
if xlim is None:
xlim = (0, times[-1])
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('External\nRate (Hz)')
# 2. inhibitory cell rater plot
ax = fig.add_subplot(gs[1: 3])
i = 0
y_ticks = ([], [])
for key, (sp_matrix, sp_type) in spikes.items():
iis, sps = np.where(sp_matrix)
tts = times[iis]
plt.plot(tts, sps + i, '.', markersize=1, label=key)
y_ticks[0].append(i + sp_matrix.shape[1] / 2)
y_ticks[1].append(key)
i += sp_matrix.shape[1]
ax.set_xlim(*xlim)
ax.set_xlabel('')
ax.set_ylabel('Neuron Index')
ax.set_xticks([])
ax.set_yticks(*y_ticks)
# ax.legend()
# 3. example membrane potential
ax = fig.add_subplot(gs[3: 5])
for key, potential in example_potentials.items():
vs = np.where(spikes[key][0][:, 0], 0, potential)
plt.plot(times, vs, label=key)
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('V (mV)')
ax.legend()
# 4. LFP
ax = fig.add_subplot(gs[5:7])
ax.set_xlim(*xlim)
t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
times = times[t1: t2]
lfp = 0
for sp_matrix, sp_type in spikes.values():
lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
bm.get_dt() * 1e-3)
plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
plt.legend(loc='best')
plt.xlabel('Time (ms)')
# save or show
if filename:
plt.savefig(filename, dpi=500)
plt.show()
[8]:
def simulate_ching_net():
duration = 6e3
varied_rates = get_inputs(1., 2., 50., 150, 600, 1e3, duration)
net = CHINGNet(varied_rates)
runner = bp.DSRunner(
net,
inputs=net.change_freq,
monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
'CH.V0': lambda tdi: net.ch_pop.V[0],
'RS.V0': lambda tdi: net.rs_pop.V[0],
'FS.spike': lambda tdi: net.fs_pop.spike,
'CH.spike': lambda tdi: net.ch_pop.spike,
'RS.spike': lambda tdi: net.rs_pop.spike}
)
runner.run(duration)
visualize_simulation_results(times=runner.mon.ts,
spikes={'FS': (runner.mon['FS.spike'], 'inh'),
'CH': (runner.mon['CH.spike'], 'exc'),
'RS': (runner.mon['RS.spike'], 'exc')},
example_potentials={'FS': runner.mon['FS.V0'],
'CH': runner.mon['CH.V0'],
'RS': runner.mon['RS.V0']},
varied_rates=varied_rates.to_numpy(),
xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)
[9]:
simulate_ching_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Susin & Destexhe, 2021): ING Network for Generating Gamma Oscillation
Implementation of the paper:
Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert
import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
def __init__(
self,
size,
# neuronal parameters
Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,
# synaptic parameters
tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,
# other parameters
name=None, method='exp_euler',
V_initializer=bp.init.Uniform(-65, -50),
w_initializer=bp.init.Constant(0.),
):
super(AdEx, self).__init__(size=size, name=name)
# neuronal parameters
self.Vth = Vth
self.delta = delta
self.tau_ref = tau_ref
self.tau_w = tau_w
self.a = a
self.b = b
self.C = C
self.gL = gL
self.EL = EL
self.V_reset = V_reset
self.V_sp_th = V_sp_th
# synaptic parameters
self.tau_e = tau_e
self.tau_i = tau_i
self.E_e = E_e
self.E_i = E_i
# neuronal variables
self.V = bp.init.variable_(V_initializer, self.num)
self.w = bp.init.variable_(w_initializer, self.num)
self.spike = bm.Variable(self.num, dtype=bool)
self.refractory = bm.Variable(self.num, dtype=bool)
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)
# synaptic parameters
self.ge = bm.Variable(self.num)
self.gi = bm.Variable(self.num)
# integral
self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)
def dge(self, ge, t):
return -ge / self.tau_e
def dgi(self, gi, t):
return -gi / self.tau_i
def dV(self, V, t, w, ge, gi, Iext=None):
I = ge * (self.E_e - V) + gi * (self.E_i - V)
if Iext is not None: I += Iext
dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
- w + self.gL * (self.EL - V) + I) / self.C
return dVdt
def dw(self, w, t, V):
dwdt = (self.a * (V - self.EL) - w) / self.tau_w
return dwdt
def update(self, tdi, x=None):
V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
tdi.t, Iext=x, dt=tdi.dt)
refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V.value, V)
spike = V >= self.V_sp_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.ge.value = ge
self.gi.value = gi
self.spike.value = spike
self.refractory.value = bm.logical_or(refractory, spike)
self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class INGNet(bp.Network):
def __init__(self, ext_varied_rates, ext_weight=0.9, method='exp_euler', dt=bm.get_dt()):
super(INGNet, self).__init__()
self.num_rs = 20000
self.num_fs = 4000
self.num_fs2 = 1000
self.exc_syn_tau = 5. # ms
self.inh_syn_tau = 5. # ms
self.exc_syn_weight = 1. # nS
self.inh_syn_weight = 5. # nS
self.num_delay_step = int(1.5 / dt)
self.ext_varied_rates = ext_varied_rates
# neuronal populations
RS_par_ = RS_par.copy()
FS_par_ = FS_par.copy()
FS2_par_ = FS_par.copy()
RS_par_.update(Vth=-50, V_sp_th=-40)
FS_par_.update(Vth=-50, V_sp_th=-40)
FS2_par_.update(Vth=-50, V_sp_th=-40)
self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
self.fs2_pop = AdEx(self.num_fs2, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS2_par_)
self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1))
# Poisson inputs
self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
self.ext_to_FS2 = bp.synapses.Delta(self.ext_pop, self.fs2_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
# synaptic projections
self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_FS2 = bp.synapses.Delta(self.rs_pop, self.fs2_pop, bp.conn.FixedProb(0.15),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_FS2 = bp.synapses.Delta(self.fs_pop, self.fs2_pop, bp.conn.FixedProb(0.03),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
self.FS2_to_RS = bp.synapses.Delta(self.fs2_pop, self.rs_pop, bp.conn.FixedProb(0.15),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS2_to_FS = bp.synapses.Delta(self.fs2_pop, self.fs_pop, bp.conn.FixedProb(0.15),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS2_to_FS2 = bp.synapses.Delta(self.fs2_pop, self.fs2_pop, bp.conn.FixedProb(0.6),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
def change_freq(self, tdi):
self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]
[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
dt = bm.get_dt() if dt is None else dt
t = 0
num_gap = int(t_gap / dt)
num_total = int(t_total / dt)
num_transition = int(t_transition / dt)
inputs = []
ramp_up = np.linspace(c_low, c_high, num_transition)
ramp_down = np.linspace(c_high, c_low, num_transition)
plato_base = np.ones(num_gap) * c_low
while t < num_total:
num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
t += (num_gap + num_transition + num_plato + num_transition)
return bm.asarray(np.concatenate(inputs)[:num_total])
[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
# sampling_space: in seconds (no units)
# signal_time: in seconds (no units)
# low_cut: in Hz (no units)(band to filter)
# high_cut: in Hz (no units)(band to filter)
signal = signal - np.mean(signal)
width = 5.0 # The desired width in Hz of the transition from pass to stop
ripple_db = 60.0 # The desired attenuation in the stop band, in dB.
sampling_rate = 1. / sampling_space
Nyquist = sampling_rate / 2.
num_taps, beta = kaiserord(ripple_db, width / Nyquist)
if num_taps % 2 == 0:
num_taps = num_taps + 1 # Numtaps must be odd
taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
pass_zero=False, scale=True)
filtered_signal = lfilter(taps, 1.0, signal)
delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase
delay_index = int(np.floor(delay * sampling_rate))
filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal
# correcting the delay and taking out the "corrupted" signal part
filtered_time = signal_time[num_taps - 1:] - delay
cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]
# --------------------------------------------------------------------------
# The hilbert transform are very slow when the signal has odd lenght,
# This part check if the length is odd, and if this is the case it adds a zero in the end
# of all the vectors related to the filtered Signal:
if len(filtered_signal) % 2 != 0: # If the lengh is odd
tmp1 = filtered_signal.tolist()
tmp1.append(0)
tmp2 = filtered_time.tolist()
tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
tmp3 = cutted_signal.tolist()
tmp3.append(0)
filtered_signal = np.asarray(tmp1)
filtered_time = np.asarray(tmp2)
cutted_signal = np.asarray(tmp3)
# --------------------------------------------------------------------------
ht_filtered_signal = hilbert(filtered_signal)
envelope = np.abs(ht_filtered_signal)
phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians
return filtered_time, filtered_signal, cutted_signal, envelope, phase
[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
# 1. input firing rate
ax = fig.add_subplot(gs[0])
plt.plot(times, varied_rates)
if xlim is None:
xlim = (0, times[-1])
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('External\nRate (Hz)')
# 2. inhibitory cell rater plot
ax = fig.add_subplot(gs[1: 3])
i = 0
y_ticks = ([], [])
for key, (sp_matrix, sp_type) in spikes.items():
iis, sps = np.where(sp_matrix)
tts = times[iis]
plt.plot(tts, sps + i, '.', markersize=1, label=key)
y_ticks[0].append(i + sp_matrix.shape[1] / 2)
y_ticks[1].append(key)
i += sp_matrix.shape[1]
ax.set_xlim(*xlim)
ax.set_xlabel('')
ax.set_ylabel('Neuron Index')
ax.set_xticks([])
ax.set_yticks(*y_ticks)
# ax.legend()
# 3. example membrane potential
ax = fig.add_subplot(gs[3: 5])
for key, potential in example_potentials.items():
vs = np.where(spikes[key][0][:, 0], 0, potential)
plt.plot(times, vs, label=key)
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('V (mV)')
ax.legend()
# 4. LFP
ax = fig.add_subplot(gs[5:7])
ax.set_xlim(*xlim)
t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
times = times[t1: t2]
lfp = 0
for sp_matrix, sp_type in spikes.values():
lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
bm.get_dt() * 1e-3)
plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
plt.legend(loc='best')
plt.xlabel('Time (ms)')
# save or show
if filename:
plt.savefig(filename, dpi=500)
plt.show()
[8]:
def simulate_ing_net():
duration = 6e3
varied_rates = get_inputs(2., 3., 50., 350, 600, 1e3, duration)
net = INGNet(varied_rates, ext_weight=0.9)
runner = bp.DSRunner(
net,
inputs=net.change_freq,
monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
'FS2.V0': lambda tdi: net.fs2_pop.V[0],
'RS.V0': lambda tdi: net.rs_pop.V[0],
'FS.spike': lambda tdi: net.fs_pop.spike,
'FS2.spike': lambda tdi: net.fs2_pop.spike,
'RS.spike': lambda tdi: net.rs_pop.spike}
)
runner.run(duration)
visualize_simulation_results(times=runner.mon.ts,
spikes={'FS': (runner.mon['FS.spike'], 'inh'),
'FS2': (runner.mon['FS2.spike'], 'inh'),
'RS': (runner.mon['RS.spike'], 'exc')},
example_potentials={'FS': runner.mon['FS.V0'],
'FS2': runner.mon['FS2.V0'],
'RS': runner.mon['RS.V0']},
varied_rates=varied_rates.to_numpy(),
xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)
[9]:
simulate_ing_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Susin & Destexhe, 2021): PING Network for Generating Gamma Oscillation
Implementation of the paper:
Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert
import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
def __init__(
self,
size,
# neuronal parameters
Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,
# synaptic parameters
tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,
# other parameters
name=None, method='exp_euler',
V_initializer=bp.init.Uniform(-65, -50),
w_initializer=bp.init.Constant(0.),
):
super(AdEx, self).__init__(size=size, name=name)
# neuronal parameters
self.Vth = Vth
self.delta = delta
self.tau_ref = tau_ref
self.tau_w = tau_w
self.a = a
self.b = b
self.C = C
self.gL = gL
self.EL = EL
self.V_reset = V_reset
self.V_sp_th = V_sp_th
# synaptic parameters
self.tau_e = tau_e
self.tau_i = tau_i
self.E_e = E_e
self.E_i = E_i
# neuronal variables
self.V = bp.init.variable_(V_initializer, self.num)
self.w = bp.init.variable_(w_initializer, self.num)
self.spike = bm.Variable(self.num, dtype=bool)
self.refractory = bm.Variable(self.num, dtype=bool)
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)
# synaptic parameters
self.ge = bm.Variable(self.num)
self.gi = bm.Variable(self.num)
# integral
self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)
def dge(self, ge, t):
return -ge / self.tau_e
def dgi(self, gi, t):
return -gi / self.tau_i
def dV(self, V, t, w, ge, gi, Iext=None):
I = ge * (self.E_e - V) + gi * (self.E_i - V)
if Iext is not None: I += Iext
dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
- w + self.gL * (self.EL - V) + I) / self.C
return dVdt
def dw(self, w, t, V):
dwdt = (self.a * (V - self.EL) - w) / self.tau_w
return dwdt
def update(self, tdi, x=None):
V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
tdi.t, Iext=x, dt=tdi.dt)
refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V.value, V)
spike = V >= self.V_sp_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.ge.value = ge
self.gi.value = gi
self.spike.value = spike
self.refractory.value = bm.logical_or(refractory, spike)
self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class PINGNet(bp.Network):
def __init__(self, ext_varied_rates, ext_weight=4., method='exp_euler', dt=bm.get_dt()):
super(PINGNet, self).__init__()
self.num_exc = 20000
self.num_inh = 5000
self.exc_syn_tau = 1. # ms
self.inh_syn_tau = 7.5 # ms
self.exc_syn_weight = 5. # nS
self.inh_syn_weight = 3.34 # nS
self.num_delay_step = int(1.5 / dt)
self.ext_varied_rates = ext_varied_rates
# neuronal populations
RS_par_ = RS_par.copy()
FS_par_ = FS_par.copy()
RS_par_.update(Vth=-50, V_sp_th=-40)
FS_par_.update(Vth=-50, V_sp_th=-40)
self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1))
# Poisson inputs
self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=ext_weight)
# synaptic projections
self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='ge'),
g_max=self.exc_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
output=bp.synouts.CUBA(target_var='gi'),
g_max=self.inh_syn_weight,
delay_step=self.num_delay_step)
def change_freq(self, tdi):
self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]
[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
dt = bm.get_dt() if dt is None else dt
t = 0
num_gap = int(t_gap / dt)
num_total = int(t_total / dt)
num_transition = int(t_transition / dt)
inputs = []
ramp_up = np.linspace(c_low, c_high, num_transition)
ramp_down = np.linspace(c_high, c_low, num_transition)
plato_base = np.ones(num_gap) * c_low
while t < num_total:
num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
t += (num_gap + num_transition + num_plato + num_transition)
return bm.asarray(np.concatenate(inputs)[:num_total])
[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
# sampling_space: in seconds (no units)
# signal_time: in seconds (no units)
# low_cut: in Hz (no units)(band to filter)
# high_cut: in Hz (no units)(band to filter)
signal = signal - np.mean(signal)
width = 5.0 # The desired width in Hz of the transition from pass to stop
ripple_db = 60.0 # The desired attenuation in the stop band, in dB.
sampling_rate = 1. / sampling_space
Nyquist = sampling_rate / 2.
num_taps, beta = kaiserord(ripple_db, width / Nyquist)
if num_taps % 2 == 0:
num_taps = num_taps + 1 # Numtaps must be odd
taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
pass_zero=False, scale=True)
filtered_signal = lfilter(taps, 1.0, signal)
delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase
delay_index = int(np.floor(delay * sampling_rate))
filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal
# correcting the delay and taking out the "corrupted" signal part
filtered_time = signal_time[num_taps - 1:] - delay
cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]
# --------------------------------------------------------------------------
# The hilbert transform are very slow when the signal has odd lenght,
# This part check if the length is odd, and if this is the case it adds a zero in the end
# of all the vectors related to the filtered Signal:
if len(filtered_signal) % 2 != 0: # If the lengh is odd
tmp1 = filtered_signal.tolist()
tmp1.append(0)
tmp2 = filtered_time.tolist()
tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
tmp3 = cutted_signal.tolist()
tmp3.append(0)
filtered_signal = np.asarray(tmp1)
filtered_time = np.asarray(tmp2)
cutted_signal = np.asarray(tmp3)
# --------------------------------------------------------------------------
ht_filtered_signal = hilbert(filtered_signal)
envelope = np.abs(ht_filtered_signal)
phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians
return filtered_time, filtered_signal, cutted_signal, envelope, phase
[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
# 1. input firing rate
ax = fig.add_subplot(gs[0])
plt.plot(times, varied_rates)
if xlim is None:
xlim = (0, times[-1])
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('External\nRate (Hz)')
# 2. inhibitory cell rater plot
ax = fig.add_subplot(gs[1: 3])
i = 0
y_ticks = ([], [])
for key, (sp_matrix, sp_type) in spikes.items():
iis, sps = np.where(sp_matrix)
tts = times[iis]
plt.plot(tts, sps + i, '.', markersize=1, label=key)
y_ticks[0].append(i + sp_matrix.shape[1] / 2)
y_ticks[1].append(key)
i += sp_matrix.shape[1]
ax.set_xlim(*xlim)
ax.set_xlabel('')
ax.set_ylabel('Neuron Index')
ax.set_xticks([])
ax.set_yticks(*y_ticks)
# ax.legend()
# 3. example membrane potential
ax = fig.add_subplot(gs[3: 5])
for key, potential in example_potentials.items():
vs = np.where(spikes[key][0][:, 0], 0, potential)
plt.plot(times, vs, label=key)
ax.set_xlim(*xlim)
ax.set_xticks([])
ax.set_ylabel('V (mV)')
ax.legend()
# 4. LFP
ax = fig.add_subplot(gs[5:7])
ax.set_xlim(*xlim)
t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
times = times[t1: t2]
lfp = 0
for sp_matrix, sp_type in spikes.values():
lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
bm.get_dt() * 1e-3)
plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
plt.legend(loc='best')
plt.xlabel('Time (ms)')
# save or show
if filename:
plt.savefig(filename, dpi=500)
plt.show()
[8]:
def simulate_ping_net():
duration = 6e3
varied_rates = get_inputs(2., 3., 50., 150, 600, 1e3, duration)
net = PINGNet(varied_rates, ext_weight=4.)
runner = bp.DSRunner(
net,
inputs=net.change_freq,
monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
'RS.V0': lambda tdi: net.rs_pop.V[0],
'FS.spike': lambda tdi: net.fs_pop.spike,
'RS.spike': lambda tdi: net.rs_pop.spike}
)
runner.run(duration)
visualize_simulation_results(times=runner.mon.ts,
spikes={'FS': (runner.mon['FS.spike'], 'inh'),
'RS': (runner.mon['RS.spike'], 'exc')},
example_potentials={'FS': runner.mon['FS.V0'],
'RS': runner.mon['RS.V0']},
varied_rates=varied_rates.to_numpy(),
xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)
[9]:
simulate_ping_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 1
Implementation of the figure 1 of:
Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.
[1]:
import brainpy as bp
import brainpy.math as bm
from jax import vmap, jit
import numpy as np
import matplotlib.pyplot as plt
[2]:
wIE = 4 + 2.0 / 7.0 # synaptic weight E to I
wII = wIE * 1.1 # synaptic weight I to I
[3]:
class LocalCircuit(bp.DynamicalSystem):
r"""The model is given by:
.. math::
\tau_E \frac{dv_E}{dt} = -v_E + a1 * [I_E]_+ + a2 * [I_I]_+ \\
\tau_I \frac{dv_I}{dt} = -v_I + a2 * [I_E]_+ + a4 * [I_I]_+
where :math:`[I_E]_+=max(I_E, 0)`. :math:`v_E` and :math:`v_I` denote the firing rates
of the excitatory and inhibitory populations respectively, :math:`\tau_E` and
:math:`\tau_I` are the corresponding intrinsic time constants.
"""
def __init__(self, wEE, wEI, tau_e=0.02, tau_i=0.02):
super(LocalCircuit, self).__init__()
# parameters
self.gc = bm.asarray([[wEE, -wEI],
[wIE, -wII]])
self.tau = bm.asarray([tau_e, tau_i]) # time constant [s]
# variables
self.state = bm.Variable(bm.asarray([1., 0.]))
def update(self, tdi):
self.state += (-self.state + self.gc @ self.state) / self.tau * tdi.dt
self.state.value = bm.maximum(self.state, 0.)
[4]:
def simulate(wEE, wEI, duration, dt=0.0001, numpy_mon_after_run=True):
model = LocalCircuit(wEE=wEE, wEI=wEI)
runner = bp.DSRunner(model, monitors=['state'], dt=dt,
numpy_mon_after_run=numpy_mon_after_run,
progress_bar=False)
runner.run(duration)
return runner.mon.state
[5]:
@jit
@vmap
def get_max_amplitude(wEE, wEI):
states = simulate(wEE, wEI, duration=2., dt=0.0001, numpy_mon_after_run=False)
return states[:, 0].max()
[6]:
@jit
@vmap
def get_eigen_value(wEE, wEI):
A = bm.array([[wEE, -wEI], [wIE, -wII]])
w, _ = bm.linalg.eig(A)
return w.real.max()
[7]:
# =================== Figure 1B in the paper ==================================#
length, step = 0.6, 0.0001
wEEweak = 4.45 # synaptic weight E to E weak LBA
wEIweak = 4.7 # synaptic weight I to E weak LBA
wEEstrong = 6 # synaptic weight E to E strong LBA
wEIstrong = 6.7 # synaptic weight I to E strong LBA
weak = simulate(wEEweak, wEIweak, duration=length, dt=step)
strong = simulate(wEEstrong, wEIstrong, duration=length, dt=step)
fig, gs = bp.visualize.get_figure(1, 1, 5, 8)
ax = fig.add_subplot(gs[0, 0])
ax.plot(np.arange(step, length, step), weak[:, 0], 'g')
ax.plot(np.arange(step, length, step), strong[:, 0], 'm')
ax.set_ylabel('Excitatory rate (Hz)', fontsize='x-large')
ax.set_xlabel('Time (s)', fontsize='x-large')
ax.set_ylim([0, 6])
ax.set_xlim([0, length])
ax.set_yticks([0, 2, 4, 6])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(['weak LBA', 'strong LBA'], prop={'size': 15}, frameon=False)
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[8]:
# =================== Figure 1C in the paper ==================================#
all_wEE = bm.arange(4, 6.5, .025)
all_wEI = bm.arange(4.5, 7, .025)
shape = (len(all_wEE), len(all_wEI))
max_amplitude = bm.ones(shape) * -5
all_wEE, all_wEI = bm.meshgrid(all_wEE, all_wEI)
# select all parameters lead to stable model
max_eigen_values = get_eigen_value(all_wEE.flatten(), all_wEI.flatten())
selected_ids = bm.where(max_eigen_values.reshape(shape) < 1)
# get maximum amplitude of each stable model
num = 100
for i in range(0, selected_ids[0].size, num):
ids = (selected_ids[0][i: i + num], selected_ids[1][i: i + num])
max_amps = get_max_amplitude(all_wEE[ids], all_wEI[ids])
max_amplitude[ids] = max_amps
fig, gs = bp.visualize.get_figure(1, 1, 5, 8)
ax = fig.add_subplot(gs[0, 0])
X, Y = bm.as_numpy(all_wEE), bm.as_numpy(all_wEI)
levels = np.linspace(-5, 8, 20)
plt.contourf(X, Y, max_amplitude.numpy(), levels=levels, cmap='hot')
plt.contourf(X, Y, max_amplitude.numpy(), levels=np.linspace(-5, 0), colors='silver')
plt.plot(wEEstrong, wEIstrong, 'mx')
plt.plot(wEEweak, wEIweak, 'gx')
plt.ylabel('Local I to E coupling', fontsize=15)
plt.xlabel('Local E to E coupling', fontsize=15)
plt.show()

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 2
Implementation of the figure 2 of:
Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap, jit
from scipy import io as sio
from functools import partial
[2]:
bm.set_dt(dt=2e-4)
[3]:
class ThreshLinearModel(bp.NeuGroup):
def __init__(
self,
size, hier, fln, inp_idx, inp_data,
eta=.68, betaE=.066, betaI=.351, tauE=2e-2, tauI=1e-2,
omegaEE=24.3, omegaEI=19.7, omegaIE=12.2, omegaII=12.5,
muIE=25.3, muEE=28., noiseE=None, noiseI=None, seed=None,
desired_ss=None, name=None
):
super(ThreshLinearModel, self).__init__(size, name=name)
# parameters
self.hier = hier
self.fln = fln
self.eta = bp.init.parameter(eta, self.num, False)
self.betaE = bp.init.parameter(betaE, self.num, False)
self.betaI = bp.init.parameter(betaI, self.num, False)
self.tauE = bp.init.parameter(tauE, self.num, False)
self.tauI = bp.init.parameter(tauI, self.num, False)
self.omegaEE = bp.init.parameter(omegaEE, self.num, False)
self.omegaEI = bp.init.parameter(omegaEI, self.num, False)
self.omegaIE = bp.init.parameter(omegaIE, self.num, False)
self.omegaII = bp.init.parameter(omegaII, self.num, False)
self.muIE = bp.init.parameter(muIE, self.num, False)
self.muEE = bp.init.parameter(muEE, self.num, False)
self.desired_ss = desired_ss
self.seed = seed
self.noiseE = bp.init.parameter(noiseE, self.num, True)
self.noiseI = bp.init.parameter(noiseI, self.num, True)
self.inp_idx, self.inp_data = inp_idx, inp_data
# Synaptic weights for intra-areal connections
self.wEE_intra = betaE * omegaEE * (1 + eta * hier)
self.wIE_intra = betaI * omegaIE * (1 + eta * hier)
self.wEI_intra = -betaE * omegaEI
self.wII_intra = -betaI * omegaII
# Synaptic weights for inter-areal connections
self.wEE_inter = bm.asarray(fln.T * (betaE * muEE * (1 + eta * hier))).T
self.wIE_inter = bm.asarray(fln.T * (betaI * muIE * (1 + eta * hier))).T
# Variables
self.re = bm.Variable(bm.zeros(self.num))
self.ri = bm.Variable(bm.zeros(self.num))
if not (self.noiseE is None and self.noiseI is None):
self.rng = bm.random.RandomState(seed)
else:
self.rng = None
# get background input
if desired_ss is None:
self.bgE = bm.zeros(self.num)
self.bgI = bm.zeros(self.num)
else:
if len(desired_ss) != 2:
raise ValueError
if len(desired_ss[0]) != self.num:
raise ValueError
if len(desired_ss[1]) != self.num:
raise ValueError
self.bgE, self.bgI = self.get_background_current(*desired_ss)
def get_background_current(self, ssE, ssI):
# total weights
wEEaux = bm.diag(-1 + self.wEE_intra) + self.wEE_inter
wEIaux = self.wEI_intra * bm.eye(self.num)
wIEaux = bm.diag(self.wIE_intra) + self.wIE_inter
wIIaux = (-1 + self.wII_intra) * bm.eye(self.num)
# temp matrices to create matrix A
A1 = bm.concatenate((wEEaux, wEIaux), axis=1)
A2 = bm.concatenate((wIEaux, wIIaux), axis=1)
A = bm.concatenate([A1, A2])
ss = bm.concatenate((ssE, ssI))
cur = -bm.dot(A, ss)
self.re.value, self.ri.value = ssE, ssI
# state = bm.linalg.lstsq(-A, cur, rcond=None)[0]
# self.re.value, self.ri.value = bm.split(state, 2)
return bm.split(cur, 2)
def reset(self):
if self.desired_ss is None:
self.re[:] = 0.
self.ri[:] = 0.
else:
self.re.value = self.desired_ss[0]
self.ri.value = self.desired_ss[1]
if self.rng is not None:
self.rng.seed(self.seed)
def update(self, tdi):
# E population
Ie = bm.dot(self.wEE_inter, self.re) + self.wEE_intra * self.re
Ie += self.wEI_intra * self.ri + self.bgE
if self.noiseE is not None:
Ie += self.noiseE * self.rng.randn(self.num) / bm.sqrt(tdi.dt)
Ie[self.inp_idx] += self.inp_data[bm.asarray(tdi.t / tdi.dt, dtype=int).value]
self.re.value = bm.maximum(self.re + (-self.re + bm.maximum(Ie, 10.)) / self.tauE * tdi.dt, 0)
# I population
Ii = bm.dot(self.wIE_inter, self.re) + self.wIE_intra * self.re
Ii += self.wII_intra * self.ri + self.bgI
if self.noiseI is not None:
Ii += self.noiseI * self.rng.randn(self.num) / bm.sqrt(tdi.dt)
self.ri.value = bm.maximum(self.ri + (-self.ri + bm.maximum(Ii, 35.)) / self.tauI * tdi.dt, 0)
[4]:
def simulate(num_node, muEE, fln, hier, input4v1, duration):
model = ThreshLinearModel(int(num_node),
hier=hier, fln=fln,
inp_idx=0, inp_data=input4v1, muEE=muEE,
desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner = bp.dyn.DSRunner(model, monitors=['re'], progress_bar=False, numpy_mon_after_run=False)
runner.run(duration)
return runner.mon.ts, runner.mon.re
[5]:
def show_firing_rates(ax, hist_t, hist_re, show_duration=None, title=None):
hist_t = bm.as_numpy(hist_t)
hist_re = bm.as_numpy(hist_re)
if show_duration is None:
i_start, i_end = (1.75, 5.)
else:
i_start, i_end = show_duration
i_start = round(i_start / bm.get_dt())
i_end = round(i_end / bm.get_dt())
# visualization
rateV1 = np.maximum(1e-2, hist_re[i_start:i_end, 0] - hist_re[i_start, 0])
rate24 = np.maximum(1e-2, hist_re[i_start:i_end, -1] - hist_re[i_start, -1])
ax.semilogy(hist_t[i_start:i_end], rateV1, 'dodgerblue')
ax.semilogy(hist_t[i_start:i_end], rate24, 'forestgreen')
ax.set_ylim([1e-2, 1e2 + 100])
# ax.set_xlim([-0.25, 2.25])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylabel('Change in firing rate (Hz)', fontsize='large')
ax.set_xlabel('Time (s)', fontsize='large')
if title:
ax.set_title(title)
[6]:
def show_maximum_rate(ax, muEE_range, peak_rates, title=''):
ax.semilogy(muEE_range[4], np.squeeze(peak_rates[4]), 'cornflowerblue', marker="o",
markersize=12, markerfacecolor='w')
ax.semilogy(muEE_range, np.squeeze(peak_rates), 'cornflowerblue', marker=".", markersize=10)
ax.semilogy(muEE_range[3], np.squeeze(peak_rates[3]), 'cornflowerblue', marker="x", markersize=15)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim([1e-6, 1e3])
ax.set_ylabel('Maximum rate in 24c (Hz)', fontsize='large')
ax.set_xlabel('Global E to E coupling', fontsize='large')
if title:
ax.set_title(title)
[7]:
def figure2B():
# data
data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
num_node = data['nNodes'][0, 0]
hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals'])) # normalize hierarchical position
fln = bm.asarray(data['flnMat'])
# inputs
ampl = 21.8 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
# Fig 2B
ax = plt.subplot(1, 2, 1)
times, res = simulate(int(num_node), fln=fln, hier=hier, muEE=34, input4v1=inputs, duration=duration)
show_firing_rates(ax, hist_t=times, hist_re=res, show_duration=(1.75, 5.), title=r'$\mu$EE=34')
ax = plt.subplot(1, 2, 2)
times, res = simulate(int(num_node), fln=fln, hier=hier, muEE=36, input4v1=inputs, duration=duration)
show_firing_rates(ax, hist_t=times, hist_re=res, show_duration=(1.75, 5.), title=r'$\mu$EE=36')
plt.show()
[8]:
figure2B()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[9]:
def figure2C():
# data
data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals'])) # normalize hierarchical position
fln = bm.asarray(data['flnMat'])
# inputs
ampl = 21.8 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
all_muEE = bm.arange(28, 44, 2)
i_start, i_end = (round(1.75 / bm.get_dt()), round(5. / bm.get_dt()))
@jit
@partial(vmap, in_axes=(0, None))
def peak_firing_rate(muEE, fln):
_, res = simulate(fln.shape[0], fln=fln, hier=hier, muEE=muEE,
input4v1=inputs, duration=duration)
return (res[i_start:i_end, -1] - res[i_start, -1]).max()
# with feedback
ax = plt.subplot(1, 2, 1)
area2peak = peak_firing_rate(all_muEE, fln)
area2peak = bm.where(area2peak > 500, 500, area2peak)
show_maximum_rate(ax, all_muEE.to_numpy(), area2peak.to_numpy(), title='With feedback')
# without feedback
ax = plt.subplot(1, 2, 2)
area2peak = peak_firing_rate(all_muEE, bm.tril(fln))
area2peak = bm.where(area2peak > 500, 500, area2peak)
show_maximum_rate(ax, all_muEE.to_numpy(), area2peak.to_numpy(), title='Without feedback')
plt.show()
[10]:
figure2C()

[11]:
sStrong = 1000
sWeak = 100
def show_multiple_area_rates(axes, times, rates, plot_duration, color='green'):
t_start, t_end = plot_duration
i_start, i_end = round(t_start / bm.get_dt()), round(t_end / bm.get_dt())
areas2plot = [0, 2, 5, 7, 8, 12, 16, 18, 28]
for i, j in enumerate(areas2plot):
ax = axes[i]
ax.plot(times[i_start:i_end] - 100, rates[i_start:i_end, j] - rates[i_start, j], color)
ax.set_xlim([-98.25, -96])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
plt.setp(ax.get_xticklabels(), visible=False)
ax.tick_params(axis='both', which='both', length=0)
peak = (rates[i_start:i_end, j] - rates[i_start, j]).max()
if j == 0:
ax.set_ylim([0, 140])
ax.set_yticks([round(sWeak * peak, 1) / sWeak])
ax.set_title('Weak GBA')
else:
ax.set_ylim([0, 1.2 * peak])
ax.set_yticks([round(sWeak * peak, 1) / sWeak])
if i == 4:
ax.set_ylabel('Change in firing rate (Hz)', fontsize='large')
[12]:
def figure3BD():
# data
data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals']))
fln = bm.asarray(data['flnMat'])
num_node = fln.shape[0]
fig, axes = plt.subplots(9, 2)
# weak GBA
muEE = 33.7
omegaEI = 19.7
ampl = 22.05 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
model = ThreshLinearModel(fln.shape[0],
hier=hier, fln=fln, inp_idx=0, inp_data=inputs,
muEE=muEE, omegaEI=omegaEI,
desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner1 = bp.dyn.DSRunner(model, monitors=['re'], )
runner1.run(duration)
show_multiple_area_rates(times=runner1.mon.ts,
rates=runner1.mon.re,
plot_duration=(1.25, 5.),
axes=[axis[0] for axis in axes])
# strong GBA
muEE = 51.5
omegaEI = 25.2
ampl = 11.54 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
model = ThreshLinearModel(fln.shape[0],
hier=hier, fln=fln, inp_idx=0, inp_data=inputs,
muEE=muEE, omegaEI=omegaEI,
desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner2 = bp.dyn.DSRunner(model, monitors=['re'], )
runner2.run(duration)
show_multiple_area_rates(times=runner1.mon.ts,
rates=runner1.mon.re,
plot_duration=(1.75, 5.),
axes=[axis[1] for axis in axes],
color='green')
show_multiple_area_rates(times=runner2.mon.ts,
rates=runner2.mon.re,
plot_duration=(1.75, 5.),
axes=[axis[1] for axis in axes],
color='purple')
plt.tight_layout()
plt.show()
[13]:
figure3BD()

[14]:
def figure3E():
# data
data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals']))
fln = bm.asarray(data['flnMat'])
num_node = fln.shape[0]
i_start, i_end = round(1.75 / bm.get_dt()), round(5. / bm.get_dt())
# weak GBA
muEE = 33.7
omegaEI = 19.7
ampl = 22.05 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner1 = bp.dyn.DSRunner(model, monitors=['re'], )
runner1.run(duration)
peak1 = (runner1.mon.re[i_start: i_end] - runner1.mon.re[i_start]).max(axis=0)
# strong GBA
muEE = 51.5
omegaEI = 25.2
ampl = 11.54 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner2 = bp.dyn.DSRunner(model, monitors=['re'], )
runner2.run(duration)
peak2 = (runner2.mon.re[i_start: i_end] - runner2.mon.re[i_start]).max(axis=0)
# visualization
fig, ax = plt.subplots()
ax.semilogy(np.arange(0, num_node), 100 * peak1 / peak1[0], 'green', marker=".", markersize=5)
ax.semilogy(np.arange(0, num_node), 100 * peak2 / peak2[0], 'purple', marker=".", markersize=5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim([1e-4, 1e2])
ax.set_xlim([0, num_node])
ax.legend(['weak GBA', 'strong GBA'], prop={'size': 10}, loc='upper right',
bbox_to_anchor=(1.0, 1.2), frameon=False)
ax.set_xticks(np.arange(0, num_node))
ax.set_xticklabels(data['areaList'].squeeze(), rotation='vertical', fontsize=10)
plt.show()
[15]:
figure3E()

[16]:
def figure3F():
# data
data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals']))
fln = bm.asarray(data['flnMat'])
num_node = fln.shape[0]
i_start, i_end = round(1.75 / bm.get_dt()), round(3.5 / bm.get_dt())
# input
ampl = 22.05 * 1.9
inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
@partial(vmap, in_axes=(0, None))
def maximum_rate(muEE, omegaEI=None):
if omegaEI is None: omegaEI = 19.7 + (muEE - 33.7) * 55 / 178
model = ThreshLinearModel(num_node, hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
runner = bp.dyn.DSRunner(model, monitors=['re'], progress_bar=False, numpy_mon_after_run=False)
runner.run(duration)
return (runner.mon.re[i_start: i_end, -1] - runner.mon.re[i_start, -1]).max()
# visualization
muEErange = bm.arange(20, 52, 2)
peaks_with_gba = maximum_rate(muEErange, None)
peaks_without_gba = maximum_rate(muEErange, 19.7)
peaks_with_gba = bm.where(peaks_with_gba > 500, 500, peaks_with_gba)
peaks_without_gba = bm.where(peaks_without_gba > 500, 500, peaks_without_gba)
fig, ax = plt.subplots()
ax.semilogy(muEErange, peaks_without_gba.to_numpy(), 'cornflowerblue', marker=".", markersize=5)
ax.semilogy(muEErange, peaks_with_gba.to_numpy(), 'black', marker=".", markersize=5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim([1e-8, 1e4])
ax.set_xlim([20, 50])
ax.set_ylabel('Maximum rate in 24c (Hz)', fontsize='large')
ax.set_xlabel('Global E to E coupling', fontsize='large')
plt.show()
[17]:
figure3F()

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 5
Implementation of the figure 5 of:
Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.
[1]:
import brainpy as bp
import brainpy.math as bm
from brainpy.dyn import neurons
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
from scipy.io import loadmat
[2]:
# This model should be run on a GPU device
bm.set_platform('gpu')
[3]:
class MultiAreaNet(bp.Network):
def __init__(
self, hier, conn, delay_mat, muIE=0.0475, muEE=.0375, wII=.075,
wEE=.01, wIE=.075, wEI=.0375, extE=15.4, extI=14.0, alpha=4., seed=None,
):
super(MultiAreaNet, self).__init__()
# data
self.hier = hier
self.conn = conn
self.delay_mat = delay_mat
# parameters
self.muIE = muIE
self.muEE = muEE
self.wII = wII
self.wEE = wEE
self.wIE = wIE
self.wEI = wEI
self.extE = extE
self.extI = extI
self.alpha = alpha
num_area = hier.size
self.num_area = num_area
rng = bm.random.RandomState(seed)
# neuron models
self.E = neurons.LIF((num_area, 1600), V_th=-50., V_reset=-60.,
V_rest=-70., tau=20., tau_ref=2.,
noise=3. / bm.sqrt(20.),
V_initializer=bp.init.Uniform(-70., -50.),
method='euler',
keep_size=True)
self.I = neurons.LIF((num_area, 400), V_th=-50., V_reset=-60.,
V_rest=-70., tau=10., tau_ref=2., noise=3. / bm.sqrt(10.),
V_initializer=bp.init.Uniform(-70., -50.),
method='euler',
keep_size=True)
# delays
self.intra_delay_step = int(2. / bm.get_dt())
self.E_delay_steps = bm.asarray(delay_mat.T / bm.get_dt(), dtype=int)
bm.fill_diagonal(self.E_delay_steps, self.intra_delay_step)
self.Edelay = bm.LengthDelay(self.E.spike, delay_len=int(self.E_delay_steps.max()))
self.Idelay = bm.LengthDelay(self.I.spike, delay_len=self.intra_delay_step)
# synapse model
syn_fun = lambda pre_spike, weight, conn_mat: weight * (pre_spike @ conn_mat)
self.f_E_current = vmap(syn_fun)
self.f_I_current = vmap(syn_fun, in_axes=(0, None, 0))
# synapses from I
self.intra_I2E_conn = rng.random((num_area, 400, 1600)) < 0.1
self.intra_I2I_conn = rng.random((num_area, 400, 400)) < 0.1
self.intra_I2E_weight = -wEI
self.intra_I2I_weight = -wII
# synapses from E
self.E2E_conns = [rng.random((num_area, 1600, 1600)) < 0.1 for _ in range(num_area)]
self.E2I_conns = [rng.random((num_area, 1600, 400)) < 0.1 for _ in range(num_area)]
self.E2E_weights = (1 + alpha * hier) * muEE * conn.T # inter-area connections
bm.fill_diagonal(self.E2E_weights, (1 + alpha * hier) * wEE) # intra-area connections
self.E2I_weights = (1 + alpha * hier) * muIE * conn.T # inter-area connections
bm.fill_diagonal(self.E2I_weights, (1 + alpha * hier) * wIE) # intra-area connections
def update(self, tdi, v1_input):
self.E.input[0] += v1_input
self.E.input += self.extE
self.I.input += self.extI
E_not_ref = bm.logical_not(self.E.refractory)
I_not_ref = bm.logical_not(self.I.refractory)
# synapses from E
for i in range(self.num_area):
delayed_E_spikes = self.Edelay(self.E_delay_steps[i], i).astype(float)
current = self.f_E_current(delayed_E_spikes, self.E2E_weights[i], self.E2E_conns[i])
self.E.V += current * E_not_ref # E2E
current = self.f_E_current(delayed_E_spikes, self.E2I_weights[i], self.E2I_conns[i])
self.I.V += current * I_not_ref # E2I
# synapses from I
delayed_I_spikes = self.Idelay(self.intra_delay_step).astype(float)
current = self.f_I_current(delayed_I_spikes, self.intra_I2E_weight, self.intra_I2E_conn)
self.E.V += current * E_not_ref # I2E
current = self.f_I_current(delayed_I_spikes, self.intra_I2I_weight, self.intra_I2I_conn)
self.I.V += current * I_not_ref # I2I
# updates
self.Edelay.update(self.E.spike)
self.Idelay.update(self.I.spike)
self.E.update(tdi)
self.I.update(tdi)
[4]:
def raster_plot(xValues, yValues, duration):
ticks = np.round(np.arange(0, 29) + 0.5, 2)
areas = ['V1', 'V2', 'V4', 'DP', 'MT', '8m', '5', '8l', 'TEO', '2', 'F1',
'STPc', '7A', '46d', '10', '9/46v', '9/46d', 'F5', 'TEpd', 'PBr',
'7m', '7B', 'F2', 'STPi', 'PROm', 'F7', '8B', 'STPr', '24c']
N = len(ticks)
plt.figure(figsize=(8, 6))
plt.plot(xValues, yValues / (4 * 400), '.', markersize=1)
plt.plot([0, duration], np.arange(N + 1).repeat(2).reshape(-1, 2).T, 'k-')
plt.ylabel('Area')
plt.yticks(np.arange(N))
plt.xlabel('Time [ms]')
plt.ylim(0, N)
plt.yticks(ticks, areas)
plt.xlim(0, duration)
plt.tight_layout()
plt.show()
[5]:
# hierarchy values
hierVals = loadmat('Joglekar_2018_data/hierValspython.mat')
hierValsnew = hierVals['hierVals'].flatten()
hier = bm.asarray(hierValsnew / max(hierValsnew)) # hierarchy normalized.
# fraction of labeled neurons
flnMatp = loadmat('Joglekar_2018_data/efelenMatpython.mat')
conn = bm.asarray(flnMatp['flnMatpython'].squeeze()) # fln values..Cij is strength from j to i
# Distance
speed = 3.5 # axonal conduction velocity
distMatp = loadmat('Joglekar_2018_data/subgraphWiring29.mat')
distMat = distMatp['wiring'].squeeze() # distances between areas values..
delayMat = bm.asarray(distMat / speed)
[6]:
pars = dict(extE=14.2, extI=14.7, wII=.075, wEE=.01, wIE=.075, wEI=.0375, muEE=.0375, muIE=0.0475)
inps = dict(value=15, duration=150)
[7]:
inputs, length = bp.inputs.section_input(values=[0, inps['value'], 0.],
durations=[300., inps['duration'], 500],
return_length=True)
[8]:
net = MultiAreaNet(hier, conn, delayMat, **pars)
runner = bp.DSRunner(net, fun_monitors={'E.spike': lambda tdi: net.E.spike.flatten()})
runner.run(length, inputs=inputs)
[9]:
times, indices = np.where(runner.mon['E.spike'])
times = runner.mon.ts[times]
raster_plot(times, indices, length)

Simulating 1-million-neuron networks with 1GB GPU memory
[1]:
import brainpy as bp
import brainpy.math as bm
import jax
assert bp.__version__ >= '2.3.8'
[2]:
bm.set(dt=0.4)
[3]:
_default_g_max = dict(type='homo', value=1., prob=0.1, seed=123)
_default_uniform = dict(type='uniform', w_low=0.1, w_high=1., prob=0.1, seed=123)
_default_normal = dict(type='normal', w_mu=0.1, w_sigma=1., prob=0.1, seed=123)
[4]:
class Exponential(bp.TwoEndConnNS):
def __init__(
self,
pre: bp.NeuGroup,
post: bp.NeuGroup,
output: bp.SynOut = bp.synouts.CUBA(),
g_max_par=_default_g_max,
delay_step=None,
tau=8.0,
method: str = 'exp_auto',
name: str = None,
mode: bm.Mode = None,
):
super().__init__(pre, post, None, output=output, name=name, mode=mode)
self.tau = tau
self.g_max_par = g_max_par
self.g = bp.init.variable_(bm.zeros, self.post.num, self.mode)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
def reset_state(self, batch_size=None):
self.g.value = bp.init.variable_(bm.zeros, self.post.num, batch_size)
def update(self):
t = bp.share.load('t')
dt = bp.share.load('dt')
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
if self.g_max_par['type'] == 'homo':
f = lambda s: bm.event_matvec_prob_conn_homo_weight(s,
self.g_max_par['value'],
conn_prob=self.g_max_par['prob'],
shape=(self.pre.num, self.post.num),
seed=self.g_max_par['seed'],
transpose=True)
elif self.g_max_par['type'] == 'uniform':
f = lambda s: bm.event_matvec_prob_conn_uniform_weight(s,
w_low=self.g_max_par['w_low'],
w_high=self.g_max_par['w_high'],
conn_prob=self.g_max_par['prob'],
shape=(self.pre.num, self.post.num),
seed=self.g_max_par['seed'],
transpose=True)
elif self.g_max_par['type'] == 'normal':
f = lambda s: bm.event_matvec_prob_conn_normal_weight(s,
w_mu=self.g_max_par['w_mu'],
w_sigma=self.g_max_par['w_sigma'],
conn_prob=self.g_max_par['prob'],
shape=(self.pre.num, self.post.num),
seed=self.g_max_par['seed'],
transpose=True)
else:
raise ValueError
if isinstance(self.mode, bm.BatchingMode):
f = jax.vmap(f)
post_vs = f(pre_spike)
self.g.value = self.integral(self.g.value, t, dt) + post_vs
return self.output(self.g)
[5]:
class EINet(bp.Network):
def __init__(self, scale=1.0, method='exp_auto'):
super().__init__()
num_exc = int(3200 * scale)
num_inh = int(800 * scale)
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.))
self.N = bp.neurons.LIF(num_exc + num_inh, **pars, method=method)
self.E = Exponential(self.N[:num_exc], self.N,
g_max_par=dict(type='homo', value=0.6 / scale, prob=0.02, seed=123),
tau=5., method=method, output=bp.synouts.COBA(E=0.))
self.I = Exponential(self.N[num_exc:], self.N,
g_max_par=dict(type='homo', value=6.7 / scale, prob=0.02, seed=12345),
tau=10., method=method, output=bp.synouts.COBA(E=-80.))
[6]:
duration = 1e2
net = EINet(scale=250)
[7]:
runner = bp.DSRunner(
net,
monitors=['N.spike'],
inputs=('N.input', 20.),
memory_efficient=True
)
runner.run(duration)
[8]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['N.spike'], show=True)

Integrator RNN Model
In this notebook, we train a vanilla RNN to integrate white noise. This example is useful on its own to understand how RNN training works.
[1]:
from functools import partial
import matplotlib.pyplot as plt
import brainpy as bp
import brainpy.math as bm
bm.set_environment(bm.training_mode)
Parameters
[2]:
dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128
Data
[3]:
@partial(bm.jit,
dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
static_argnames=['batch_size'])
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
# Create the white noise input
sample = bm.random.normal(size=(batch_size, 1, 1))
bias = mean * 2.0 * (sample - 0.5)
samples = bm.random.normal(size=(batch_size, num_step, 1))
noise_t = scale / dt ** 0.5 * samples
inputs = bias + noise_t
targets = bm.cumsum(inputs, axis=1)
return inputs, targets
[4]:
def train_data():
for _ in range(100):
yield build_inputs_and_targets(batch_size=num_batch)
Model
[5]:
class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.layers.Dense(num_hidden, 1)
def update(self, sha, x):
return self.out(sha, self.rnn(sha, x))
model = RNN(1, 100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Training
[6]:
# define loss function
def loss(predictions, targets, l2_reg=2e-4):
mse = bp.losses.mean_squared_error(predictions, targets)
l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
return mse + l2
# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)
# create a trainer
trainer = bp.BPTT(model, loss_fun=loss, optimizer=opt)
trainer.fit(train_data,
num_epoch=30,
num_report=200)
Train 200 steps, use 2.0416 s, loss 0.43417808413505554
Train 400 steps, use 1.1465 s, loss 0.027894824743270874
Train 600 steps, use 1.1387 s, loss 0.02194945327937603
Train 800 steps, use 4.5845 s, loss 0.021538913249969482
Train 1000 steps, use 5.0299 s, loss 0.02128899097442627
Train 1200 steps, use 4.9541 s, loss 0.02115160971879959
Train 1400 steps, use 5.0622 s, loss 0.021017059683799744
Train 1600 steps, use 5.0935 s, loss 0.020916711539030075
Train 1800 steps, use 4.9851 s, loss 0.020782889798283577
Train 2000 steps, use 4.8506 s, loss 0.020689304918050766
Train 2200 steps, use 5.1480 s, loss 0.020607156679034233
Train 2400 steps, use 5.0867 s, loss 0.020528702065348625
Train 2600 steps, use 4.9022 s, loss 0.02044598013162613
Train 2800 steps, use 4.8018 s, loss 0.020371917635202408
Train 3000 steps, use 4.9188 s, loss 0.020304910838603973
[7]:
plt.plot(bm.as_numpy(trainer.get_hist_metric()))
plt.show()

Testing
[8]:
model.reset_state(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)
plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()

Train RNN to Solve Parametric Working Memory
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym
import neurogym as ngym
[4]:
# Environment
task = 'DelayComparison-v0'
timing = {'delay': ('choice', [200, 400, 800, 1600, 3200]),
'response': ('constant', 500)}
kwargs = {'dt': 100, 'timing': timing}
seq_len = 100
# Make supervised dataset
dataset = ngym.Dataset(task,
env_kwargs=kwargs,
batch_size=16,
seq_len=seq_len)
# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, def_act=0, fig_kwargs={'figsize': (8, 6)})
plt.show()

[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size
[6]:
class RNN(bp.DynamicalSystem):
def __init__(self, num_input, num_hidden, num_output, num_batch,
w_ir=bp.init.KaimingNormal(scale=1.),
w_rr=bp.init.KaimingNormal(scale=1.),
w_ro=bp.init.KaimingNormal(scale=1.),
dt=None, seed=None):
super(RNN, self).__init__()
# parameters
self.tau = 100
self.num_batch = num_batch
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output
if dt is None:
self.alpha = 1
else:
self.alpha = dt / self.tau
self.rng = bm.random.RandomState(seed=seed)
# input weight
self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, size=(num_input, num_hidden)))
# recurrent weight
bound = 1 / num_hidden ** 0.5
self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, size=(num_hidden, num_hidden)))
self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))
# readout weight
self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, size=(num_hidden, num_output)))
self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))
# variables
self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
self.o = bm.Variable(bm.zeros((num_batch, num_output)))
def cell(self, x, h):
ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
state = h * (1 - self.alpha) + ins * self.alpha
return bm.relu(state)
def readout(self, h):
return h @ self.w_ro + self.b_ro
def update(self, x):
self.h.value = self.cell(x, self.h.value)
self.o.value = self.readout(self.h.value)
return self.h.value, self.o.value
def predict(self, xs):
self.h[:] = 0.
return bm.for_loop(self.update, xs)
def loss(self, xs, ys):
hs, os = self.predict(xs)
os = os.reshape((-1, os.shape[-1]))
loss = bp.losses.cross_entropy_loss(os, ys.flatten())
return loss, os
[7]:
# Instantiate the network and print information
hidden_size = 64
net = RNN(num_input=input_size,
num_hidden=hidden_size,
num_output=output_size,
num_batch=batch_size,
dt=env.dt)
[8]:
predict = bm.jit(net.predict)
[9]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[10]:
# gradient function
grad = bm.grad(net.loss,
child_objs=net,
grad_vars=net.train_vars().unique(),
return_value=True,
has_aux=True)
[11]:
@bm.jit
@bm.to_object(child_objs=(grad, opt))
def train(xs, ys):
grads, loss, os = grad(xs, ys)
opt.update(grads)
return loss, os
[12]:
running_acc = 0
running_loss = 0
for i in range(2000):
inputs, labels_np = dataset()
inputs = bm.asarray(inputs)
labels = bm.asarray(labels_np)
loss, outputs = train(inputs, labels)
running_loss += loss
# Compute performance
output_np = np.argmax(bm.as_numpy(outputs), axis=-1).flatten()
labels_np = labels_np.flatten()
ind = labels_np > 0 # Only analyze time points when target is not fixation
running_acc += np.mean(labels_np[ind] == output_np[ind])
if i % 100 == 99:
running_loss /= 100
running_acc /= 100
print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(i + 1, running_loss, running_acc))
running_loss = 0
running_acc = 0
Step 100, Loss 0.1653, Acc 0.149
Step 200, Loss 0.0321, Acc 0.674
Step 300, Loss 0.0238, Acc 0.764
Step 400, Loss 0.0187, Acc 0.825
Step 500, Loss 0.0150, Acc 0.857
Step 600, Loss 0.0124, Acc 0.884
Step 700, Loss 0.0099, Acc 0.909
Step 800, Loss 0.0110, Acc 0.894
Step 900, Loss 0.0092, Acc 0.912
Step 1000, Loss 0.0089, Acc 0.913
Step 1100, Loss 0.0080, Acc 0.926
Step 1200, Loss 0.0079, Acc 0.921
Step 1300, Loss 0.0079, Acc 0.923
Step 1400, Loss 0.0075, Acc 0.924
Step 1500, Loss 0.0080, Acc 0.921
Step 1600, Loss 0.0071, Acc 0.930
Step 1700, Loss 0.0071, Acc 0.930
Step 1800, Loss 0.0074, Acc 0.920
Step 1900, Loss 0.0069, Acc 0.928
Step 2000, Loss 0.0071, Acc 0.926
[13]:
def run(num_trial=1):
env.reset(no_step=True)
perf = 0
activity_dict = {}
trial_infos = {}
for i in range(num_trial):
env.new_trial()
ob, gt = env.ob, env.gt
inputs = bm.asarray(ob[:, np.newaxis, :])
rnn_activity, action_pred = predict(inputs)
rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
activity_dict[i] = rnn_activity
trial_infos[i] = env.trial
# Concatenate activity for PCA
activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0)
print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)
# Print trial informations
for i in range(5):
if i >= num_trial: break
print('Trial ', i, trial_infos[i])
pca = PCA(n_components=2)
pca.fit(activity)
# print('Shape of the projected activity: (Time points, PCs): ', activity_pc.shape)
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(6, 3))
for i in range(num_trial):
activity_pc = pca.transform(activity_dict[i])
trial = trial_infos[i]
color = 'red' if trial['ground_truth'] == 0 else 'blue'
_ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
if i < 3:
_ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
ax1.set_xlabel('PC 1')
ax1.set_ylabel('PC 2')
plt.show()
[14]:
run(num_trial=1)
Shape of the neural activity: (Time points, Neurons): (32, 64)
Trial 0 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}

[15]:
run(num_trial=20)
Shape of the neural activity: (Time points, Neurons): (562, 64)
Trial 0 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial 1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial 2 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial 3 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial 4 {'ground_truth': 1, 'vpair': (18, 10), 'v1': 18, 'v2': 10}

[16]:
run(num_trial=100)
Shape of the neural activity: (Time points, Neurons): (2778, 64)
Trial 0 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial 1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial 2 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial 3 {'ground_truth': 2, 'vpair': (22, 14), 'v1': 14, 'v2': 22}
Trial 4 {'ground_truth': 1, 'vpair': (22, 14), 'v1': 22, 'v2': 14}

(Song, et al., 2016): Training excitatory-inhibitory recurrent network
Implementation of the paper:
Song, H. F. , G. R. Yang , and X. J. Wang . “Training Excitatory-Inhibitory Recurrent Neural Networks for Cognitive Tasks: A Simple and Flexible Framework.” Plos Computational Biology 12.2(2016):e1004792.
The original code is based on PyTorch (https://github.com/gyyang/nn-brain/blob/master/EI_RNN.ipynb). However, comparing with the PyTorch codes, the training on BrainPy speeds up nearly four folds.
Here we will train recurrent neural network with excitatory and inhibitory neurons on a simple perceptual decision making task.
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
Defining a perceptual decision making task
[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym
import neurogym as ngym
[4]:
# Environment
task = 'PerceptualDecisionMaking-v0'
timing = {
'fixation': ('choice', (50, 100, 200, 400)),
'stimulus': ('choice', (100, 200, 400, 800)),
}
kwargs = {'dt': 20, 'timing': timing}
seq_len = 100
# Make supervised dataset
dataset = ngym.Dataset(task,
env_kwargs=kwargs,
batch_size=16,
seq_len=seq_len)
# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, fig_kwargs={'figsize': (10, 6)})
plt.show()

[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size
print(f'Input size = {input_size}')
print(f'Output size = {output_size}')
print(f'Bacth size = {batch_size}')
Input size = 3
Output size = 3
Bacth size = 16
Define E-I recurrent network
Here we define a E-I recurrent network, in particular, no self-connections are allowed.
[6]:
class RNN(bp.DynamicalSystem):
r"""E-I RNN.
The RNNs are described by the equations
.. math::
\begin{gathered}
\tau \dot{\mathbf{x}}=-\mathbf{x}+W^{\mathrm{rec}} \mathbf{r}+W^{\mathrm{in}}
\mathbf{u}+\sqrt{2 \tau \sigma_{\mathrm{rec}}^{2}} \xi \\
\mathbf{r}=[\mathbf{x}]_{+} \\
\mathbf{z}=W^{\text {out }} \mathbf{r}
\end{gathered}
In practice, the continuous-time dynamics are discretized to Euler form
in time steps of size :math:`\Delta t` as
.. math::
\begin{gathered}
\mathbf{x}_{t}=(1-\alpha) \mathbf{x}_{t-1}+\alpha\left(W^{\mathrm{rec}} \mathbf{r}_{t-1}+
W^{\mathrm{in}} \mathbf{u}_{t}\right)+\sqrt{2 \alpha \sigma_{\mathrm{rec}}^{2}} \mathbf{N}(0,1) \\
\mathbf{r}_{t}=\left[\mathbf{x}_{t}\right]_{+} \\
\mathbf{z}_{t}=W^{\mathrm{out}} \mathbf{r}_{t}
\end{gathered}
where :math:`\alpha = \Delta t/\tau` and :math:`N(0, 1)` are normally distributed
random numbers with zero mean and unit variance, sampled independently at every time step.
"""
def __init__(self, num_input, num_hidden, num_output, num_batch,
dt=None, e_ratio=0.8, sigma_rec=0., seed=None,
w_ir=bp.init.KaimingUniform(scale=1.),
w_rr=bp.init.KaimingUniform(scale=1.),
w_ro=bp.init.KaimingUniform(scale=1.)):
super(RNN, self).__init__()
# parameters
self.tau = 100
self.num_batch = num_batch
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output
self.e_size = int(num_hidden * e_ratio)
self.i_size = num_hidden - self.e_size
if dt is None:
self.alpha = 1
else:
self.alpha = dt / self.tau
self.sigma_rec = (2 * self.alpha) ** 0.5 * sigma_rec # Recurrent noise
self.rng = bm.random.RandomState(seed=seed)
# hidden mask
mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1))
np.fill_diagonal(mask, 0)
self.mask = bm.asarray(mask, dtype=bm.float_)
# input weight
self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))
# recurrent weight
bound = 1 / num_hidden ** 0.5
self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size)
self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))
# readout weight
bound = 1 / self.e_size ** 0.5
self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output)))
self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))
# variables
self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
self.o = bm.Variable(bm.zeros((num_batch, num_output)))
def cell(self, x, h):
ins = x @ self.w_ir + h @ (bm.abs(self.w_rr) * self.mask) + self.b_rr
state = h * (1 - self.alpha) + ins * self.alpha
state += self.sigma_rec * self.rng.randn(self.num_hidden)
return bm.relu(state)
def readout(self, h):
return h @ self.w_ro + self.b_ro
def update(self, x):
self.h.value = self.cell(x, self.h)
self.o.value = self.readout(self.h[:, :self.e_size])
return self.h.value, self.o.value
def predict(self, xs):
self.h[:] = 0.
return bm.for_loop(self.update, xs)
def loss(self, xs, ys):
hs, os = self.predict(xs)
os = os.reshape((-1, os.shape[-1]))
return bp.losses.cross_entropy_loss(os, ys.flatten())
Train the network on the decision making task
[7]:
# Instantiate the network and print information
hidden_size = 50
net = RNN(num_input=input_size,
num_hidden=hidden_size,
num_output=output_size,
num_batch=batch_size,
dt=env.dt,
sigma_rec=0.15)
[8]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[9]:
# gradient function
grad_f = bm.grad(net.loss,
child_objs=net,
grad_vars=net.train_vars().unique(),
return_value=True)
[10]:
@bm.jit(child_objs=(net, opt))
def train(xs, ys):
grads, loss = grad_f(xs, ys)
opt.update(grads)
return loss
The training speeds up nearly 4 times, comparing with the original PyTorch codes.
[11]:
running_loss = 0
print_step = 200
for i in range(5000):
inputs, labels = dataset()
inputs = bm.asarray(inputs)
labels = bm.asarray(labels)
loss = train(inputs, labels)
running_loss += loss
if i % print_step == (print_step - 1):
running_loss /= print_step
print('Step {}, Loss {:0.4f}'.format(i + 1, running_loss))
running_loss = 0
Step 200, Loss 0.6556
Step 400, Loss 0.4587
Step 600, Loss 0.4140
Step 800, Loss 0.3671
Step 1000, Loss 0.3321
Step 1200, Loss 0.3048
Step 1400, Loss 0.2851
Step 1600, Loss 0.2638
Step 1800, Loss 0.2431
Step 2000, Loss 0.2230
Step 2200, Loss 0.2083
Step 2400, Loss 0.1932
Step 2600, Loss 0.1787
Step 2800, Loss 0.1673
Step 3000, Loss 0.1595
Step 3200, Loss 0.1457
Step 3400, Loss 0.1398
Step 3600, Loss 0.1335
Step 3800, Loss 0.1252
Step 4000, Loss 0.1204
Step 4200, Loss 0.1151
Step 4400, Loss 0.1099
Step 4600, Loss 0.1075
Step 4800, Loss 0.1027
Step 5000, Loss 0.0976
Run the network post-training and record neural activity
[12]:
predict = bm.jit(net.predict, dyn_vars=net.vars())
[13]:
env.reset(no_step=True)
env.timing.update({'fixation': ('constant', 500), 'stimulus': ('constant', 500)})
perf = 0
num_trial = 500
activity_dict = {}
trial_infos = {}
stim_activity = [[], []] # response for ground-truth 0 and 1
for i in range(num_trial):
env.new_trial()
ob, gt = env.ob, env.gt
inputs = bm.asarray(ob[:, np.newaxis, :])
rnn_activity, action_pred = predict(inputs)
# Compute performance
action_pred = bm.as_numpy(action_pred)
choice = np.argmax(action_pred[-1, 0, :])
correct = choice == gt[-1]
# Log trial info
trial_info = env.trial
trial_info.update({'correct': correct, 'choice': choice})
trial_infos[i] = trial_info
# Log stimulus period activity
rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
activity_dict[i] = rnn_activity
# Compute stimulus selectivity for all units
# Compute each neuron's response in trials where ground_truth=0 and 1 respectively
rnn_activity = rnn_activity[env.start_ind['stimulus']: env.end_ind['stimulus']]
stim_activity[env.trial['ground_truth']].append(rnn_activity)
print('Average performance', np.mean([val['correct'] for val in trial_infos.values()]))
Average performance 0.81
Plot neural activity from sample trials
[14]:
trial = 2
plt.figure(figsize=(8, 6))
_ = plt.plot(activity_dict[trial][:, :net.e_size], color='blue', label='Excitatory')
_ = plt.plot(activity_dict[trial][:, net.e_size:], color='red', label='Inhibitory')
plt.xlabel('Time step')
plt.ylabel('Activity')
plt.show()

Compute stimulus selectivity for sorting neurons
Here for each neuron we compute its stimulus period selectivity \(d'\)
[15]:
mean_activity = []
std_activity = []
for ground_truth in [0, 1]:
activity = np.concatenate(stim_activity[ground_truth], axis=0)
mean_activity.append(np.mean(activity, axis=0))
std_activity.append(np.std(activity, axis=0))
# Compute d'
selectivity = (mean_activity[0] - mean_activity[1])
selectivity /= np.sqrt((std_activity[0] ** 2 + std_activity[1] ** 2 + 1e-7) / 2)
# Sort index for selectivity, separately for E and I
ind_sort = np.concatenate((np.argsort(selectivity[:net.e_size]),
np.argsort(selectivity[net.e_size:]) + net.e_size))
Plot network connectivity sorted by stimulus selectivity
[16]:
# Plot distribution of stimulus selectivity
plt.figure(figsize=(6, 4))
plt.hist(selectivity)
plt.xlabel('Selectivity')
plt.ylabel('Number of neurons')
plt.show()

[17]:
W = (bm.abs(net.w_rr) * net.mask).numpy()
# Sort by selectivity
W = W[:, ind_sort][ind_sort, :]
wlim = np.max(np.abs(W))
plt.figure(figsize=(10, 10))
plt.imshow(W, cmap='bwr_r', vmin=-wlim, vmax=wlim)
plt.colorbar()
plt.xlabel('From neurons')
plt.ylabel('To neurons')
plt.title('Network connectivity')
plt.tight_layout()
plt.show()

(Masse, et al., 2019): RNN with STP for Working Memory
Re-implementation of the paper with BrainPy:
Masse, Nicolas Y., Guangyu R. Yang, H. Francis Song, Xiao-Jing Wang, and David J. Freedman. “Circuit mechanisms for the maintenance and manipulation of information in working memory.” Nature neuroscience 22, no. 7 (2019): 1159-1167.
Thanks the original GitHub code: https://github.com/nmasse/Short-term-plasticity-RNN
The code for the implementation of Task please refer to the Masse_2019_STP_RNN_tasks.py.
The analysis methods please refer to the original repository: https://github.com/nmasse/Short-term-plasticity-RNN/blob/master/analysis.py
[11]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[12]:
import os
import math
import pickle
import numpy as np
from Masse_2019_STP_RNN_tasks import Task
[13]:
# Time parameters
dt = 100 # ms
dt_sec = dt / 1000
time_constant = 100 # ms
alpha = dt / time_constant
[14]:
# Loss parameters
spike_regularization = 'L2' # 'L1' or 'L2'
spike_cost = 2e-2
weight_cost = 0.
clip_max_grad_val = 0.1
[15]:
# Training specs
batch_size = 1024
learning_rate = 2e-2
[16]:
def initialize(shape, prob, size):
w = bm.random.gamma(shape, size=size)
w *= (bm.random.random(size) < prob)
return bm.asarray(w, dtype=bm.float32)
Model
[17]:
class Model(bp.DynamicalSystem):
def __init__(self, task, num_hidden=100, name=None):
super(Model, self).__init__(name=name)
assert isinstance(task, Task)
self.task = task
# Network configuration
self.exc_inh_prop = 0.8 # excitatory/inhibitory ratio
self.conn_prob = 0.2
# Network shape
self.num_output = task.num_output
self.num_hidden = num_hidden
self.num_input = task.num_input
# EI
self.num_exc = int(self.num_hidden * self.exc_inh_prop)
self.num_inh = self.num_hidden - self.num_exc
self.EI_list = bm.ones(self.num_hidden)
self.EI_list[self.num_exc:] = -1.
self.EI_matrix = bm.diag(self.EI_list)
self.inh_index = bm.arange(self.num_exc, self.num_hidden)
# Input and noise
self.noise_rnn = math.sqrt(2 * alpha) * 0.5
# Synaptic plasticity specs
self.tau_fast = 200 # ms
self.tau_slow = 1500 # ms
self.U_stf = 0.15
self.U_std = 0.45
# Initial hidden values
self.init_h = bm.TrainVar(bm.ones((batch_size, self.num_hidden)) * 0.1)
self.h = bm.Variable(bm.ones((batch_size, self.num_hidden)) * 0.1)
# Input/recurrent/output weights
# 1. w_ir (input => recurrent)
prob = self.conn_prob * task.num_receptive_fields
self.w_ir = bm.TrainVar(initialize(0.2, prob, (self.num_input, self.num_hidden)))
self.w_ir_mask = bm.ones((self.num_input, self.num_hidden))
if task.trial_type == 'location_DMS':
self.w_ir_mask *= 0.
target_ind = [range(0, self.num_hidden, 3), range(1, self.num_hidden, 3), range(2, self.num_hidden, 3)]
for n in range(self.num_input):
u = int(n // (self.num_input / 3))
self.w_ir_mask[n, target_ind[u]] = 1.
self.w_ir *= self.w_ir_mask # only preserve
# 2. w_rr (recurrent => recurrent)
self.w_rr = bm.TrainVar(initialize(0.1, self.conn_prob, (self.num_hidden, self.num_hidden)))
self.w_rr[:, self.num_exc:] = initialize(0.2, self.conn_prob, (self.num_hidden, self.num_inh))
self.w_rr[self.num_exc:, :] = initialize(0.2, self.conn_prob, (self.num_inh, self.num_hidden))
self.w_rr_mask = bm.ones((self.num_hidden, self.num_hidden)) - bm.eye(self.num_hidden)
self.w_rr *= self.w_rr_mask # remove self-connections
self.b_rr = bm.TrainVar(bm.zeros((1, self.num_hidden)))
# 3. w_ro (input => recurrent)
self.w_ro = bm.TrainVar(initialize(0.1, self.conn_prob, (self.num_hidden, self.num_output)))
self.w_ro_mask = bm.ones((self.num_hidden, self.num_output))
self.w_ro_mask[self.num_exc:, :] = 0.
self.w_ro *= self.w_ro_mask # remove inhibitory-to-output connections
# 4. b_ro (bias)
self.b_ro = bm.TrainVar(bm.zeros((1, self.num_output)))
# Synaptic variables
# - The first row (first half neurons) are facilitating synapses
# - The second row (last half neurons) are depressing synapses
alpha_stf = bm.ones((2, int(self.num_hidden / 2)))
alpha_stf[0] = dt / self.tau_slow
alpha_stf[1] = dt / self.tau_fast
alpha_std = bm.ones((2, int(self.num_hidden / 2)))
alpha_std[0] = dt / self.tau_fast
alpha_std[1] = dt / self.tau_slow
U = bm.ones((2, int(self.num_hidden / 2)))
U[0] = 0.15
U[1] = 0.45
u = bm.ones((batch_size, 2, int(self.num_hidden / 2))) * 0.3
u[:, 0] = 0.15
u[:, 1] = 0.45
# - final
self.alpha_stf = alpha_stf.reshape((1, -1))
self.alpha_std = alpha_std.reshape((1, -1))
self.U = U.reshape((1, -1))
self.u = bm.Variable(u.reshape((batch_size, -1)))
self.x = bm.Variable(bm.ones((batch_size, self.num_hidden)))
self.y = bm.Variable(bm.ones((batch_size, self.num_output)))
self.y_hist = bm.Variable(bm.zeros((task.num_steps, batch_size, task.num_output)))
# Loss
self.loss = bm.Variable(bm.zeros(1))
self.perf_loss = bm.Variable(bm.zeros(1))
self.spike_loss = bm.Variable(bm.zeros(1))
self.weight_loss = bm.Variable(bm.zeros(1))
def reset_state(self, batch_size):
u = bm.ones((batch_size, 2, int(self.num_hidden / 2))) * 0.3
u[:, 0] = 0.15
u[:, 1] = 0.45
self.u.value = u.reshape((batch_size, -1))
self.x.value = bm.ones((batch_size, self.num_hidden))
self.loss[:] = 0.
self.perf_loss[:] = 0.
self.spike_loss[:] = 0.
self.weight_loss[:] = 0.
def update(self, input):
# update STP variables
self.x += (self.alpha_std * (1 - self.x) - dt_sec * self.u * self.x * self.h)
self.u += (self.alpha_stf * (self.U - self.u) + dt_sec * self.U * (1 - self.u) * self.h)
self.x.value = bm.minimum(1., bm.relu(self.x))
self.u.value = bm.minimum(1., bm.relu(self.u))
h_post = self.u * self.x * self.h
# Update the hidden state. Only use excitatory projections from input layer to RNN
# All input and RNN activity will be non-negative
state = alpha * (input @ bm.relu(self.w_ir) + h_post @ self.w_rr + self.b_rr)
state += bm.random.normal(0, self.noise_rnn, self.h.shape)
self.h.value = bm.relu(state) + self.h * (1 - alpha)
self.y.value = self.h @ bm.relu(self.w_ro) + self.b_ro
def predict(self, inputs):
self.h[:] = self.init_h
scan = bm.make_loop(body_fun=self.update,
dyn_vars=[self.x, self.u, self.h, self.y],
out_vars=[self.y, self.h])
logits, hist_h = scan(inputs)
self.y_hist[:] = logits
return logits, hist_h
def loss_func(self, inputs, targets, mask):
logits, hist_h = self.predict(inputs)
# Calculate the performance loss
perf_loss = bp.losses.cross_entropy_loss(logits, targets, reduction='none') * mask
self.perf_loss[:] = bm.mean(perf_loss)
# L1/L2 penalty term on hidden state activity to encourage low spike rate solutions
n = 2 if spike_regularization == 'L2' else 1
self.spike_loss[:] = bm.mean(hist_h ** n)
self.weight_loss[:] = bm.mean(bm.relu(self.w_rr) ** n)
# final loss
self.loss[:] = self.perf_loss + spike_cost * self.spike_loss + weight_cost * self.weight_loss
return self.loss.mean()
Analysis
[18]:
def get_perf(target, output, mask):
"""Calculate task accuracy by comparing the actual network output to the desired output
only examine time points when test stimulus is on, e.g. when y[:,:,0] = 0 """
target = target.numpy()
output = output.numpy()
mask = mask.numpy()
mask_full = mask > 0
mask_test = mask_full * (target[:, :, 0] == 0)
mask_non_match = mask_full * (target[:, :, 1] == 1)
mask_match = mask_full * (target[:, :, 2] == 1)
target_max = np.argmax(target, axis=2)
output_max = np.argmax(output, axis=2)
match = target_max == output_max
accuracy = np.sum(match * mask_test) / np.sum(mask_test)
acc_non_match = np.sum(match * np.squeeze(mask_non_match)) / np.sum(mask_non_match)
acc_match = np.sum(match * np.squeeze(mask_match)) / np.sum(mask_match)
return accuracy, acc_non_match, acc_match
Training
[19]:
def trial(task_name, save_fn=None, num_iterations=2000, iter_between_outputs=5):
task = Task(task_name, dt=dt, tau=time_constant, batch_size=batch_size)
# trial_info = task.generate_trial(set_rule=None)
# task.plot_neural_input(trial_info)
model = Model(task)
opt = bp.optim.Adam(learning_rate, train_vars=model.train_vars())
grad_f = bm.grad(model.loss_func,
dyn_vars=model.vars(),
grad_vars=model.train_vars(),
return_value=True)
@bm.jit
@bm.to_object(child_objs=(model, opt))
def train_op(x, y, mask):
grads, _ = grad_f(x, y, mask)
capped_gs = dict()
for key, grad in grads.items():
if 'w_rr' in key: grad *= model.w_rr_mask
elif 'w_ro' in key: grad *= model.w_ro_mask
elif 'w_ri' in key: grad *= model.w_ir_mask
capped_gs[key] = bm.clip_by_norm(grad, clip_max_grad_val)
opt.update(grads=capped_gs)
# keep track of the model performance across training
model_performance = {'accuracy': [], 'loss': [], 'perf_loss': [],
'spike_loss': [], 'weight_loss': [], 'iteration': []}
for i in range(num_iterations):
model.reset_state(batch_size)
# generate batch of batch_train_size
trial_info = task.generate_trial(set_rule=None)
inputs = bm.array(trial_info['neural_input'], dtype=bm.float32)
targets = bm.array(trial_info['desired_output'], dtype=bm.float32)
mask = bm.array(trial_info['train_mask'], dtype=bm.float32)
# Run the model
train_op(inputs, targets, mask)
# get metrics
accuracy, _, _ = get_perf(targets, model.y_hist, mask)
model_performance['accuracy'].append(accuracy)
model_performance['loss'].append(model.loss)
model_performance['perf_loss'].append(model.perf_loss)
model_performance['spike_loss'].append(model.spike_loss)
model_performance['weight_loss'].append(model.weight_loss)
model_performance['iteration'].append(i)
# Save the network model and output model performance to screen
if i % iter_between_outputs == 0:
print(task_name +
f' Iter {i:4d}' +
f' | Accuracy {accuracy:0.4f}' +
f' | Perf loss {model.perf_loss[0]:0.4f}' +
f' | Spike loss {model.spike_loss[0]:0.4f}' +
f' | Weight loss {model.weight_loss[0]:0.4f}' +
f' | Mean activity {bm.mean(model.h):0.4f}')
if save_fn:
if not os.path.exists(os.path.dirname(save_fn)):
os.makedirs(os.path.dirname(save_fn))
# Save model and results
weights = model.train_vars().unique().dict()
results = {'weights': weights, 'parameters': {}}
for k, v in model_performance.items():
results[k] = v
pickle.dump(results, open(save_fn, 'wb'))
[20]:
trial('DMS')
DMS Iter 0 | Accuracy 0.5211 | Perf loss 2.7038 | Spike loss 7.8653 | Weight loss 0.0315 | Mean activity 3.3380
DMS Iter 5 | Accuracy 0.4184 | Perf loss 0.4879 | Spike loss 5.5533 | Weight loss 0.0315 | Mean activity 1.6523
DMS Iter 10 | Accuracy 0.4473 | Perf loss 0.2892 | Spike loss 4.3892 | Weight loss 0.0312 | Mean activity 1.2179
DMS Iter 15 | Accuracy 0.4654 | Perf loss 0.2386 | Spike loss 2.8176 | Weight loss 0.0308 | Mean activity 1.0014
DMS Iter 20 | Accuracy 0.5000 | Perf loss 0.2102 | Spike loss 2.4600 | Weight loss 0.0309 | Mean activity 0.8940
DMS Iter 25 | Accuracy 0.4783 | Perf loss 0.2018 | Spike loss 2.2587 | Weight loss 0.0311 | Mean activity 0.8501
DMS Iter 30 | Accuracy 0.5119 | Perf loss 0.1815 | Spike loss 2.0880 | Weight loss 0.0313 | Mean activity 0.7721
DMS Iter 35 | Accuracy 0.5340 | Perf loss 0.1697 | Spike loss 1.7911 | Weight loss 0.0318 | Mean activity 0.7324
DMS Iter 40 | Accuracy 0.5469 | Perf loss 0.1640 | Spike loss 1.7178 | Weight loss 0.0324 | Mean activity 0.6913
DMS Iter 45 | Accuracy 0.5764 | Perf loss 0.1524 | Spike loss 1.6305 | Weight loss 0.0332 | Mean activity 0.6684
DMS Iter 50 | Accuracy 0.6229 | Perf loss 0.1436 | Spike loss 1.5921 | Weight loss 0.0339 | Mean activity 0.6497
DMS Iter 55 | Accuracy 0.6301 | Perf loss 0.1382 | Spike loss 1.5635 | Weight loss 0.0346 | Mean activity 0.6360
DMS Iter 60 | Accuracy 0.6254 | Perf loss 0.1392 | Spike loss 1.4905 | Weight loss 0.0353 | Mean activity 0.6193
DMS Iter 65 | Accuracy 0.6494 | Perf loss 0.1325 | Spike loss 1.4990 | Weight loss 0.0360 | Mean activity 0.6177
DMS Iter 70 | Accuracy 0.6910 | Perf loss 0.1232 | Spike loss 1.4614 | Weight loss 0.0365 | Mean activity 0.6192
DMS Iter 75 | Accuracy 0.6977 | Perf loss 0.1206 | Spike loss 1.4657 | Weight loss 0.0372 | Mean activity 0.6092
DMS Iter 80 | Accuracy 0.7135 | Perf loss 0.1178 | Spike loss 1.4238 | Weight loss 0.0377 | Mean activity 0.6133
DMS Iter 85 | Accuracy 0.7240 | Perf loss 0.1133 | Spike loss 1.4442 | Weight loss 0.0383 | Mean activity 0.6201
DMS Iter 90 | Accuracy 0.7273 | Perf loss 0.1093 | Spike loss 1.4717 | Weight loss 0.0390 | Mean activity 0.6170
DMS Iter 95 | Accuracy 0.7404 | Perf loss 0.1055 | Spike loss 1.4073 | Weight loss 0.0395 | Mean activity 0.6214
DMS Iter 100 | Accuracy 0.7670 | Perf loss 0.0974 | Spike loss 1.4685 | Weight loss 0.0403 | Mean activity 0.6321
DMS Iter 105 | Accuracy 0.7639 | Perf loss 0.0996 | Spike loss 1.4525 | Weight loss 0.0410 | Mean activity 0.6321
DMS Iter 110 | Accuracy 0.7947 | Perf loss 0.0879 | Spike loss 1.4711 | Weight loss 0.0416 | Mean activity 0.6412
DMS Iter 115 | Accuracy 0.7881 | Perf loss 0.0893 | Spike loss 1.4721 | Weight loss 0.0426 | Mean activity 0.6400
DMS Iter 120 | Accuracy 0.7902 | Perf loss 0.0853 | Spike loss 1.4998 | Weight loss 0.0433 | Mean activity 0.6495
DMS Iter 125 | Accuracy 0.8094 | Perf loss 0.0775 | Spike loss 1.4819 | Weight loss 0.0440 | Mean activity 0.6519
DMS Iter 130 | Accuracy 0.8127 | Perf loss 0.0780 | Spike loss 1.4739 | Weight loss 0.0446 | Mean activity 0.6535
DMS Iter 135 | Accuracy 0.8295 | Perf loss 0.0725 | Spike loss 1.4526 | Weight loss 0.0452 | Mean activity 0.6481
DMS Iter 140 | Accuracy 0.8303 | Perf loss 0.0702 | Spike loss 1.4571 | Weight loss 0.0458 | Mean activity 0.6508
DMS Iter 145 | Accuracy 0.8248 | Perf loss 0.0719 | Spike loss 1.4244 | Weight loss 0.0463 | Mean activity 0.6533
DMS Iter 150 | Accuracy 0.8318 | Perf loss 0.0707 | Spike loss 1.4362 | Weight loss 0.0468 | Mean activity 0.6520
DMS Iter 155 | Accuracy 0.8227 | Perf loss 0.0718 | Spike loss 1.4525 | Weight loss 0.0476 | Mean activity 0.6625
DMS Iter 160 | Accuracy 0.8293 | Perf loss 0.0687 | Spike loss 1.4380 | Weight loss 0.0481 | Mean activity 0.6525
DMS Iter 165 | Accuracy 0.8318 | Perf loss 0.0672 | Spike loss 1.4201 | Weight loss 0.0484 | Mean activity 0.6538
DMS Iter 170 | Accuracy 0.8379 | Perf loss 0.0639 | Spike loss 1.4193 | Weight loss 0.0488 | Mean activity 0.6541
DMS Iter 175 | Accuracy 0.8297 | Perf loss 0.0673 | Spike loss 1.4213 | Weight loss 0.0494 | Mean activity 0.6478
DMS Iter 180 | Accuracy 0.8422 | Perf loss 0.0619 | Spike loss 1.4273 | Weight loss 0.0497 | Mean activity 0.6586
DMS Iter 185 | Accuracy 0.8475 | Perf loss 0.0614 | Spike loss 1.3814 | Weight loss 0.0500 | Mean activity 0.6567
DMS Iter 190 | Accuracy 0.8451 | Perf loss 0.0628 | Spike loss 1.3811 | Weight loss 0.0505 | Mean activity 0.6565
DMS Iter 195 | Accuracy 0.8518 | Perf loss 0.0601 | Spike loss 1.3828 | Weight loss 0.0511 | Mean activity 0.6549
DMS Iter 200 | Accuracy 0.8529 | Perf loss 0.0569 | Spike loss 1.3763 | Weight loss 0.0515 | Mean activity 0.6515
DMS Iter 205 | Accuracy 0.8508 | Perf loss 0.0577 | Spike loss 1.3716 | Weight loss 0.0520 | Mean activity 0.6461
DMS Iter 210 | Accuracy 0.8514 | Perf loss 0.0579 | Spike loss 1.3755 | Weight loss 0.0524 | Mean activity 0.6428
DMS Iter 215 | Accuracy 0.8512 | Perf loss 0.0573 | Spike loss 1.3652 | Weight loss 0.0529 | Mean activity 0.6437
DMS Iter 220 | Accuracy 0.8521 | Perf loss 0.0559 | Spike loss 1.3662 | Weight loss 0.0533 | Mean activity 0.6355
DMS Iter 225 | Accuracy 0.8609 | Perf loss 0.0540 | Spike loss 1.3805 | Weight loss 0.0535 | Mean activity 0.6425
DMS Iter 230 | Accuracy 0.8682 | Perf loss 0.0520 | Spike loss 1.3702 | Weight loss 0.0541 | Mean activity 0.6354
DMS Iter 235 | Accuracy 0.8682 | Perf loss 0.0522 | Spike loss 1.3818 | Weight loss 0.0544 | Mean activity 0.6353
DMS Iter 240 | Accuracy 0.8629 | Perf loss 0.0529 | Spike loss 1.3609 | Weight loss 0.0548 | Mean activity 0.6275
DMS Iter 245 | Accuracy 0.8494 | Perf loss 0.0576 | Spike loss 1.3546 | Weight loss 0.0555 | Mean activity 0.6225
DMS Iter 250 | Accuracy 0.8623 | Perf loss 0.0535 | Spike loss 1.3358 | Weight loss 0.0557 | Mean activity 0.6133
DMS Iter 255 | Accuracy 0.8684 | Perf loss 0.0512 | Spike loss 1.3440 | Weight loss 0.0562 | Mean activity 0.6131
DMS Iter 260 | Accuracy 0.8705 | Perf loss 0.0520 | Spike loss 1.3385 | Weight loss 0.0566 | Mean activity 0.6083
DMS Iter 265 | Accuracy 0.8678 | Perf loss 0.0519 | Spike loss 1.3221 | Weight loss 0.0567 | Mean activity 0.6024
DMS Iter 270 | Accuracy 0.8740 | Perf loss 0.0492 | Spike loss 1.3103 | Weight loss 0.0571 | Mean activity 0.5915
DMS Iter 275 | Accuracy 0.8762 | Perf loss 0.0465 | Spike loss 1.2993 | Weight loss 0.0574 | Mean activity 0.5922
DMS Iter 280 | Accuracy 0.8723 | Perf loss 0.0489 | Spike loss 1.2954 | Weight loss 0.0578 | Mean activity 0.5907
DMS Iter 285 | Accuracy 0.8660 | Perf loss 0.0495 | Spike loss 1.3016 | Weight loss 0.0583 | Mean activity 0.5868
DMS Iter 290 | Accuracy 0.8783 | Perf loss 0.0478 | Spike loss 1.2908 | Weight loss 0.0589 | Mean activity 0.5819
DMS Iter 295 | Accuracy 0.8711 | Perf loss 0.0487 | Spike loss 1.2881 | Weight loss 0.0591 | Mean activity 0.5714
DMS Iter 300 | Accuracy 0.8666 | Perf loss 0.0502 | Spike loss 1.2648 | Weight loss 0.0594 | Mean activity 0.5663
DMS Iter 305 | Accuracy 0.8715 | Perf loss 0.0476 | Spike loss 1.2757 | Weight loss 0.0599 | Mean activity 0.5588
DMS Iter 310 | Accuracy 0.8713 | Perf loss 0.0486 | Spike loss 1.2846 | Weight loss 0.0603 | Mean activity 0.5546
DMS Iter 315 | Accuracy 0.8686 | Perf loss 0.0498 | Spike loss 1.2558 | Weight loss 0.0604 | Mean activity 0.5455
DMS Iter 320 | Accuracy 0.8760 | Perf loss 0.0485 | Spike loss 1.2669 | Weight loss 0.0610 | Mean activity 0.5511
DMS Iter 325 | Accuracy 0.8746 | Perf loss 0.0500 | Spike loss 1.2692 | Weight loss 0.0613 | Mean activity 0.5438
DMS Iter 330 | Accuracy 0.8809 | Perf loss 0.0441 | Spike loss 1.2730 | Weight loss 0.0615 | Mean activity 0.5410
DMS Iter 335 | Accuracy 0.8809 | Perf loss 0.0454 | Spike loss 1.2474 | Weight loss 0.0619 | Mean activity 0.5316
DMS Iter 340 | Accuracy 0.8803 | Perf loss 0.0442 | Spike loss 1.2417 | Weight loss 0.0623 | Mean activity 0.5207
DMS Iter 345 | Accuracy 0.8844 | Perf loss 0.0444 | Spike loss 1.2522 | Weight loss 0.0626 | Mean activity 0.5163
DMS Iter 350 | Accuracy 0.8924 | Perf loss 0.0418 | Spike loss 1.2379 | Weight loss 0.0631 | Mean activity 0.5123
DMS Iter 355 | Accuracy 0.8795 | Perf loss 0.0457 | Spike loss 1.1955 | Weight loss 0.0632 | Mean activity 0.5204
DMS Iter 360 | Accuracy 0.8844 | Perf loss 0.0438 | Spike loss 1.2291 | Weight loss 0.0637 | Mean activity 0.5052
DMS Iter 365 | Accuracy 0.8867 | Perf loss 0.0446 | Spike loss 1.2282 | Weight loss 0.0642 | Mean activity 0.5023
DMS Iter 370 | Accuracy 0.8916 | Perf loss 0.0419 | Spike loss 1.2145 | Weight loss 0.0645 | Mean activity 0.4940
DMS Iter 375 | Accuracy 0.8951 | Perf loss 0.0407 | Spike loss 1.2218 | Weight loss 0.0651 | Mean activity 0.4938
DMS Iter 380 | Accuracy 0.8928 | Perf loss 0.0445 | Spike loss 1.2002 | Weight loss 0.0654 | Mean activity 0.4889
DMS Iter 385 | Accuracy 0.8953 | Perf loss 0.0413 | Spike loss 1.2121 | Weight loss 0.0656 | Mean activity 0.4712
DMS Iter 390 | Accuracy 0.8969 | Perf loss 0.0414 | Spike loss 1.2223 | Weight loss 0.0663 | Mean activity 0.4747
DMS Iter 395 | Accuracy 0.8936 | Perf loss 0.0415 | Spike loss 1.1969 | Weight loss 0.0664 | Mean activity 0.4663
DMS Iter 400 | Accuracy 0.8949 | Perf loss 0.0417 | Spike loss 1.1999 | Weight loss 0.0667 | Mean activity 0.4605
DMS Iter 405 | Accuracy 0.8977 | Perf loss 0.0417 | Spike loss 1.2007 | Weight loss 0.0666 | Mean activity 0.4563
DMS Iter 410 | Accuracy 0.8984 | Perf loss 0.0400 | Spike loss 1.1890 | Weight loss 0.0669 | Mean activity 0.4540
DMS Iter 415 | Accuracy 0.8967 | Perf loss 0.0408 | Spike loss 1.1682 | Weight loss 0.0672 | Mean activity 0.4449
DMS Iter 420 | Accuracy 0.8969 | Perf loss 0.0408 | Spike loss 1.1809 | Weight loss 0.0681 | Mean activity 0.4277
DMS Iter 425 | Accuracy 0.8932 | Perf loss 0.0405 | Spike loss 1.1743 | Weight loss 0.0683 | Mean activity 0.4309
DMS Iter 430 | Accuracy 0.8910 | Perf loss 0.0416 | Spike loss 1.1341 | Weight loss 0.0684 | Mean activity 0.4317
DMS Iter 435 | Accuracy 0.9074 | Perf loss 0.0386 | Spike loss 1.1491 | Weight loss 0.0686 | Mean activity 0.4341
DMS Iter 440 | Accuracy 0.8957 | Perf loss 0.0398 | Spike loss 1.1909 | Weight loss 0.0693 | Mean activity 0.4319
DMS Iter 445 | Accuracy 0.8938 | Perf loss 0.0407 | Spike loss 1.1691 | Weight loss 0.0694 | Mean activity 0.4270
DMS Iter 450 | Accuracy 0.9043 | Perf loss 0.0377 | Spike loss 1.1948 | Weight loss 0.0698 | Mean activity 0.4231
DMS Iter 455 | Accuracy 0.8969 | Perf loss 0.0411 | Spike loss 1.1601 | Weight loss 0.0700 | Mean activity 0.4285
DMS Iter 460 | Accuracy 0.9043 | Perf loss 0.0369 | Spike loss 1.1469 | Weight loss 0.0705 | Mean activity 0.4100
DMS Iter 465 | Accuracy 0.9037 | Perf loss 0.0378 | Spike loss 1.1253 | Weight loss 0.0711 | Mean activity 0.4051
DMS Iter 470 | Accuracy 0.9025 | Perf loss 0.0377 | Spike loss 1.1440 | Weight loss 0.0718 | Mean activity 0.4085
DMS Iter 475 | Accuracy 0.9043 | Perf loss 0.0376 | Spike loss 1.1172 | Weight loss 0.0720 | Mean activity 0.4043
DMS Iter 480 | Accuracy 0.9043 | Perf loss 0.0368 | Spike loss 1.1174 | Weight loss 0.0725 | Mean activity 0.3943
DMS Iter 485 | Accuracy 0.9133 | Perf loss 0.0339 | Spike loss 1.1094 | Weight loss 0.0729 | Mean activity 0.4056
DMS Iter 490 | Accuracy 0.9012 | Perf loss 0.0386 | Spike loss 1.1034 | Weight loss 0.0733 | Mean activity 0.3849
DMS Iter 495 | Accuracy 0.9037 | Perf loss 0.0377 | Spike loss 1.1030 | Weight loss 0.0735 | Mean activity 0.3882
DMS Iter 500 | Accuracy 0.9082 | Perf loss 0.0361 | Spike loss 1.0884 | Weight loss 0.0740 | Mean activity 0.3877
DMS Iter 505 | Accuracy 0.9098 | Perf loss 0.0352 | Spike loss 1.0988 | Weight loss 0.0743 | Mean activity 0.3768
DMS Iter 510 | Accuracy 0.9088 | Perf loss 0.0348 | Spike loss 1.0985 | Weight loss 0.0744 | Mean activity 0.3860
DMS Iter 515 | Accuracy 0.9053 | Perf loss 0.0370 | Spike loss 1.0939 | Weight loss 0.0746 | Mean activity 0.3713
DMS Iter 520 | Accuracy 0.9172 | Perf loss 0.0343 | Spike loss 1.1245 | Weight loss 0.0750 | Mean activity 0.3869
DMS Iter 525 | Accuracy 0.9039 | Perf loss 0.0362 | Spike loss 1.1007 | Weight loss 0.0754 | Mean activity 0.3688
DMS Iter 530 | Accuracy 0.9111 | Perf loss 0.0355 | Spike loss 1.0753 | Weight loss 0.0756 | Mean activity 0.3803
DMS Iter 535 | Accuracy 0.9064 | Perf loss 0.0361 | Spike loss 1.1138 | Weight loss 0.0761 | Mean activity 0.3709
DMS Iter 540 | Accuracy 0.9160 | Perf loss 0.0343 | Spike loss 1.1043 | Weight loss 0.0764 | Mean activity 0.3653
DMS Iter 545 | Accuracy 0.9195 | Perf loss 0.0323 | Spike loss 1.1468 | Weight loss 0.0772 | Mean activity 0.3768
DMS Iter 550 | Accuracy 0.9115 | Perf loss 0.0336 | Spike loss 1.1138 | Weight loss 0.0775 | Mean activity 0.3667
DMS Iter 555 | Accuracy 0.9125 | Perf loss 0.0343 | Spike loss 1.0900 | Weight loss 0.0775 | Mean activity 0.3603
DMS Iter 560 | Accuracy 0.9193 | Perf loss 0.0323 | Spike loss 1.0804 | Weight loss 0.0775 | Mean activity 0.3686
DMS Iter 565 | Accuracy 0.9209 | Perf loss 0.0317 | Spike loss 1.0787 | Weight loss 0.0778 | Mean activity 0.3531
DMS Iter 570 | Accuracy 0.9199 | Perf loss 0.0322 | Spike loss 1.0756 | Weight loss 0.0781 | Mean activity 0.3547
DMS Iter 575 | Accuracy 0.9135 | Perf loss 0.0340 | Spike loss 1.0826 | Weight loss 0.0782 | Mean activity 0.3441
DMS Iter 580 | Accuracy 0.9203 | Perf loss 0.0333 | Spike loss 1.0866 | Weight loss 0.0788 | Mean activity 0.3429
DMS Iter 585 | Accuracy 0.9219 | Perf loss 0.0335 | Spike loss 1.1002 | Weight loss 0.0795 | Mean activity 0.3590
DMS Iter 590 | Accuracy 0.9244 | Perf loss 0.0302 | Spike loss 1.0977 | Weight loss 0.0799 | Mean activity 0.3601
DMS Iter 595 | Accuracy 0.9254 | Perf loss 0.0317 | Spike loss 1.0713 | Weight loss 0.0804 | Mean activity 0.3472
DMS Iter 600 | Accuracy 0.9271 | Perf loss 0.0307 | Spike loss 1.0641 | Weight loss 0.0804 | Mean activity 0.3495
DMS Iter 605 | Accuracy 0.9217 | Perf loss 0.0315 | Spike loss 1.0845 | Weight loss 0.0808 | Mean activity 0.3475
DMS Iter 610 | Accuracy 0.9273 | Perf loss 0.0308 | Spike loss 1.0668 | Weight loss 0.0809 | Mean activity 0.3464
DMS Iter 615 | Accuracy 0.9299 | Perf loss 0.0301 | Spike loss 1.0581 | Weight loss 0.0810 | Mean activity 0.3463
DMS Iter 620 | Accuracy 0.9311 | Perf loss 0.0292 | Spike loss 1.0662 | Weight loss 0.0813 | Mean activity 0.3406
DMS Iter 625 | Accuracy 0.9271 | Perf loss 0.0301 | Spike loss 1.0467 | Weight loss 0.0815 | Mean activity 0.3372
DMS Iter 630 | Accuracy 0.9264 | Perf loss 0.0307 | Spike loss 1.0596 | Weight loss 0.0816 | Mean activity 0.3505
DMS Iter 635 | Accuracy 0.9287 | Perf loss 0.0316 | Spike loss 1.0671 | Weight loss 0.0818 | Mean activity 0.3264
DMS Iter 640 | Accuracy 0.9295 | Perf loss 0.0289 | Spike loss 1.0595 | Weight loss 0.0822 | Mean activity 0.3410
DMS Iter 645 | Accuracy 0.9375 | Perf loss 0.0290 | Spike loss 1.0762 | Weight loss 0.0826 | Mean activity 0.3229
DMS Iter 650 | Accuracy 0.9303 | Perf loss 0.0300 | Spike loss 1.0576 | Weight loss 0.0822 | Mean activity 0.3397
DMS Iter 655 | Accuracy 0.9340 | Perf loss 0.0278 | Spike loss 1.0979 | Weight loss 0.0827 | Mean activity 0.3331
DMS Iter 660 | Accuracy 0.9328 | Perf loss 0.0274 | Spike loss 1.0629 | Weight loss 0.0830 | Mean activity 0.3175
DMS Iter 665 | Accuracy 0.9381 | Perf loss 0.0261 | Spike loss 1.0790 | Weight loss 0.0833 | Mean activity 0.3261
DMS Iter 670 | Accuracy 0.9350 | Perf loss 0.0271 | Spike loss 1.0582 | Weight loss 0.0834 | Mean activity 0.3257
DMS Iter 675 | Accuracy 0.9350 | Perf loss 0.0263 | Spike loss 1.0180 | Weight loss 0.0833 | Mean activity 0.3198
DMS Iter 680 | Accuracy 0.9412 | Perf loss 0.0250 | Spike loss 1.0213 | Weight loss 0.0836 | Mean activity 0.3153
DMS Iter 685 | Accuracy 0.9391 | Perf loss 0.0272 | Spike loss 1.0157 | Weight loss 0.0838 | Mean activity 0.3290
DMS Iter 690 | Accuracy 0.9355 | Perf loss 0.0280 | Spike loss 1.0102 | Weight loss 0.0841 | Mean activity 0.3112
DMS Iter 695 | Accuracy 0.9357 | Perf loss 0.0272 | Spike loss 1.0224 | Weight loss 0.0847 | Mean activity 0.3169
DMS Iter 700 | Accuracy 0.9373 | Perf loss 0.0267 | Spike loss 1.0023 | Weight loss 0.0850 | Mean activity 0.3212
DMS Iter 705 | Accuracy 0.9395 | Perf loss 0.0256 | Spike loss 1.0093 | Weight loss 0.0854 | Mean activity 0.3088
DMS Iter 710 | Accuracy 0.9389 | Perf loss 0.0270 | Spike loss 0.9992 | Weight loss 0.0852 | Mean activity 0.3150
DMS Iter 715 | Accuracy 0.9396 | Perf loss 0.0248 | Spike loss 1.0078 | Weight loss 0.0856 | Mean activity 0.3174
DMS Iter 720 | Accuracy 0.9404 | Perf loss 0.0256 | Spike loss 1.0017 | Weight loss 0.0862 | Mean activity 0.3120
DMS Iter 725 | Accuracy 0.9459 | Perf loss 0.0250 | Spike loss 1.0026 | Weight loss 0.0867 | Mean activity 0.3089
DMS Iter 730 | Accuracy 0.9354 | Perf loss 0.0262 | Spike loss 1.0120 | Weight loss 0.0868 | Mean activity 0.3107
DMS Iter 735 | Accuracy 0.9387 | Perf loss 0.0261 | Spike loss 0.9864 | Weight loss 0.0867 | Mean activity 0.2896
DMS Iter 740 | Accuracy 0.9406 | Perf loss 0.0270 | Spike loss 1.0056 | Weight loss 0.0873 | Mean activity 0.3082
DMS Iter 745 | Accuracy 0.9443 | Perf loss 0.0236 | Spike loss 0.9958 | Weight loss 0.0875 | Mean activity 0.3037
DMS Iter 750 | Accuracy 0.9434 | Perf loss 0.0238 | Spike loss 1.0074 | Weight loss 0.0875 | Mean activity 0.2983
DMS Iter 755 | Accuracy 0.9479 | Perf loss 0.0227 | Spike loss 0.9761 | Weight loss 0.0874 | Mean activity 0.2986
DMS Iter 760 | Accuracy 0.9424 | Perf loss 0.0240 | Spike loss 0.9849 | Weight loss 0.0879 | Mean activity 0.3040
DMS Iter 765 | Accuracy 0.9471 | Perf loss 0.0233 | Spike loss 0.9560 | Weight loss 0.0883 | Mean activity 0.2987
DMS Iter 770 | Accuracy 0.9453 | Perf loss 0.0247 | Spike loss 0.9739 | Weight loss 0.0887 | Mean activity 0.2964
DMS Iter 775 | Accuracy 0.9445 | Perf loss 0.0240 | Spike loss 1.0132 | Weight loss 0.0887 | Mean activity 0.3047
DMS Iter 780 | Accuracy 0.9447 | Perf loss 0.0230 | Spike loss 1.0303 | Weight loss 0.0893 | Mean activity 0.3056
DMS Iter 785 | Accuracy 0.9451 | Perf loss 0.0232 | Spike loss 1.0083 | Weight loss 0.0894 | Mean activity 0.2938
DMS Iter 790 | Accuracy 0.9514 | Perf loss 0.0232 | Spike loss 0.9913 | Weight loss 0.0896 | Mean activity 0.2990
DMS Iter 795 | Accuracy 0.9465 | Perf loss 0.0242 | Spike loss 1.0077 | Weight loss 0.0900 | Mean activity 0.2884
DMS Iter 800 | Accuracy 0.9465 | Perf loss 0.0230 | Spike loss 0.9780 | Weight loss 0.0897 | Mean activity 0.2900
DMS Iter 805 | Accuracy 0.9422 | Perf loss 0.0254 | Spike loss 0.9774 | Weight loss 0.0902 | Mean activity 0.2843
DMS Iter 810 | Accuracy 0.9461 | Perf loss 0.0255 | Spike loss 0.9957 | Weight loss 0.0903 | Mean activity 0.2958
DMS Iter 815 | Accuracy 0.9488 | Perf loss 0.0239 | Spike loss 0.9833 | Weight loss 0.0907 | Mean activity 0.2843
DMS Iter 820 | Accuracy 0.9557 | Perf loss 0.0214 | Spike loss 1.0125 | Weight loss 0.0907 | Mean activity 0.2788
DMS Iter 825 | Accuracy 0.9504 | Perf loss 0.0232 | Spike loss 0.9782 | Weight loss 0.0907 | Mean activity 0.2872
DMS Iter 830 | Accuracy 0.9484 | Perf loss 0.0241 | Spike loss 1.0000 | Weight loss 0.0909 | Mean activity 0.2910
DMS Iter 835 | Accuracy 0.9479 | Perf loss 0.0232 | Spike loss 1.0095 | Weight loss 0.0911 | Mean activity 0.2851
DMS Iter 840 | Accuracy 0.9563 | Perf loss 0.0207 | Spike loss 1.0116 | Weight loss 0.0911 | Mean activity 0.2792
DMS Iter 845 | Accuracy 0.9521 | Perf loss 0.0226 | Spike loss 0.9752 | Weight loss 0.0912 | Mean activity 0.2781
DMS Iter 850 | Accuracy 0.9367 | Perf loss 0.0310 | Spike loss 0.9490 | Weight loss 0.0914 | Mean activity 0.2766
DMS Iter 855 | Accuracy 0.9434 | Perf loss 0.0261 | Spike loss 1.0368 | Weight loss 0.0929 | Mean activity 0.2730
DMS Iter 860 | Accuracy 0.9529 | Perf loss 0.0222 | Spike loss 1.0724 | Weight loss 0.0933 | Mean activity 0.2844
DMS Iter 865 | Accuracy 0.9504 | Perf loss 0.0234 | Spike loss 1.0631 | Weight loss 0.0938 | Mean activity 0.2806
DMS Iter 870 | Accuracy 0.9473 | Perf loss 0.0234 | Spike loss 1.0414 | Weight loss 0.0940 | Mean activity 0.2765
DMS Iter 875 | Accuracy 0.9490 | Perf loss 0.0230 | Spike loss 1.0663 | Weight loss 0.0946 | Mean activity 0.2675
DMS Iter 880 | Accuracy 0.9527 | Perf loss 0.0210 | Spike loss 1.0612 | Weight loss 0.0950 | Mean activity 0.2682
DMS Iter 885 | Accuracy 0.9563 | Perf loss 0.0204 | Spike loss 1.0489 | Weight loss 0.0951 | Mean activity 0.2637
DMS Iter 890 | Accuracy 0.9564 | Perf loss 0.0208 | Spike loss 1.0035 | Weight loss 0.0946 | Mean activity 0.2539
DMS Iter 895 | Accuracy 0.9547 | Perf loss 0.0207 | Spike loss 1.0087 | Weight loss 0.0948 | Mean activity 0.2553
DMS Iter 900 | Accuracy 0.9477 | Perf loss 0.0230 | Spike loss 0.9910 | Weight loss 0.0949 | Mean activity 0.2516
DMS Iter 905 | Accuracy 0.9557 | Perf loss 0.0221 | Spike loss 0.9653 | Weight loss 0.0950 | Mean activity 0.2548
DMS Iter 910 | Accuracy 0.9541 | Perf loss 0.0216 | Spike loss 0.9534 | Weight loss 0.0953 | Mean activity 0.2559
DMS Iter 915 | Accuracy 0.9557 | Perf loss 0.0194 | Spike loss 0.9345 | Weight loss 0.0950 | Mean activity 0.2515
DMS Iter 920 | Accuracy 0.9486 | Perf loss 0.0233 | Spike loss 0.9277 | Weight loss 0.0951 | Mean activity 0.2611
DMS Iter 925 | Accuracy 0.9553 | Perf loss 0.0231 | Spike loss 0.9238 | Weight loss 0.0951 | Mean activity 0.2570
DMS Iter 930 | Accuracy 0.9574 | Perf loss 0.0205 | Spike loss 0.9232 | Weight loss 0.0952 | Mean activity 0.2584
DMS Iter 935 | Accuracy 0.9545 | Perf loss 0.0219 | Spike loss 0.9053 | Weight loss 0.0956 | Mean activity 0.2516
DMS Iter 940 | Accuracy 0.9518 | Perf loss 0.0212 | Spike loss 0.9380 | Weight loss 0.0963 | Mean activity 0.2494
DMS Iter 945 | Accuracy 0.9557 | Perf loss 0.0200 | Spike loss 0.9615 | Weight loss 0.0968 | Mean activity 0.2467
DMS Iter 950 | Accuracy 0.9488 | Perf loss 0.0231 | Spike loss 0.9416 | Weight loss 0.0972 | Mean activity 0.2489
DMS Iter 955 | Accuracy 0.9553 | Perf loss 0.0201 | Spike loss 0.9724 | Weight loss 0.0975 | Mean activity 0.2519
DMS Iter 960 | Accuracy 0.9572 | Perf loss 0.0198 | Spike loss 0.9679 | Weight loss 0.0974 | Mean activity 0.2542
DMS Iter 965 | Accuracy 0.9578 | Perf loss 0.0188 | Spike loss 0.9389 | Weight loss 0.0972 | Mean activity 0.2496
DMS Iter 970 | Accuracy 0.9563 | Perf loss 0.0207 | Spike loss 0.9153 | Weight loss 0.0973 | Mean activity 0.2544
DMS Iter 975 | Accuracy 0.9596 | Perf loss 0.0205 | Spike loss 0.9059 | Weight loss 0.0972 | Mean activity 0.2454
DMS Iter 980 | Accuracy 0.9549 | Perf loss 0.0203 | Spike loss 0.9092 | Weight loss 0.0971 | Mean activity 0.2442
DMS Iter 985 | Accuracy 0.9523 | Perf loss 0.0236 | Spike loss 0.9461 | Weight loss 0.0975 | Mean activity 0.2440
DMS Iter 990 | Accuracy 0.9543 | Perf loss 0.0211 | Spike loss 0.9405 | Weight loss 0.0982 | Mean activity 0.2470
DMS Iter 995 | Accuracy 0.9605 | Perf loss 0.0198 | Spike loss 0.9788 | Weight loss 0.0992 | Mean activity 0.2469
DMS Iter 1000 | Accuracy 0.9555 | Perf loss 0.0207 | Spike loss 0.9792 | Weight loss 0.1000 | Mean activity 0.2461
DMS Iter 1005 | Accuracy 0.9541 | Perf loss 0.0207 | Spike loss 0.9978 | Weight loss 0.1002 | Mean activity 0.2378
DMS Iter 1010 | Accuracy 0.9580 | Perf loss 0.0211 | Spike loss 0.9724 | Weight loss 0.1003 | Mean activity 0.2392
DMS Iter 1015 | Accuracy 0.9502 | Perf loss 0.0214 | Spike loss 0.9557 | Weight loss 0.1001 | Mean activity 0.2403
DMS Iter 1020 | Accuracy 0.9551 | Perf loss 0.0217 | Spike loss 0.9422 | Weight loss 0.0999 | Mean activity 0.2300
DMS Iter 1025 | Accuracy 0.9590 | Perf loss 0.0203 | Spike loss 0.9418 | Weight loss 0.1002 | Mean activity 0.2368
DMS Iter 1030 | Accuracy 0.9574 | Perf loss 0.0206 | Spike loss 0.9781 | Weight loss 0.1001 | Mean activity 0.2349
DMS Iter 1035 | Accuracy 0.9627 | Perf loss 0.0194 | Spike loss 0.9598 | Weight loss 0.0998 | Mean activity 0.2411
DMS Iter 1040 | Accuracy 0.9592 | Perf loss 0.0208 | Spike loss 0.9538 | Weight loss 0.0998 | Mean activity 0.2443
DMS Iter 1045 | Accuracy 0.9578 | Perf loss 0.0201 | Spike loss 0.9230 | Weight loss 0.0996 | Mean activity 0.2371
DMS Iter 1050 | Accuracy 0.9525 | Perf loss 0.0219 | Spike loss 0.9439 | Weight loss 0.1003 | Mean activity 0.2424
DMS Iter 1055 | Accuracy 0.9578 | Perf loss 0.0204 | Spike loss 1.0266 | Weight loss 0.1015 | Mean activity 0.2359
DMS Iter 1060 | Accuracy 0.9410 | Perf loss 0.0260 | Spike loss 1.0010 | Weight loss 0.1021 | Mean activity 0.2344
DMS Iter 1065 | Accuracy 0.9602 | Perf loss 0.0201 | Spike loss 1.0120 | Weight loss 0.1031 | Mean activity 0.2487
DMS Iter 1070 | Accuracy 0.9568 | Perf loss 0.0196 | Spike loss 1.0056 | Weight loss 0.1035 | Mean activity 0.2442
DMS Iter 1075 | Accuracy 0.9563 | Perf loss 0.0197 | Spike loss 0.9897 | Weight loss 0.1033 | Mean activity 0.2477
DMS Iter 1080 | Accuracy 0.9588 | Perf loss 0.0201 | Spike loss 0.9577 | Weight loss 0.1032 | Mean activity 0.2462
DMS Iter 1085 | Accuracy 0.9516 | Perf loss 0.0208 | Spike loss 0.9455 | Weight loss 0.1034 | Mean activity 0.2504
DMS Iter 1090 | Accuracy 0.9605 | Perf loss 0.0182 | Spike loss 0.9264 | Weight loss 0.1029 | Mean activity 0.2421
DMS Iter 1095 | Accuracy 0.9641 | Perf loss 0.0180 | Spike loss 0.9008 | Weight loss 0.1024 | Mean activity 0.2502
DMS Iter 1100 | Accuracy 0.9576 | Perf loss 0.0197 | Spike loss 0.8982 | Weight loss 0.1028 | Mean activity 0.2418
DMS Iter 1105 | Accuracy 0.9551 | Perf loss 0.0217 | Spike loss 0.9124 | Weight loss 0.1040 | Mean activity 0.2322
DMS Iter 1110 | Accuracy 0.9627 | Perf loss 0.0177 | Spike loss 0.9525 | Weight loss 0.1051 | Mean activity 0.2316
DMS Iter 1115 | Accuracy 0.9596 | Perf loss 0.0185 | Spike loss 0.9270 | Weight loss 0.1055 | Mean activity 0.2248
DMS Iter 1120 | Accuracy 0.9602 | Perf loss 0.0183 | Spike loss 0.9226 | Weight loss 0.1058 | Mean activity 0.2201
DMS Iter 1125 | Accuracy 0.9607 | Perf loss 0.0206 | Spike loss 0.9030 | Weight loss 0.1056 | Mean activity 0.2226
DMS Iter 1130 | Accuracy 0.9588 | Perf loss 0.0176 | Spike loss 0.9071 | Weight loss 0.1052 | Mean activity 0.2286
DMS Iter 1135 | Accuracy 0.9594 | Perf loss 0.0185 | Spike loss 0.9155 | Weight loss 0.1056 | Mean activity 0.2299
DMS Iter 1140 | Accuracy 0.9609 | Perf loss 0.0191 | Spike loss 0.9045 | Weight loss 0.1057 | Mean activity 0.2285
DMS Iter 1145 | Accuracy 0.9643 | Perf loss 0.0176 | Spike loss 0.8940 | Weight loss 0.1052 | Mean activity 0.2287
DMS Iter 1150 | Accuracy 0.9543 | Perf loss 0.0206 | Spike loss 0.8978 | Weight loss 0.1052 | Mean activity 0.2285
DMS Iter 1155 | Accuracy 0.9594 | Perf loss 0.0186 | Spike loss 0.9168 | Weight loss 0.1056 | Mean activity 0.2298
DMS Iter 1160 | Accuracy 0.9566 | Perf loss 0.0184 | Spike loss 0.9188 | Weight loss 0.1057 | Mean activity 0.2252
DMS Iter 1165 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.9011 | Weight loss 0.1057 | Mean activity 0.2345
DMS Iter 1170 | Accuracy 0.9596 | Perf loss 0.0192 | Spike loss 0.8788 | Weight loss 0.1055 | Mean activity 0.2340
DMS Iter 1175 | Accuracy 0.9568 | Perf loss 0.0186 | Spike loss 0.8635 | Weight loss 0.1054 | Mean activity 0.2299
DMS Iter 1180 | Accuracy 0.9611 | Perf loss 0.0177 | Spike loss 0.8675 | Weight loss 0.1056 | Mean activity 0.2248
DMS Iter 1185 | Accuracy 0.9625 | Perf loss 0.0167 | Spike loss 0.8573 | Weight loss 0.1056 | Mean activity 0.2284
DMS Iter 1190 | Accuracy 0.9563 | Perf loss 0.0189 | Spike loss 0.8538 | Weight loss 0.1055 | Mean activity 0.2284
DMS Iter 1195 | Accuracy 0.9641 | Perf loss 0.0178 | Spike loss 0.8566 | Weight loss 0.1054 | Mean activity 0.2249
DMS Iter 1200 | Accuracy 0.9580 | Perf loss 0.0188 | Spike loss 0.8522 | Weight loss 0.1053 | Mean activity 0.2254
DMS Iter 1205 | Accuracy 0.9604 | Perf loss 0.0184 | Spike loss 0.8308 | Weight loss 0.1052 | Mean activity 0.2275
DMS Iter 1210 | Accuracy 0.9629 | Perf loss 0.0186 | Spike loss 0.8230 | Weight loss 0.1053 | Mean activity 0.2253
DMS Iter 1215 | Accuracy 0.9633 | Perf loss 0.0176 | Spike loss 0.8264 | Weight loss 0.1052 | Mean activity 0.2207
DMS Iter 1220 | Accuracy 0.9584 | Perf loss 0.0182 | Spike loss 0.8192 | Weight loss 0.1052 | Mean activity 0.2272
DMS Iter 1225 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.8148 | Weight loss 0.1053 | Mean activity 0.2230
DMS Iter 1230 | Accuracy 0.9621 | Perf loss 0.0179 | Spike loss 0.8174 | Weight loss 0.1053 | Mean activity 0.2225
DMS Iter 1235 | Accuracy 0.9588 | Perf loss 0.0189 | Spike loss 0.8043 | Weight loss 0.1053 | Mean activity 0.2274
DMS Iter 1240 | Accuracy 0.9631 | Perf loss 0.0177 | Spike loss 0.8178 | Weight loss 0.1057 | Mean activity 0.2176
DMS Iter 1245 | Accuracy 0.9637 | Perf loss 0.0166 | Spike loss 0.8162 | Weight loss 0.1059 | Mean activity 0.2195
DMS Iter 1250 | Accuracy 0.9652 | Perf loss 0.0165 | Spike loss 0.8167 | Weight loss 0.1062 | Mean activity 0.2236
DMS Iter 1255 | Accuracy 0.9590 | Perf loss 0.0190 | Spike loss 0.8093 | Weight loss 0.1065 | Mean activity 0.2233
DMS Iter 1260 | Accuracy 0.9586 | Perf loss 0.0203 | Spike loss 0.8348 | Weight loss 0.1070 | Mean activity 0.2247
DMS Iter 1265 | Accuracy 0.9621 | Perf loss 0.0187 | Spike loss 0.8456 | Weight loss 0.1071 | Mean activity 0.2237
DMS Iter 1270 | Accuracy 0.9582 | Perf loss 0.0199 | Spike loss 0.8415 | Weight loss 0.1074 | Mean activity 0.2230
DMS Iter 1275 | Accuracy 0.9602 | Perf loss 0.0180 | Spike loss 0.8544 | Weight loss 0.1078 | Mean activity 0.2218
DMS Iter 1280 | Accuracy 0.9646 | Perf loss 0.0171 | Spike loss 0.8512 | Weight loss 0.1082 | Mean activity 0.2227
DMS Iter 1285 | Accuracy 0.9566 | Perf loss 0.0218 | Spike loss 0.8360 | Weight loss 0.1081 | Mean activity 0.2254
DMS Iter 1290 | Accuracy 0.9613 | Perf loss 0.0180 | Spike loss 0.8803 | Weight loss 0.1081 | Mean activity 0.2248
DMS Iter 1295 | Accuracy 0.9654 | Perf loss 0.0172 | Spike loss 0.8958 | Weight loss 0.1080 | Mean activity 0.2216
DMS Iter 1300 | Accuracy 0.9609 | Perf loss 0.0178 | Spike loss 0.8779 | Weight loss 0.1077 | Mean activity 0.2268
DMS Iter 1305 | Accuracy 0.9631 | Perf loss 0.0170 | Spike loss 0.8719 | Weight loss 0.1079 | Mean activity 0.2253
DMS Iter 1310 | Accuracy 0.9631 | Perf loss 0.0166 | Spike loss 0.8505 | Weight loss 0.1080 | Mean activity 0.2237
DMS Iter 1315 | Accuracy 0.9652 | Perf loss 0.0168 | Spike loss 0.8249 | Weight loss 0.1082 | Mean activity 0.2165
DMS Iter 1320 | Accuracy 0.9637 | Perf loss 0.0168 | Spike loss 0.8397 | Weight loss 0.1081 | Mean activity 0.2241
DMS Iter 1325 | Accuracy 0.9645 | Perf loss 0.0164 | Spike loss 0.8215 | Weight loss 0.1082 | Mean activity 0.2139
DMS Iter 1330 | Accuracy 0.9619 | Perf loss 0.0181 | Spike loss 0.7980 | Weight loss 0.1082 | Mean activity 0.2095
DMS Iter 1335 | Accuracy 0.9648 | Perf loss 0.0157 | Spike loss 0.7928 | Weight loss 0.1080 | Mean activity 0.2122
DMS Iter 1340 | Accuracy 0.9600 | Perf loss 0.0188 | Spike loss 0.8036 | Weight loss 0.1082 | Mean activity 0.2132
DMS Iter 1345 | Accuracy 0.9596 | Perf loss 0.0176 | Spike loss 0.8062 | Weight loss 0.1082 | Mean activity 0.2136
DMS Iter 1350 | Accuracy 0.9588 | Perf loss 0.0217 | Spike loss 0.8128 | Weight loss 0.1081 | Mean activity 0.2039
DMS Iter 1355 | Accuracy 0.9605 | Perf loss 0.0230 | Spike loss 0.7970 | Weight loss 0.1081 | Mean activity 0.2139
DMS Iter 1360 | Accuracy 0.9605 | Perf loss 0.0243 | Spike loss 0.8114 | Weight loss 0.1088 | Mean activity 0.2005
DMS Iter 1365 | Accuracy 0.9475 | Perf loss 0.0305 | Spike loss 0.8264 | Weight loss 0.1094 | Mean activity 0.2037
DMS Iter 1370 | Accuracy 0.9498 | Perf loss 0.0268 | Spike loss 0.9131 | Weight loss 0.1122 | Mean activity 0.2222
DMS Iter 1375 | Accuracy 0.9627 | Perf loss 0.0213 | Spike loss 0.9920 | Weight loss 0.1148 | Mean activity 0.2216
DMS Iter 1380 | Accuracy 0.9596 | Perf loss 0.0233 | Spike loss 0.9783 | Weight loss 0.1151 | Mean activity 0.2213
DMS Iter 1385 | Accuracy 0.9600 | Perf loss 0.0208 | Spike loss 0.9487 | Weight loss 0.1151 | Mean activity 0.2169
DMS Iter 1390 | Accuracy 0.9598 | Perf loss 0.0205 | Spike loss 0.9231 | Weight loss 0.1150 | Mean activity 0.2160
DMS Iter 1395 | Accuracy 0.9627 | Perf loss 0.0188 | Spike loss 0.8971 | Weight loss 0.1146 | Mean activity 0.2105
DMS Iter 1400 | Accuracy 0.9645 | Perf loss 0.0171 | Spike loss 0.8679 | Weight loss 0.1142 | Mean activity 0.2143
DMS Iter 1405 | Accuracy 0.9613 | Perf loss 0.0194 | Spike loss 0.8436 | Weight loss 0.1140 | Mean activity 0.2101
DMS Iter 1410 | Accuracy 0.9596 | Perf loss 0.0201 | Spike loss 0.8148 | Weight loss 0.1135 | Mean activity 0.1999
DMS Iter 1415 | Accuracy 0.9576 | Perf loss 0.0204 | Spike loss 0.8096 | Weight loss 0.1127 | Mean activity 0.2089
DMS Iter 1420 | Accuracy 0.9631 | Perf loss 0.0182 | Spike loss 0.8169 | Weight loss 0.1127 | Mean activity 0.2145
DMS Iter 1425 | Accuracy 0.9645 | Perf loss 0.0180 | Spike loss 0.8101 | Weight loss 0.1126 | Mean activity 0.2188
DMS Iter 1430 | Accuracy 0.9635 | Perf loss 0.0187 | Spike loss 0.8028 | Weight loss 0.1127 | Mean activity 0.2143
DMS Iter 1435 | Accuracy 0.9596 | Perf loss 0.0191 | Spike loss 0.8156 | Weight loss 0.1135 | Mean activity 0.2054
DMS Iter 1440 | Accuracy 0.9648 | Perf loss 0.0187 | Spike loss 0.8341 | Weight loss 0.1143 | Mean activity 0.2148
DMS Iter 1445 | Accuracy 0.9689 | Perf loss 0.0192 | Spike loss 0.8380 | Weight loss 0.1146 | Mean activity 0.2115
DMS Iter 1450 | Accuracy 0.9600 | Perf loss 0.0202 | Spike loss 0.8435 | Weight loss 0.1150 | Mean activity 0.2156
DMS Iter 1455 | Accuracy 0.9578 | Perf loss 0.0209 | Spike loss 0.8924 | Weight loss 0.1161 | Mean activity 0.2215
DMS Iter 1460 | Accuracy 0.9623 | Perf loss 0.0182 | Spike loss 0.8875 | Weight loss 0.1166 | Mean activity 0.2158
DMS Iter 1465 | Accuracy 0.9580 | Perf loss 0.0208 | Spike loss 0.9091 | Weight loss 0.1172 | Mean activity 0.2109
DMS Iter 1470 | Accuracy 0.9588 | Perf loss 0.0179 | Spike loss 0.9209 | Weight loss 0.1175 | Mean activity 0.2048
DMS Iter 1475 | Accuracy 0.9678 | Perf loss 0.0167 | Spike loss 0.9268 | Weight loss 0.1177 | Mean activity 0.2073
DMS Iter 1480 | Accuracy 0.9641 | Perf loss 0.0176 | Spike loss 0.9199 | Weight loss 0.1178 | Mean activity 0.2083
DMS Iter 1485 | Accuracy 0.9639 | Perf loss 0.0173 | Spike loss 0.8908 | Weight loss 0.1178 | Mean activity 0.2164
DMS Iter 1490 | Accuracy 0.9578 | Perf loss 0.0196 | Spike loss 0.8575 | Weight loss 0.1176 | Mean activity 0.2109
DMS Iter 1495 | Accuracy 0.9639 | Perf loss 0.0165 | Spike loss 0.8525 | Weight loss 0.1177 | Mean activity 0.2117
DMS Iter 1500 | Accuracy 0.9676 | Perf loss 0.0158 | Spike loss 0.8359 | Weight loss 0.1171 | Mean activity 0.2040
DMS Iter 1505 | Accuracy 0.9613 | Perf loss 0.0182 | Spike loss 0.8243 | Weight loss 0.1170 | Mean activity 0.2112
DMS Iter 1510 | Accuracy 0.9592 | Perf loss 0.0174 | Spike loss 0.8195 | Weight loss 0.1173 | Mean activity 0.2075
DMS Iter 1515 | Accuracy 0.9639 | Perf loss 0.0171 | Spike loss 0.8180 | Weight loss 0.1177 | Mean activity 0.2046
DMS Iter 1520 | Accuracy 0.9666 | Perf loss 0.0169 | Spike loss 0.8242 | Weight loss 0.1178 | Mean activity 0.2081
DMS Iter 1525 | Accuracy 0.9609 | Perf loss 0.0172 | Spike loss 0.8321 | Weight loss 0.1184 | Mean activity 0.2012
DMS Iter 1530 | Accuracy 0.9617 | Perf loss 0.0176 | Spike loss 0.8266 | Weight loss 0.1186 | Mean activity 0.2057
DMS Iter 1535 | Accuracy 0.9668 | Perf loss 0.0162 | Spike loss 0.8740 | Weight loss 0.1194 | Mean activity 0.2095
DMS Iter 1540 | Accuracy 0.9607 | Perf loss 0.0216 | Spike loss 0.8984 | Weight loss 0.1201 | Mean activity 0.2094
DMS Iter 1545 | Accuracy 0.9602 | Perf loss 0.0186 | Spike loss 0.8961 | Weight loss 0.1209 | Mean activity 0.2056
DMS Iter 1550 | Accuracy 0.9600 | Perf loss 0.0254 | Spike loss 0.8728 | Weight loss 0.1216 | Mean activity 0.2052
DMS Iter 1555 | Accuracy 0.9598 | Perf loss 0.0212 | Spike loss 0.8902 | Weight loss 0.1228 | Mean activity 0.2036
DMS Iter 1560 | Accuracy 0.9582 | Perf loss 0.0221 | Spike loss 0.9120 | Weight loss 0.1241 | Mean activity 0.2131
DMS Iter 1565 | Accuracy 0.9535 | Perf loss 0.0230 | Spike loss 0.8988 | Weight loss 0.1242 | Mean activity 0.2163
DMS Iter 1570 | Accuracy 0.9563 | Perf loss 0.0234 | Spike loss 0.9593 | Weight loss 0.1249 | Mean activity 0.2230
DMS Iter 1575 | Accuracy 0.9551 | Perf loss 0.0220 | Spike loss 1.0544 | Weight loss 0.1264 | Mean activity 0.2237
DMS Iter 1580 | Accuracy 0.9486 | Perf loss 0.0342 | Spike loss 1.0935 | Weight loss 0.1277 | Mean activity 0.2217
DMS Iter 1585 | Accuracy 0.9621 | Perf loss 0.0195 | Spike loss 1.1112 | Weight loss 0.1280 | Mean activity 0.2241
DMS Iter 1590 | Accuracy 0.9615 | Perf loss 0.0198 | Spike loss 1.0739 | Weight loss 0.1275 | Mean activity 0.2263
DMS Iter 1595 | Accuracy 0.9623 | Perf loss 0.0200 | Spike loss 1.0368 | Weight loss 0.1270 | Mean activity 0.2160
DMS Iter 1600 | Accuracy 0.9563 | Perf loss 0.0193 | Spike loss 1.0115 | Weight loss 0.1268 | Mean activity 0.2078
DMS Iter 1605 | Accuracy 0.9631 | Perf loss 0.0181 | Spike loss 0.9856 | Weight loss 0.1264 | Mean activity 0.2050
DMS Iter 1610 | Accuracy 0.9633 | Perf loss 0.0179 | Spike loss 0.9625 | Weight loss 0.1259 | Mean activity 0.2071
DMS Iter 1615 | Accuracy 0.9633 | Perf loss 0.0175 | Spike loss 0.9268 | Weight loss 0.1253 | Mean activity 0.2140
DMS Iter 1620 | Accuracy 0.9588 | Perf loss 0.0214 | Spike loss 0.9385 | Weight loss 0.1259 | Mean activity 0.2098
DMS Iter 1625 | Accuracy 0.9676 | Perf loss 0.0171 | Spike loss 1.0057 | Weight loss 0.1272 | Mean activity 0.2155
DMS Iter 1630 | Accuracy 0.9580 | Perf loss 0.0209 | Spike loss 1.0043 | Weight loss 0.1276 | Mean activity 0.2070
DMS Iter 1635 | Accuracy 0.9615 | Perf loss 0.0177 | Spike loss 1.0109 | Weight loss 0.1278 | Mean activity 0.2061
DMS Iter 1640 | Accuracy 0.9613 | Perf loss 0.0197 | Spike loss 1.0107 | Weight loss 0.1283 | Mean activity 0.2072
DMS Iter 1645 | Accuracy 0.9680 | Perf loss 0.0160 | Spike loss 1.0006 | Weight loss 0.1282 | Mean activity 0.2114
DMS Iter 1650 | Accuracy 0.9627 | Perf loss 0.0167 | Spike loss 0.9525 | Weight loss 0.1277 | Mean activity 0.2022
DMS Iter 1655 | Accuracy 0.9602 | Perf loss 0.0190 | Spike loss 0.9282 | Weight loss 0.1274 | Mean activity 0.2034
DMS Iter 1660 | Accuracy 0.9660 | Perf loss 0.0161 | Spike loss 0.9054 | Weight loss 0.1271 | Mean activity 0.2124
DMS Iter 1665 | Accuracy 0.9680 | Perf loss 0.0150 | Spike loss 0.8919 | Weight loss 0.1267 | Mean activity 0.2075
DMS Iter 1670 | Accuracy 0.9701 | Perf loss 0.0154 | Spike loss 0.8714 | Weight loss 0.1264 | Mean activity 0.2108
DMS Iter 1675 | Accuracy 0.9650 | Perf loss 0.0182 | Spike loss 0.8894 | Weight loss 0.1261 | Mean activity 0.2144
DMS Iter 1680 | Accuracy 0.9682 | Perf loss 0.0172 | Spike loss 0.8823 | Weight loss 0.1261 | Mean activity 0.2129
DMS Iter 1685 | Accuracy 0.9670 | Perf loss 0.0164 | Spike loss 0.9009 | Weight loss 0.1269 | Mean activity 0.2084
DMS Iter 1690 | Accuracy 0.9664 | Perf loss 0.0156 | Spike loss 0.9136 | Weight loss 0.1277 | Mean activity 0.2111
DMS Iter 1695 | Accuracy 0.9609 | Perf loss 0.0198 | Spike loss 0.9154 | Weight loss 0.1278 | Mean activity 0.2100
DMS Iter 1700 | Accuracy 0.9666 | Perf loss 0.0149 | Spike loss 0.9216 | Weight loss 0.1275 | Mean activity 0.2079
DMS Iter 1705 | Accuracy 0.9674 | Perf loss 0.0154 | Spike loss 0.9027 | Weight loss 0.1271 | Mean activity 0.2043
DMS Iter 1710 | Accuracy 0.9639 | Perf loss 0.0181 | Spike loss 0.8879 | Weight loss 0.1270 | Mean activity 0.2094
DMS Iter 1715 | Accuracy 0.9625 | Perf loss 0.0167 | Spike loss 0.8706 | Weight loss 0.1265 | Mean activity 0.2025
DMS Iter 1720 | Accuracy 0.9629 | Perf loss 0.0181 | Spike loss 0.8676 | Weight loss 0.1265 | Mean activity 0.1984
DMS Iter 1725 | Accuracy 0.9652 | Perf loss 0.0159 | Spike loss 0.8568 | Weight loss 0.1265 | Mean activity 0.1971
DMS Iter 1730 | Accuracy 0.9686 | Perf loss 0.0148 | Spike loss 0.8430 | Weight loss 0.1263 | Mean activity 0.1966
DMS Iter 1735 | Accuracy 0.9646 | Perf loss 0.0195 | Spike loss 0.8329 | Weight loss 0.1260 | Mean activity 0.1966
DMS Iter 1740 | Accuracy 0.9662 | Perf loss 0.0163 | Spike loss 0.8172 | Weight loss 0.1259 | Mean activity 0.1943
DMS Iter 1745 | Accuracy 0.9639 | Perf loss 0.0163 | Spike loss 0.8191 | Weight loss 0.1260 | Mean activity 0.1986
DMS Iter 1750 | Accuracy 0.9688 | Perf loss 0.0151 | Spike loss 0.8169 | Weight loss 0.1259 | Mean activity 0.1961
DMS Iter 1755 | Accuracy 0.9703 | Perf loss 0.0144 | Spike loss 0.8091 | Weight loss 0.1258 | Mean activity 0.1986
DMS Iter 1760 | Accuracy 0.9592 | Perf loss 0.0193 | Spike loss 0.8076 | Weight loss 0.1257 | Mean activity 0.1952
DMS Iter 1765 | Accuracy 0.9645 | Perf loss 0.0172 | Spike loss 0.7946 | Weight loss 0.1256 | Mean activity 0.1965
DMS Iter 1770 | Accuracy 0.9625 | Perf loss 0.0171 | Spike loss 0.7859 | Weight loss 0.1256 | Mean activity 0.1962
DMS Iter 1775 | Accuracy 0.9639 | Perf loss 0.0183 | Spike loss 0.7740 | Weight loss 0.1258 | Mean activity 0.1944
DMS Iter 1780 | Accuracy 0.9662 | Perf loss 0.0156 | Spike loss 0.7814 | Weight loss 0.1262 | Mean activity 0.1977
DMS Iter 1785 | Accuracy 0.9625 | Perf loss 0.0162 | Spike loss 0.8024 | Weight loss 0.1267 | Mean activity 0.1944
DMS Iter 1790 | Accuracy 0.9719 | Perf loss 0.0138 | Spike loss 0.8086 | Weight loss 0.1273 | Mean activity 0.1952
DMS Iter 1795 | Accuracy 0.9672 | Perf loss 0.0151 | Spike loss 0.8010 | Weight loss 0.1276 | Mean activity 0.1906
DMS Iter 1800 | Accuracy 0.9680 | Perf loss 0.0148 | Spike loss 0.7924 | Weight loss 0.1277 | Mean activity 0.1936
DMS Iter 1805 | Accuracy 0.9699 | Perf loss 0.0137 | Spike loss 0.7918 | Weight loss 0.1275 | Mean activity 0.1901
DMS Iter 1810 | Accuracy 0.9672 | Perf loss 0.0152 | Spike loss 0.7841 | Weight loss 0.1272 | Mean activity 0.1907
DMS Iter 1815 | Accuracy 0.9684 | Perf loss 0.0147 | Spike loss 0.7764 | Weight loss 0.1266 | Mean activity 0.1949
DMS Iter 1820 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.7545 | Weight loss 0.1262 | Mean activity 0.1955
DMS Iter 1825 | Accuracy 0.9654 | Perf loss 0.0175 | Spike loss 0.7483 | Weight loss 0.1261 | Mean activity 0.1876
DMS Iter 1830 | Accuracy 0.9680 | Perf loss 0.0169 | Spike loss 0.7479 | Weight loss 0.1263 | Mean activity 0.1868
DMS Iter 1835 | Accuracy 0.9631 | Perf loss 0.0177 | Spike loss 0.7709 | Weight loss 0.1277 | Mean activity 0.1898
DMS Iter 1840 | Accuracy 0.9662 | Perf loss 0.0155 | Spike loss 0.8514 | Weight loss 0.1294 | Mean activity 0.1925
DMS Iter 1845 | Accuracy 0.9658 | Perf loss 0.0191 | Spike loss 0.8877 | Weight loss 0.1302 | Mean activity 0.1952
DMS Iter 1850 | Accuracy 0.9572 | Perf loss 0.0198 | Spike loss 0.8889 | Weight loss 0.1305 | Mean activity 0.1948
DMS Iter 1855 | Accuracy 0.9645 | Perf loss 0.0170 | Spike loss 0.8962 | Weight loss 0.1304 | Mean activity 0.1962
DMS Iter 1860 | Accuracy 0.9641 | Perf loss 0.0163 | Spike loss 0.9061 | Weight loss 0.1308 | Mean activity 0.1955
DMS Iter 1865 | Accuracy 0.9652 | Perf loss 0.0161 | Spike loss 0.9028 | Weight loss 0.1311 | Mean activity 0.1890
DMS Iter 1870 | Accuracy 0.9613 | Perf loss 0.0167 | Spike loss 0.8960 | Weight loss 0.1311 | Mean activity 0.1936
DMS Iter 1875 | Accuracy 0.9631 | Perf loss 0.0168 | Spike loss 0.8865 | Weight loss 0.1312 | Mean activity 0.1906
DMS Iter 1880 | Accuracy 0.9688 | Perf loss 0.0154 | Spike loss 0.8781 | Weight loss 0.1314 | Mean activity 0.1909
DMS Iter 1885 | Accuracy 0.9648 | Perf loss 0.0166 | Spike loss 0.8425 | Weight loss 0.1312 | Mean activity 0.1904
DMS Iter 1890 | Accuracy 0.9715 | Perf loss 0.0149 | Spike loss 0.8410 | Weight loss 0.1314 | Mean activity 0.1846
DMS Iter 1895 | Accuracy 0.9643 | Perf loss 0.0153 | Spike loss 0.8335 | Weight loss 0.1312 | Mean activity 0.1834
DMS Iter 1900 | Accuracy 0.9684 | Perf loss 0.0146 | Spike loss 0.8169 | Weight loss 0.1310 | Mean activity 0.1786
DMS Iter 1905 | Accuracy 0.9646 | Perf loss 0.0166 | Spike loss 0.8116 | Weight loss 0.1310 | Mean activity 0.1794
DMS Iter 1910 | Accuracy 0.9715 | Perf loss 0.0149 | Spike loss 0.8409 | Weight loss 0.1313 | Mean activity 0.1780
DMS Iter 1915 | Accuracy 0.9611 | Perf loss 0.0183 | Spike loss 0.8462 | Weight loss 0.1313 | Mean activity 0.1881
DMS Iter 1920 | Accuracy 0.9627 | Perf loss 0.0215 | Spike loss 0.8177 | Weight loss 0.1311 | Mean activity 0.1832
DMS Iter 1925 | Accuracy 0.9631 | Perf loss 0.0203 | Spike loss 0.8071 | Weight loss 0.1312 | Mean activity 0.1819
DMS Iter 1930 | Accuracy 0.9697 | Perf loss 0.0135 | Spike loss 0.7942 | Weight loss 0.1309 | Mean activity 0.1781
DMS Iter 1935 | Accuracy 0.9674 | Perf loss 0.0156 | Spike loss 0.7923 | Weight loss 0.1307 | Mean activity 0.1771
DMS Iter 1940 | Accuracy 0.9699 | Perf loss 0.0130 | Spike loss 0.7848 | Weight loss 0.1304 | Mean activity 0.1755
DMS Iter 1945 | Accuracy 0.9660 | Perf loss 0.0151 | Spike loss 0.7687 | Weight loss 0.1304 | Mean activity 0.1756
DMS Iter 1950 | Accuracy 0.9717 | Perf loss 0.0134 | Spike loss 0.7760 | Weight loss 0.1305 | Mean activity 0.1797
DMS Iter 1955 | Accuracy 0.9658 | Perf loss 0.0149 | Spike loss 0.7787 | Weight loss 0.1305 | Mean activity 0.1841
DMS Iter 1960 | Accuracy 0.9670 | Perf loss 0.0184 | Spike loss 0.7772 | Weight loss 0.1306 | Mean activity 0.1826
DMS Iter 1965 | Accuracy 0.9641 | Perf loss 0.0193 | Spike loss 0.7668 | Weight loss 0.1308 | Mean activity 0.1853
DMS Iter 1970 | Accuracy 0.9609 | Perf loss 0.0163 | Spike loss 0.7810 | Weight loss 0.1319 | Mean activity 0.1779
DMS Iter 1975 | Accuracy 0.9707 | Perf loss 0.0163 | Spike loss 0.7959 | Weight loss 0.1327 | Mean activity 0.1697
DMS Iter 1980 | Accuracy 0.9623 | Perf loss 0.0163 | Spike loss 0.8059 | Weight loss 0.1326 | Mean activity 0.1691
DMS Iter 1985 | Accuracy 0.9680 | Perf loss 0.0154 | Spike loss 0.7976 | Weight loss 0.1324 | Mean activity 0.1717
DMS Iter 1990 | Accuracy 0.9697 | Perf loss 0.0140 | Spike loss 0.7889 | Weight loss 0.1322 | Mean activity 0.1747
DMS Iter 1995 | Accuracy 0.9635 | Perf loss 0.0159 | Spike loss 0.7953 | Weight loss 0.1324 | Mean activity 0.1737
(Yang, 2020): Dynamical system analysis for RNN
Implementation of the paper:
Yang G R, Wang X J. Artificial neural networks for neuroscientists: A primer[J]. Neuron, 2020, 107(6): 1048-1070.
The original implementation is based on PyTorch: https://github.com/gyyang/nn-brain/blob/master/RNN%2BDynamicalSystemAnalysis.ipynb
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
In this tutorial, we will use supervised learning to train a recurrent neural network on a simple perceptual decision making task, and analyze the trained network using dynamical system analysis.
Defining a cognitive task
[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym
import neurogym as ngym
[4]:
# Environment
task = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 100}
seq_len = 100
# Make supervised dataset
dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,
seq_len=seq_len)
# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, fig_kwargs={'figsize': (8, 6)})

[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size
Define a vanilla continuous-time recurrent network
Here we will define a continuous-time neural network but discretize it in time using the Euler method.
This continuous-time system can then be discretized using the Euler method with a time step of \(\Delta t\),
[6]:
class RNN(bp.DynamicalSystem):
def __init__(self,
num_input,
num_hidden,
num_output,
num_batch,
dt=None, seed=None,
w_ir=bp.init.KaimingNormal(scale=1.),
w_rr=bp.init.KaimingNormal(scale=1.),
w_ro=bp.init.KaimingNormal(scale=1.)):
super(RNN, self).__init__()
# parameters
self.tau = 100
self.num_batch = num_batch
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output
if dt is None:
self.alpha = 1
else:
self.alpha = dt / self.tau
self.rng = bm.random.RandomState(seed=seed)
# input weight
self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))
# recurrent weight
bound = 1 / num_hidden ** 0.5
self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))
# readout weight
self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (num_hidden, num_output)))
self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))
# variables
self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
self.o = bm.Variable(bm.zeros((num_batch, num_output)))
def cell(self, x, h):
ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
state = h * (1 - self.alpha) + ins * self.alpha
return bm.relu(state)
def readout(self, h):
return h @ self.w_ro + self.b_ro
def update(self, x):
self.h.value = self.cell(x, self.h)
self.o.value = self.readout(self.h)
return self.h.value, self.o.value
def predict(self, xs):
self.h[:] = 0.
return bm.for_loop(self.update, xs, dyn_vars=[self.h, self.o])
def loss(self, xs, ys):
hs, os = self.predict(xs)
os = os.reshape((-1, os.shape[-1]))
loss = bp.losses.cross_entropy_loss(os, ys.flatten())
return loss, os
Train the recurrent network on the decision-making task
[7]:
# Instantiate the network and print information
hidden_size = 64
with bm.training_environment():
net = RNN(num_input=input_size,
num_hidden=hidden_size,
num_output=output_size,
num_batch=batch_size,
dt=env.dt)
[8]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
# gradient function
grad_f = bm.grad(net.loss,
child_objs=net,
grad_vars=net.train_vars().unique(),
return_value=True,
has_aux=True)
# training function
@bm.jit
@bm.to_object(child_objs=(grad_f, opt))
def train(xs, ys):
grads, loss, os = grad_f(xs, ys)
opt.update(grads)
return loss, os
[9]:
running_acc = 0
running_loss = 0
for i in range(1500):
inputs, labels_np = dataset()
inputs = bm.asarray(inputs)
labels = bm.asarray(labels_np)
loss, outputs = train(inputs, labels)
running_loss += loss
# Compute performance
output_np = np.argmax(bm.as_numpy(outputs), axis=-1).flatten()
labels_np = labels_np.flatten()
ind = labels_np > 0 # Only analyze time points when target is not fixation
running_acc += np.mean(labels_np[ind] == output_np[ind])
if i % 100 == 99:
running_loss /= 100
running_acc /= 100
print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(i + 1, running_loss, running_acc))
running_loss = 0
running_acc = 0
Step 100, Loss 0.0994, Acc 0.458
Step 200, Loss 0.0201, Acc 0.868
Step 300, Loss 0.0155, Acc 0.876
Step 400, Loss 0.0136, Acc 0.883
Step 500, Loss 0.0129, Acc 0.878
Step 600, Loss 0.0127, Acc 0.881
Step 700, Loss 0.0122, Acc 0.884
Step 800, Loss 0.0123, Acc 0.884
Step 900, Loss 0.0116, Acc 0.885
Step 1000, Loss 0.0116, Acc 0.885
Step 1100, Loss 0.0111, Acc 0.888
Step 1200, Loss 0.0111, Acc 0.885
Step 1300, Loss 0.0108, Acc 0.891
Step 1400, Loss 0.0111, Acc 0.885
Step 1500, Loss 0.0106, Acc 0.891
Visualize neural activity for in sample trials
We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space.
[10]:
env.reset(no_step=True)
perf = 0
num_trial = 100
activity_dict = {}
trial_infos = {}
for i in range(num_trial):
env.new_trial()
ob, gt = env.ob, env.gt
inputs = bm.asarray(ob[:, np.newaxis, :])
rnn_activity, action_pred = net.predict(inputs)
rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
activity_dict[i] = rnn_activity
trial_infos[i] = env.trial
# Concatenate activity for PCA
activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0)
print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)
# Print trial informations
for i in range(5):
print('Trial ', i, trial_infos[i])
Shape of the neural activity: (Time points, Neurons): (2200, 64)
Trial 0 {'ground_truth': 0, 'coh': 25.6}
Trial 1 {'ground_truth': 1, 'coh': 0.0}
Trial 2 {'ground_truth': 0, 'coh': 6.4}
Trial 3 {'ground_truth': 0, 'coh': 6.4}
Trial 4 {'ground_truth': 1, 'coh': 0.0}
[11]:
pca = PCA(n_components=2)
pca.fit(activity)
[11]:
PCA(n_components=2)
Transform individual trials and Visualize in PC space based on ground-truth color. We see that the neural activity is organized by stimulus ground-truth in PC1
[12]:
plt.rcdefaults()
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5))
for i in range(num_trial):
activity_pc = pca.transform(activity_dict[i])
trial = trial_infos[i]
color = 'red' if trial['ground_truth'] == 0 else 'blue'
_ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
if i < 5:
_ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
ax1.set_xlabel('PC 1')
ax1.set_ylabel('PC 2')
plt.show()

Search for approximate fixed points
Here we search for approximate fixed points and visualize them in the same PC space. In a generic dynamical system,
We can search for fixed points by doing the optimization
[13]:
f_cell = lambda h: net.cell(bm.asarray([1, 0.5, 0.5]), h)
[14]:
fp_candidates = bm.vstack([activity_dict[i] for i in range(num_trial)])
fp_candidates.shape
[14]:
(2200, 64)
[15]:
finder = bp.analysis.SlowPointFinder(f_cell=f_cell, f_type='discrete')
finder.find_fps_with_gd_method(
candidates=fp_candidates,
tolerance=1e-5, num_batch=200,
optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(tolerance=1e-5)
finder.keep_unique(tolerance=0.03)
finder.exclude_outliers(0.1)
fixed_points = finder.fixed_points
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.74 sec, Training loss 0.0012789472
Batches 201-400 in 0.69 sec, Training loss 0.0005247793
Batches 401-600 in 0.68 sec, Training loss 0.0003411663
Batches 601-800 in 0.69 sec, Training loss 0.0002612533
Batches 801-1000 in 0.69 sec, Training loss 0.0002112046
Batches 1001-1200 in 0.73 sec, Training loss 0.0001740992
Batches 1201-1400 in 0.70 sec, Training loss 0.0001455246
Batches 1401-1600 in 0.74 sec, Training loss 0.0001233126
Batches 1601-1800 in 0.70 sec, Training loss 0.0001058684
Batches 1801-2000 in 0.77 sec, Training loss 0.0000921902
Batches 2001-2200 in 0.72 sec, Training loss 0.0000814899
Batches 2201-2400 in 0.57 sec, Training loss 0.0000731918
Batches 2401-2600 in 0.75 sec, Training loss 0.0000667152
Batches 2601-2800 in 0.74 sec, Training loss 0.0000616670
Batches 2801-3000 in 0.70 sec, Training loss 0.0000576638
Batches 3001-3200 in 0.56 sec, Training loss 0.0000543614
Batches 3201-3400 in 0.60 sec, Training loss 0.0000515535
Batches 3401-3600 in 0.61 sec, Training loss 0.0000491180
Batches 3601-3800 in 0.71 sec, Training loss 0.0000469084
Batches 3801-4000 in 0.75 sec, Training loss 0.0000448209
Batches 4001-4200 in 0.67 sec, Training loss 0.0000428211
Batches 4201-4400 in 0.73 sec, Training loss 0.0000408868
Batches 4401-4600 in 0.71 sec, Training loss 0.0000390110
Batches 4601-4800 in 0.77 sec, Training loss 0.0000371877
Batches 4801-5000 in 0.58 sec, Training loss 0.0000354225
Batches 5001-5200 in 0.58 sec, Training loss 0.0000337130
Batches 5201-5400 in 0.69 sec, Training loss 0.0000320665
Batches 5401-5600 in 0.65 sec, Training loss 0.0000304907
Batches 5601-5800 in 0.76 sec, Training loss 0.0000289806
Batches 5801-6000 in 0.64 sec, Training loss 0.0000275472
Batches 6001-6200 in 0.55 sec, Training loss 0.0000261840
Batches 6201-6400 in 0.60 sec, Training loss 0.0000248957
Batches 6401-6600 in 0.68 sec, Training loss 0.0000236831
Batches 6601-6800 in 0.67 sec, Training loss 0.0000225569
Batches 6801-7000 in 0.71 sec, Training loss 0.0000215175
Batches 7001-7200 in 0.66 sec, Training loss 0.0000205547
Batches 7201-7400 in 0.70 sec, Training loss 0.0000196768
Batches 7401-7600 in 0.68 sec, Training loss 0.0000188892
Batches 7601-7800 in 0.71 sec, Training loss 0.0000181973
Batches 7801-8000 in 0.67 sec, Training loss 0.0000175776
Batches 8001-8200 in 0.72 sec, Training loss 0.0000170112
Batches 8201-8400 in 0.75 sec, Training loss 0.0000165215
Batches 8401-8600 in 0.54 sec, Training loss 0.0000160938
Batches 8601-8800 in 0.60 sec, Training loss 0.0000157249
Batches 8801-9000 in 0.52 sec, Training loss 0.0000154070
Batches 9001-9200 in 0.68 sec, Training loss 0.0000151405
Batches 9201-9400 in 0.54 sec, Training loss 0.0000149151
Batches 9401-9600 in 0.53 sec, Training loss 0.0000147287
Batches 9601-9800 in 0.53 sec, Training loss 0.0000145786
Batches 9801-10000 in 0.53 sec, Training loss 0.0000144657
Excluding fixed points with squared speed above tolerance 1e-05:
Kept 1157/2200 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
Kept 10/1157 unique fixed points with uniqueness tolerance 0.03.
Excluding outliers:
Kept 8/10 fixed points with within outlier tolerance 0.1.
Visualize the found approximate fixed points.
We see that they found an approximate line attrator, corresponding to our PC1, along which evidence is integrated during the stimulus period.
[16]:
# Plot in the same space as activity
plt.figure(figsize=(10, 5))
for i in range(10):
activity_pc = pca.transform(activity_dict[i])
trial = trial_infos[i]
color = 'red' if trial['ground_truth'] == 0 else 'blue'
plt.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color, alpha=0.1)
# Fixed points are shown in cross
fixedpoints_pc = pca.transform(fixed_points)
plt.plot(fixedpoints_pc[:, 0], fixedpoints_pc[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.legend()
plt.show()

Computing the Jacobian and finding the line attractor
[17]:
from jax import jacobian
[19]:
dFdh = jacobian(f_cell)(fixed_points[2])
eigval, eigvec = np.linalg.eig(bm.as_numpy(dFdh))
[20]:
# Plot distribution of eigenvalues in a 2-d real-imaginary plot
plt.figure()
plt.scatter(np.real(eigval), np.imag(eigval))
plt.plot([1, 1], [-1, 1], '--')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.show()

(Bellec, et. al, 2020): eprop for Evidence Accumulation Task
Implementation of the paper:
Bellec, G., Scherr, F., Subramoney, A., Hajek, E., Salaj, D., Legenstein, R., & Maass, W. (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature communications, 11(1), 1-15.
[1]:
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
from jax.lax import stop_gradient
from matplotlib import patches
bm.set_environment(bm.training_mode, dt=1.)
[2]:
# training parameters
n_batch = 128 # batch size
# neuron model and simulation parameters
reg_f = 1. # regularization coefficient for firing rate
reg_rate = 10 # target firing rate for regularization [Hz]
# Experiment parameters
t_cue_spacing = 150 # distance between two consecutive cues in ms
# Frequencies
input_f0 = 40. / 1000. # poisson firing rate of input neurons in khz
regularization_f0 = reg_rate / 1000. # mean target network firing frequency
[3]:
class EligSNN(bp.Network):
def __init__(self, num_in, num_rec, num_out, eprop=True, tau_a=2e3, tau_v=2e1):
super(EligSNN, self).__init__()
# parameters
self.num_in = num_in
self.num_rec = num_rec
self.num_out = num_out
self.eprop = eprop
# neurons
self.i = bp.neurons.InputGroup(num_in)
self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, mode=bm.training_mode)
n_regular = int(num_rec / 2)
n_adaptive = num_rec - n_regular
beta1 = bm.exp(- bm.get_dt() / tau_a)
beta2 = 1.7 * (1 - beta1) / (1 - bm.exp(-1 / tau_v))
beta = bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta2])
self.r = bp.neurons.ALIFBellec2020(
num_rec,
V_rest=0.,
tau_ref=5.,
V_th=0.6,
tau_a=tau_a,
tau=tau_v,
beta=beta,
V_initializer=bp.init.ZeroInit(),
a_initializer=bp.init.ZeroInit(),
mode=bm.training_mode, eprop=eprop
)
# synapses
self.i2r = bp.layers.Dense(num_in, num_rec,
W_initializer=bp.init.KaimingNormal(),
b_initializer=None)
self.i2r.W *= tau_v
self.r2r = bp.layers.Dense(num_rec, num_rec,
W_initializer=bp.init.KaimingNormal(),
b_initializer=None)
self.r2r.W *= tau_v
self.r2o = bp.layers.Dense(num_rec, num_out,
W_initializer=bp.init.KaimingNormal(),
b_initializer=None)
def update(self, shared, x):
self.r.input += self.i2r(shared, x)
z = stop_gradient(self.r.spike.value) if self.eprop else self.r.spike.value
self.r.input += self.r2r(shared, z)
self.r(shared)
self.o.input += self.r2o(shared, self.r.spike.value)
self.o(shared)
return self.o.V.value
[4]:
net = EligSNN(num_in=40, num_rec=100, num_out=2, eprop=False)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
@bp.tools.numba_jit
def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, prob, f0=0.5,
n_cues=7, t_cue=100, t_interval=150, n_input_symbols=4):
n_channel = n_neuron // n_input_symbols
# assign input spike probabilities
probs = np.where(np.random.random((batch_size, 1)) < 0.5, prob, 1 - prob)
# for each example in batch, draw which cues are going to be active (left or right)
cue_assignments = np.asarray(np.random.random(n_cues) > probs, dtype=np.int_)
# generate input nums - 0: left, 1: right, 2:recall, 3:background noise
input_nums = 3 * np.ones((batch_size, seq_len), dtype=np.int_)
input_nums[:, :n_cues] = cue_assignments
input_nums[:, -1] = 2
# generate input spikes
input_spike_prob = np.zeros((batch_size, seq_len, n_neuron))
d_silence = t_interval - t_cue
for b in range(batch_size):
for k in range(n_cues):
# input channels only fire when they are selected (left or right)
c = cue_assignments[b, k]
# reverse order of cues
i_seq = d_silence + k * t_interval
i_neu = c * n_channel
input_spike_prob[b, i_seq:i_seq + t_cue, i_neu:i_neu + n_channel] = f0
# recall cue
input_spike_prob[:, -recall_duration:, 2 * n_channel:3 * n_channel] = f0
# background noise
input_spike_prob[:, :, 3 * n_channel:] = f0 / 4.
input_spikes = input_spike_prob > np.random.rand(*input_spike_prob.shape)
# generate targets
target_mask = np.zeros((batch_size, seq_len), dtype=np.bool_)
target_mask[:, -1] = True
target_nums = (np.sum(cue_assignments, axis=1) > n_cues / 2).astype(np.int_)
return input_spikes, input_nums, target_nums, target_mask
[6]:
def get_data(batch_size, n_in, t_interval, f0):
# used for obtaining a new randomly generated batch of examples
def generate_data():
seq_len = int(t_interval * 7 + 1200)
for _ in range(10):
spk_data, _, target_data, _ = generate_click_task_data(
batch_size=batch_size, seq_len=seq_len, n_neuron=n_in, recall_duration=150,
prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4
)
yield spk_data, target_data
return generate_data
[7]:
def loss_fun(predicts, targets):
predicts, mon = predicts
# we only use network output at the end for classification
output_logits = predicts[:, -t_cue_spacing:]
# Define the accuracy
y_predict = bm.argmax(bm.mean(output_logits, axis=1), axis=1)
accuracy = bm.equal(targets, y_predict).astype(bm.dftype()).mean()
# loss function
tiled_targets = bm.tile(bm.expand_dims(targets, 1), (1, t_cue_spacing))
loss_cls = bm.mean(bp.losses.cross_entropy_loss(output_logits, tiled_targets))
# Firing rate regularization:
# For historical reason we often use this regularization,
# but the other one is easier to implement in an "online" fashion by a single agent.
av = bm.mean(mon['r.spike'], axis=(0, 1)) / bm.get_dt()
loss_reg_f = bm.sum(bm.square(av - regularization_f0) * reg_f)
# Aggregate the losses #
loss = loss_reg_f + loss_cls
loss_res = {'loss': loss, 'loss reg': loss_reg_f, 'accuracy': accuracy}
return bm.as_jax(loss), loss_res
Training
[8]:
# Training
trainer = bp.BPTT(
net,
loss_fun,
loss_has_aux=True,
optimizer=bp.optim.Adam(lr=0.01),
monitors={'r.spike': net.r.spike},
)
trainer.fit(get_data(n_batch, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0),
num_epoch=40,
num_report=10)
Train 10 steps, use 9.8286 s, loss 0.7034175992012024, accuracy 0.4281249940395355, loss reg 0.004963034763932228
Train 20 steps, use 7.3619 s, loss 0.7023941874504089, accuracy 0.4046874940395355, loss reg 0.004842016380280256
Train 30 steps, use 7.3519 s, loss 0.6943514943122864, accuracy 0.596875011920929, loss reg 0.004735519178211689
Train 40 steps, use 7.3225 s, loss 0.6893576383590698, accuracy 0.5882812738418579, loss reg 0.004621113184839487
Train 50 steps, use 7.4316 s, loss 0.7021943926811218, accuracy 0.45703125, loss reg 0.0045189810916781425
Train 60 steps, use 7.3455 s, loss 0.7054350972175598, accuracy 0.3843750059604645, loss reg 0.004483774304389954
Train 70 steps, use 7.5527 s, loss 0.6898735761642456, accuracy 0.56640625, loss reg 0.0044373138807713985
Train 80 steps, use 7.3871 s, loss 0.69098299741745, accuracy 0.5992187857627869, loss reg 0.004469707142561674
Train 90 steps, use 7.3317 s, loss 0.6886122822761536, accuracy 0.6968750357627869, loss reg 0.004408594686537981
Train 100 steps, use 7.4113 s, loss 0.6826426386833191, accuracy 0.749218761920929, loss reg 0.004439009819179773
Train 110 steps, use 7.4938 s, loss 0.6881702542304993, accuracy 0.633593738079071, loss reg 0.004656294826418161
Train 120 steps, use 7.2987 s, loss 0.6720143556594849, accuracy 0.772656261920929, loss reg 0.004621779080480337
Train 130 steps, use 7.4011 s, loss 0.6499411463737488, accuracy 0.8023437857627869, loss reg 0.004800095688551664
Train 140 steps, use 7.3568 s, loss 0.6571835875511169, accuracy 0.710156261920929, loss reg 0.005033737048506737
Train 150 steps, use 7.1787 s, loss 0.6523110866546631, accuracy 0.7250000238418579, loss reg 0.00559642817825079
Train 160 steps, use 7.1760 s, loss 0.608909010887146, accuracy 0.821093738079071, loss reg 0.005754651036113501
Train 170 steps, use 7.2558 s, loss 0.5620784163475037, accuracy 0.844531238079071, loss reg 0.006214872468262911
Train 180 steps, use 7.2844 s, loss 0.5986811518669128, accuracy 0.750781238079071, loss reg 0.006925530731678009
Train 190 steps, use 7.4182 s, loss 0.544775664806366, accuracy 0.848437488079071, loss reg 0.006775358226150274
Train 200 steps, use 7.4347 s, loss 0.5496039390563965, accuracy 0.831250011920929, loss reg 0.007397319655865431
Train 210 steps, use 7.4629 s, loss 0.5447431206703186, accuracy 0.813281238079071, loss reg 0.006942986976355314
Train 220 steps, use 7.3833 s, loss 0.5015143752098083, accuracy 0.85546875, loss reg 0.0072592394426465034
Train 230 steps, use 7.4328 s, loss 0.5421426296234131, accuracy 0.854687511920929, loss reg 0.0077950432896614075
Train 240 steps, use 7.4438 s, loss 0.4893417954444885, accuracy 0.875781238079071, loss reg 0.007711453828960657
Train 250 steps, use 7.3671 s, loss 0.48076897859573364, accuracy 0.8203125, loss reg 0.006535724736750126
Train 260 steps, use 7.3650 s, loss 0.46686863899230957, accuracy 0.8617187738418579, loss reg 0.007533709984272718
Train 270 steps, use 7.4364 s, loss 0.4155255854129791, accuracy 0.9156250357627869, loss reg 0.007653679233044386
Train 280 steps, use 7.5231 s, loss 0.5252839922904968, accuracy 0.8070312738418579, loss reg 0.0074622235260903835
Train 290 steps, use 7.4474 s, loss 0.4552551209926605, accuracy 0.840624988079071, loss reg 0.007330414839088917
Train 300 steps, use 7.3283 s, loss 0.4508514404296875, accuracy 0.8617187738418579, loss reg 0.007133393082767725
Train 310 steps, use 7.3453 s, loss 0.38369470834732056, accuracy 0.925000011920929, loss reg 0.007317659445106983
Train 320 steps, use 7.2690 s, loss 0.4067922532558441, accuracy 0.9125000238418579, loss reg 0.00806522835046053
Train 330 steps, use 7.4205 s, loss 0.4162019193172455, accuracy 0.8843750357627869, loss reg 0.0077808513306081295
Train 340 steps, use 7.3557 s, loss 0.42762160301208496, accuracy 0.8695312738418579, loss reg 0.00763324648141861
Train 350 steps, use 7.4115 s, loss 0.38524919748306274, accuracy 0.8984375, loss reg 0.00784334447234869
Train 360 steps, use 7.3625 s, loss 0.36755821108818054, accuracy 0.905468761920929, loss reg 0.007520874030888081
Train 370 steps, use 7.3553 s, loss 0.4653354585170746, accuracy 0.839062511920929, loss reg 0.007807845715433359
Train 380 steps, use 7.4220 s, loss 0.46386781334877014, accuracy 0.828906238079071, loss reg 0.007937172427773476
Train 390 steps, use 7.3668 s, loss 0.5748793482780457, accuracy 0.75, loss reg 0.007791445590555668
Train 400 steps, use 7.2759 s, loss 0.3976801037788391, accuracy 0.8812500238418579, loss reg 0.007725016679614782
Visualization
[9]:
# visualization
dataset, _ = next(get_data(20, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0)())
runner = bp.DSTrainer(net, monitors={'spike': net.r.spike})
outs = runner.predict(dataset, reset_state=True)
for i in range(10):
fig, gs = bp.visualize.get_figure(3, 1, 2., 6.)
ax_inp = fig.add_subplot(gs[0, 0])
ax_rec = fig.add_subplot(gs[1, 0])
ax_out = fig.add_subplot(gs[2, 0])
data = dataset[i]
# insert empty row
n_channel = data.shape[1] // 4
zero_fill = np.zeros((data.shape[0], int(n_channel / 2)))
data = np.concatenate((data[:, 3 * n_channel:], zero_fill,
data[:, 2 * n_channel:3 * n_channel], zero_fill,
data[:, :n_channel], zero_fill,
data[:, n_channel:2 * n_channel]), axis=1)
ax_inp.set_yticklabels([])
ax_inp.add_patch(patches.Rectangle((0, 2 * n_channel + 2 * int(n_channel / 2)),
data.shape[0], n_channel,
facecolor="red", alpha=0.1))
ax_inp.add_patch(patches.Rectangle((0, 3 * n_channel + 3 * int(n_channel / 2)),
data.shape[0], n_channel,
facecolor="blue", alpha=0.1))
bp.visualize.raster_plot(runner.mon.ts, data, ax=ax_inp, marker='|')
ax_inp.set_ylabel('Input Activity')
ax_inp.set_xticklabels([])
ax_inp.set_xticks([])
# spiking activity
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'][i], ax=ax_rec, marker='|')
ax_rec.set_ylabel('Spiking Activity')
ax_rec.set_xticklabels([])
ax_rec.set_xticks([])
# decision activity
ax_out.set_yticks([0, 0.5, 1])
ax_out.set_ylabel('Output Activity')
ax_out.plot(runner.mon.ts, outs[i, :, 0], label='Readout 0', alpha=0.7)
ax_out.plot(runner.mon.ts, outs[i, :, 1], label='Readout 1', alpha=0.7)
ax_out.set_xticklabels([])
ax_out.set_xticks([])
ax_out.set_xlabel('Time [ms]')
plt.legend()
plt.show()










(Bouchacourt & Buschman, 2019) Flexible Working Memory Model
Implementation of :
Bouchacourt, Flora, and Timothy J. Buschman. “A flexible model of working memory.” Neuron 103.1 (2019): 147-160.
Author:
Chaoming Wang (chao.brain@qq.com)
[1]:
import matplotlib.pyplot as plt
[2]:
import brainpy as bp
import brainpy.math as bm
[3]:
# increase in order to run multiple trials with the same network
num_trials = 1
num_item_to_load = 6
[4]:
# Parameters for network architecture
# ------------------------------------
num_sensory_neuron = 512 # Number of recurrent neurons per sensory network
num_sensory_pool = 8 # Number of ring-like sensory network
num_all_sensory = num_sensory_pool * num_sensory_neuron
num_all_random = 1024 # Number of neuron in the random network
fI_slope = 0.4 # slope of the non-linear f-I function
bias = 0. # bias in the neuron firing response (cf page 1 right column of Burak, Fiete 2012)
tau = 10. # Synaptic time constant [ms]
init_range = 0.01 # Range to randomly initialize synaptic variables
[5]:
# Parameters for sensory network
# -------------------------------
k2 = 0.25 # width of negative surround suppression
k1 = 1. # width of positive amplification
A = 2. # amplitude of weight function
lambda_ = 0.28 # baseline of weight matrix for recurrent network
[6]:
# Parameters for interaction of
# random network <-> sensory network
# -----------------------------------
forward_balance = -1. # if -1, perfect feed-forward balance from SN to RN
backward_balance = -1. # if -1, perfect feedback balance from RN to SN
alpha = 2.1 # parameter used to compute the feedforward weight, before balancing
beta = 0.2 # parameter used to compute the feedback weight, before balancing
gamma = 0.35 # connectivity (gamma in the paper)
factor = 1000 # factor for computing weights values
[7]:
# Parameters for stimulus
# -----------------------
simulation_time = 1100 # # the simulation time [ms]
start_stimulation = 100 # [ms]
end_stimulation = 200 # [ms]
input_strength = 10 # strength of the stimulation
num_sensory_input_width = 32
# the width for input stimulation of the gaussian distribution
sigma = round(num_sensory_neuron / num_sensory_input_width)
three_sigma = 3 * sigma
activity_threshold = 3
[8]:
# Weights initialization
# ----------------------
# weight matrix within sensory network
sensory_encoding = 2. * bm.pi * bm.arange(1, num_sensory_neuron + 1) / num_sensory_neuron
diff = sensory_encoding.reshape((-1, 1)) - sensory_encoding
weight_mat_of_sensory = lambda_ + A * bm.exp(k1 * (bm.cos(diff) - 1)) - A * bm.exp(k2 * (bm.cos(diff) - 1))
diag = bm.arange(num_sensory_neuron)
weight_mat_of_sensory[diag, diag] = 0.
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[9]:
# connectivity matrix between sensory and random network
conn_matrix_sensory2random = bm.random.rand(num_all_sensory, num_all_random) < gamma
[10]:
# weight matrix of sensory2random
ws = factor * alpha / conn_matrix_sensory2random.sum(axis=0)
weight_mat_sensory2random = conn_matrix_sensory2random * ws.reshape((1, -1))
ws = weight_mat_sensory2random.sum(axis=0).reshape((1, -1))
weight_mat_sensory2random += forward_balance / num_all_sensory * ws # balance
[11]:
# weight matrix of random2sensory
ws = factor * beta / conn_matrix_sensory2random.sum(axis=1)
weight_mat_random2sensory = conn_matrix_sensory2random.T * ws.reshape((1, -1))
ws = weight_mat_random2sensory.sum(axis=0).reshape((1, -1))
weight_mat_random2sensory += backward_balance / num_all_random * ws # balance
[12]:
@bm.jit
def f(inp_ids, center):
inp_scale = bm.exp(-(inp_ids - center) ** 2 / 2 / sigma ** 2) / (bm.sqrt(2 * bm.pi) * sigma)
inp_scale /= bm.max(inp_scale)
inp_ids = bm.remainder(inp_ids - 1, num_sensory_neuron)
input = bm.zeros(num_sensory_neuron)
input[inp_ids] = input_strength * inp_scale
input -= bm.sum(input) / num_sensory_neuron
return input
def get_input(center):
inp_ids = bm.arange(bm.asarray(center - three_sigma, dtype=bm.int32),
bm.asarray(center + three_sigma + 1, dtype=bm.int32),
1)
return f(inp_ids, center)
[13]:
def get_activity_vector(rates):
exp_stim_encoding = bm.exp(1j * sensory_encoding)
timed_abs = bm.zeros(num_sensory_pool)
timed_angle = bm.zeros(num_sensory_pool)
for si in range(num_sensory_pool):
start = si * num_sensory_neuron
end = (si + 1) * num_sensory_neuron
exp_rates = bm.multiply(rates[start:end], exp_stim_encoding)
mean_rates = bm.mean(exp_rates)
timed_angle[si] = bm.angle(mean_rates) * num_sensory_neuron / (2 * bm.pi)
timed_abs[si] = bm.absolute(mean_rates)
timed_angle[timed_angle < 0] += num_sensory_neuron
return timed_abs, timed_angle
[14]:
class PoissonNeuron(bp.NeuGroup):
def __init__(self, size, **kwargs):
super(PoissonNeuron, self).__init__(size=size, **kwargs)
self.s = bm.Variable(bm.zeros(self.num))
self.r = bm.Variable(bm.zeros(self.num))
self.input = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
self.rng = bm.random.RandomState()
self.int_s = bp.odeint(lambda s, t: -s / tau, method='exp_euler')
def update(self, tdi):
self.s.value = self.int_s(self.s, tdi.t, tdi.dt)
self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (self.input + self.s + bias) - 3.)) / tau
self.spike.value = self.rng.random(self.s.shape) < self.r * tdi.dt
self.input[:] = 0.
def reset_state(self, batch_size=None):
self.s.value = self.rng.random(self.num) * init_range
self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (bias + self.s) - 3.)) / tau
self.input.value = bm.zeros(self.num)
self.spike.value = bm.zeros(self.num, dtype=bool)
[15]:
class Sen2SenSyn(bp.SynConn):
def __init__(self, pre, post, **kwargs):
super(Sen2SenSyn, self).__init__(pre=pre, post=post, **kwargs)
def update(self, tdi):
for i in range(num_sensory_pool):
start = i * num_sensory_neuron
end = (i + 1) * num_sensory_neuron
self.post.s[start: end] += bm.dot(self.pre.spike[start: end],
weight_mat_of_sensory)
[16]:
class OtherSyn(bp.SynConn):
def __init__(self, pre, post, weights, **kwargs):
super(OtherSyn, self).__init__(pre=pre, post=post, **kwargs)
self.weights = weights
def update(self, tdi):
self.post.s += bm.dot(self.pre.spike, self.weights)
[17]:
class Network(bp.Network):
def __init__(self):
super(Network, self).__init__()
self.sensory = PoissonNeuron(num_all_sensory)
self.random = PoissonNeuron(num_all_random)
self.sensory2sensory = Sen2SenSyn(pre=self.sensory, post=self.sensory)
self.random2sensory = OtherSyn(pre=self.random,
post=self.sensory,
weights=weight_mat_random2sensory)
self.sensory2random = OtherSyn(pre=self.sensory,
post=self.random,
weights=weight_mat_sensory2random)
[18]:
for trial_idx in range(num_trials):
# inputs
# ------
pools_receiving_inputs = bm.random.choice(num_sensory_pool, num_item_to_load, replace=False)
print(f"Load {num_item_to_load} items in trial {trial_idx}.\n")
input_center = bm.ones(num_sensory_pool) * num_sensory_neuron / 2
inp_vector = bm.zeros((num_sensory_pool, num_sensory_neuron))
for si in pools_receiving_inputs:
inp_vector[si, :] = get_input(input_center[si])
inp_vector = inp_vector.flatten()
Iext, duration = bp.inputs.constant_input(
[(0., start_stimulation),
(inp_vector, end_stimulation - start_stimulation),
(0., simulation_time - end_stimulation)]
)
# running
# -------
net = Network()
runner = bp.dyn.DSRunner(net,
inputs=(net.sensory.input, Iext, 'iter'),
monitors={'S.r': net.sensory.r,
'S.spike': net.sensory.spike,
'R.spike': net.random.spike})
runner.predict(duration, reset_state=True)
# results
# --------
rate_abs, rate_angle = get_activity_vector(runner.mon['S.r'][-1] * 1e3)
print(f"Stimulus is given in: {bm.sort(pools_receiving_inputs)}")
print(f"Memory is found in: {bm.where(rate_abs > activity_threshold)[0]}")
prob_maintained, prob_spurious = 0, 0
for si in range(num_sensory_pool):
if rate_abs[si] > activity_threshold:
if si in pools_receiving_inputs:
prob_maintained += 1
else:
prob_spurious += 1
print(str(prob_maintained) + ' maintained memories')
print(str(pools_receiving_inputs.shape[0] - prob_maintained) + ' forgotten memories')
print(str(prob_spurious) + ' spurious memories\n')
prob_maintained /= float(num_item_to_load)
if num_item_to_load != num_sensory_pool:
prob_spurious /= float(num_sensory_pool - num_item_to_load)
# visualization
# -------------
fig, gs = bp.visualize.get_figure(6, 1, 1.5, 12)
xlim = (0, duration)
fig.add_subplot(gs[0:4, 0])
bp.visualize.raster_plot(runner.mon.ts,
runner.mon['S.spike'],
ylabel='Sensory Network', xlim=xlim)
for index_sn in range(num_sensory_pool + 1):
plt.axhline(index_sn * num_sensory_neuron)
plt.yticks([num_sensory_neuron * (i + 0.5) for i in range(num_sensory_pool)],
[f'pool-{i}' for i in range(num_sensory_pool)])
fig.add_subplot(gs[4:6, 0])
bp.visualize.raster_plot(runner.mon.ts,
runner.mon['R.spike'],
ylabel='Random Network', xlim=xlim, show=True)
Load 6 items in trial 0.
Stimulus is given in: [0 1 2 4 5 7]
Memory is found in: [0 7]
2 maintained memories
4 forgotten memories
0 spurious memories

(Mi, et. al., 2017) STP for Working Memory Capacity
Implementation of the paper:
Mi, Yuanyuan, Mikhail Katkov, and Misha Tsodyks. “Synaptic correlates of working memory capacity.” Neuron 93.2 (2017): 323-330.
[1]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
[2]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bm.enable_x64()
[3]:
alpha = 1.5
J_EE = 8. # the connection strength in each excitatory neural clusters
J_IE = 1.75 # Synaptic efficacy E → I
J_EI = 1.1 # Synaptic efficacy I → E
tau_f = 1.5 # time constant of STF [s]
tau_d = .3 # time constant of STD [s]
U = 0.3 # minimum STF value
tau = 0.008 # time constant of firing rate of the excitatory neurons [s]
tau_I = tau # time constant of firing rate of the inhibitory neurons
Ib = 8. # background input and external input
Iinh = 0. # the background input of inhibtory neuron
cluster_num = 16 # the number of the clusters
[4]:
# the parameters of external input
stimulus_num = 5
Iext_train = 225 # the strength of the external input
Ts_interval = 0.070 # the time interval between the consequent external input [s]
Ts_duration = 0.030 # the time duration of the external input [s]
duration = 2.500 # [s]
[5]:
# the excitatory cluster model and the inhibitory pool model
class WorkingMemoryModel(bp.DynamicalSystem):
def __init__(self, **kwargs):
super(WorkingMemoryModel, self).__init__(**kwargs)
# variables
self.u = bm.Variable(bm.ones(cluster_num) * U)
self.x = bm.Variable(bm.ones(cluster_num))
self.h = bm.Variable(bm.zeros(cluster_num))
self.r = bm.Variable(self.log(self.h))
self.input = bm.Variable(bm.zeros(cluster_num))
self.inh_h = bm.Variable(bm.zeros(1))
self.inh_r = bm.Variable(self.log(self.inh_h))
def log(self, h):
# return alpha * bm.log(1. + bm.exp(h / alpha))
return alpha * bm.log1p(bm.exp(h / alpha))
def update(self, tdi):
uxr = self.u * self.x * self.r
du = (U - self.u) / tau_f + U * (1 - self.u) * self.r
dx = (1 - self.x) / tau_d - uxr
dh = (-self.h + J_EE * uxr - J_EI * self.inh_r + self.input + Ib) / tau
dhi = (-self.inh_h + J_IE * bm.sum(self.r) + Iinh) / tau_I
self.u += du * tdi.dt
self.x += dx * tdi.dt
self.h += dh * tdi.dt
self.inh_h += dhi * tdi.dt
self.r[:] = self.log(self.h)
self.inh_r[:] = self.log(self.inh_h)
self.input[:] = 0.
[6]:
dt = 0.0001 # [s]
# the external input
I_inputs = bm.zeros((int(duration / dt), cluster_num))
for i in range(stimulus_num):
t_start = (Ts_interval + Ts_duration) * i + Ts_interval
t_end = t_start + Ts_duration
idx_start, idx_end = int(t_start / dt), int(t_end / dt)
I_inputs[idx_start: idx_end, i] = Iext_train
[7]:
# running
runner = bp.DSRunner(WorkingMemoryModel(),
inputs=['input', I_inputs, 'iter'],
monitors=['u', 'x', 'r', 'h'],
dt=dt)
runner(duration)
[8]:
# visualization
colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys())
fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
fig.add_subplot(gs[0, 0])
for i in range(stimulus_num):
plt.plot(runner.mon.ts, runner.mon.r[:, i], label='Cluster-{}'.format(i))
plt.ylabel("$r (Hz)$")
plt.legend(loc='right')
fig.add_subplot(gs[1, 0])
hist_Jux = J_EE * runner.mon.u * runner.mon.x
for i in range(stimulus_num):
plt.plot(runner.mon.ts, hist_Jux[:, i])
plt.ylabel("$J_{EE}ux$")
fig.add_subplot(gs[2, 0])
for i in range(stimulus_num):
plt.plot(runner.mon.ts, runner.mon.u[:, i], colors[i])
plt.ylabel('u')
fig.add_subplot(gs[3, 0])
for i in range(stimulus_num):
plt.plot(runner.mon.ts, runner.mon.x[:, i], colors[i])
plt.ylabel('x')
fig.add_subplot(gs[4, 0])
for i in range(stimulus_num):
plt.plot(runner.mon.ts, runner.mon.r[:, i], colors[i])
plt.ylabel('h')
plt.xlabel('time [s]')
plt.show()

[1D] Simple systems
[1]:
import brainpy as bp
bp.math.enable_x64()
# bp.math.set_platform('cpu')
Phase plane
Here we will show the birfurcation analysis of 1D system with dummy test neuronal model.
First, let’s define the model.
[2]:
@bp.odeint
def int_x(x, t, Iext):
dx = x ** 3 - x + Iext
return dx
[3]:
analyzer = bp.analysis.PhasePlane1D(int_x,
target_vars={'x': [-10, 10]},
pars_update={'Iext': 0.})
analyzer.plot_vector_field()
analyzer.plot_fixed_point(show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I am creating the vector field ...
I am searching fixed points ...
Fixed point #1 at x=-1.0000000000000002 is a unstable point.
Fixed point #2 at x=-7.771561172376096e-16 is a stable point.
Fixed point #3 at x=1.0000000000000002 is a unstable point.

Codimension1
Then, create a bifurcation analyzer with bp.symbolic.Bifurcation
.
[4]:
an = bp.analysis.Bifurcation1D(
int_x,
target_pars={'Iext': [-0.5, 0.5]},
target_vars={"x": [-2, 2]},
resolutions={'Iext': 0.001}
)
an.plot_bifurcation(show=True)
I am making bifurcation analysis ...

Codimension2
Here we define the following 1D model for codimension 2 bifurcation testing.
[5]:
@bp.odeint
def int_x(x, t, mu, lambda_):
dxdt = mu + lambda_ * x - x ** 3
return dxdt
[6]:
analyzer = bp.analysis.Bifurcation1D(
int_x,
target_pars={'lambda_': [-1, 4], 'mu': [-4, 4], },
target_vars={'x': [-3, 3]},
resolutions={'lambda_': 0.1, 'mu': 0.1}
)
analyzer.plot_bifurcation(show=True)
I am making bifurcation analysis ...

[2D] NaK model analysis
Here we will show you the neurodynamics analysis of a two-dimensional system model with the example of the \(I_{\rm{Na,p+}}-I_K\) Model.
The dynamical system is given by:
where
This model specifies a leak current \(I_L\), persistent sodium current \(I_{\rm{Na, p}}\) with instantaneous activation kinetic, and a relatively slower persistent potassium current \(I_K\) with either high or low threshold (the two choices result in fundamentally different dynamics).
[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bm.set_dt(dt=0.02)
bm.enable_x64()
[2]:
C = 1
E_L = -78 # different from high-threshold model
g_L = 8
g_Na = 20
g_K = 10
E_K = -90
E_Na = 60
Vm_half = -20
k_m = 15
Vn_half = -45 # different from high-threshold model
k_n = 5
tau = 1
@bp.odeint(method='exp_auto')
def int_V(V, t, n, Iext):
m_inf = 1 / (1 + bm.exp((Vm_half - V) / k_m))
I_leak = g_L * (V - E_L)
I_Na = g_Na * m_inf * (V - E_Na)
I_K = g_K * n * (V - E_K)
dvdt = (-I_leak - I_Na - I_K + Iext) / C
return dvdt
@bp.odeint(method='exp_auto')
def int_n(n, t, V):
n_inf = 1 / (1 + bm.exp((Vn_half - V) / k_n))
dndt = (n_inf - n) / tau
return dndt
Phase plane analysis
[3]:
analyzer = bp.analysis.PhasePlane2D(
model=[int_n, int_V],
target_vars={'n': [0., 1.], 'V': [-90, 20]},
pars_update={'Iext': 50.},
resolutions={'n': 0.01, 'V': 0.1}
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'n': [0.2, 0.4], 'V': [-10, -80]},
duration=100., show=True)
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 680 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 n=0.210529358619151, V=-51.60868761096133 is a unstable focus.
I am plotting the trajectory ...

Codimension 1 bifurcation analysis
Here we show the codimension 1 bifurcation analysis of the \(I_{\rm{Na,p+}}-I_K\) Model, in which \(I_{ext}\) is varied in [0., 50.].
[4]:
analyzer = bp.analysis.Bifurcation2D(
model=[int_V, int_n],
target_vars={"V": [-90., 20.], 'n': [0., 1.]},
target_pars={'Iext': [0, 50.]},
resolutions={'Iext': 0.1},
)
analyzer.plot_bifurcation(num_rank=30)
analyzer.plot_limit_cycle_by_sim(duration=100., show=True)
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 15000 candidates
I am trying to filter out duplicate fixed points ...
Found 325 fixed points.
I am plotting the limit cycle ...


Reference
Izhikevich, Eugene M. Dynamical systems in neuroscience (Chapter 4). MIT press, 2007.
[2D] Wilson-Cowan model
[1]:
import brainpy as bp
import brainpy.math as bm
# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
class WilsonCowanModel(bp.DynamicalSystem):
def __init__(self, num, method='exp_auto'):
super(WilsonCowanModel, self).__init__()
# Connection weights
self.wEE = 12
self.wEI = 4
self.wIE = 13
self.wII = 11
# Refractory parameter
self.r = 1
# Excitatory parameters
self.E_tau = 1 # Timescale of excitatory population
self.E_a = 1.2 # Gain of excitatory population
self.E_theta = 2.8 # Threshold of excitatory population
# Inhibitory parameters
self.I_tau = 1 # Timescale of inhibitory population
self.I_a = 1 # Gain of inhibitory population
self.I_theta = 4 # Threshold of inhibitory population
# variables
self.i = bm.Variable(bm.ones(num))
self.e = bm.Variable(bm.ones(num))
self.Iext = bm.Variable(bm.zeros(num))
# functions
def F(x, a, theta):
return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta))
def de(e, t, i, Iext=0.):
x = self.wEE * e - self.wEI * i + Iext
return (-e + (1 - self.r * e) * F(x, self.E_a, self.E_theta)) / self.E_tau
def di(i, t, e):
x = self.wIE * e - self.wII * i
return (-i + (1 - self.r * i) * F(x, self.I_a, self.I_theta)) / self.I_tau
self.int_e = bp.odeint(de, method=method)
self.int_i = bp.odeint(di, method=method)
def update(self, tdi):
self.e.value = self.int_e(self.e, tdi.t, self.i, self.Iext, tdi.dt)
self.i.value = self.int_i(self.i, tdi.t, self.e, tdi.dt)
self.Iext[:] = 0.
Simulation
[3]:
model = WilsonCowanModel(2)
model.e[:] = [-0.1, 0.4]
model.i[:] = [0.5, 0.6]
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
# simulation
runner = bp.DSRunner(model, monitors=['e', 'i'])
runner.run(100)
[5]:
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[0], legend='e', linestyle='--')
bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[0], legend='i', linestyle='--')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[1], legend='e')
bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[1], legend='i', show=True)

Phase Plane Analysis
[6]:
# phase plane analysis
pp = bp.analysis.PhasePlane2D(
model,
target_vars={'e': [-0.2, 1.], 'i': [-0.2, 1.]},
resolutions=0.001,
)
pp.plot_vector_field()
pp.plot_nullcline(coords={'i': 'i-e'})
pp.plot_fixed_point()
pp.plot_trajectory(initials={'i': [0.5, 0.6], 'e': [-0.1, 0.4]},
duration=10, dt=0.1)
pp.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 1966 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 e=-1.6868322611378924e-07, i=-3.4570638908889285e-08 is a stable node.
#2 e=0.17803581298567323, i=0.06279970624572738 is a saddle node.
#3 e=0.46246485137398247, i=0.24336408752379407 is a stable node.
I am plotting the trajectory ...

[2D] Decision Making Model with SlowPointFinder
[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.enable_x64()
# bp.math.set_platform('cpu')
[2]:
# parameters
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154
[3]:
JE = 0.3725 # self-coupling strength [nA]
JI = -0.1137 # cross-coupling strength [nA]
JAext = 0.00117 # Stimulus input strength [nA]
[4]:
mu = 20. # Stimulus firing rate [spikes/sec]
coh = 0.5 # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
[5]:
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
return - s1 / tau + (1. - s1) * gamma * r1
[6]:
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
return - s2 / tau + (1. - s2) * gamma * r2
[7]:
def cell(s):
ds1 = int_s1.f(s[0], 0., s[1])
ds2 = int_s2.f(s[1], 0., s[0])
return bm.asarray([ds1, ds2])
[8]:
finder = bp.analysis.SlowPointFinder(f_cell=cell, f_type='continuous')
finder.find_fps_with_gd_method(
candidates=bm.random.random((1000, 2)),
tolerance=1e-5, num_batch=200,
optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(1e-5)
finder.keep_unique()
print('fixed_points: ', finder.fixed_points)
print('losses:', finder.losses)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.53 sec, Training loss 0.0209022794
Batches 201-400 in 0.39 sec, Training loss 0.0034664037
Batches 401-600 in 0.40 sec, Training loss 0.0009952186
Batches 601-800 in 1.03 sec, Training loss 0.0003619036
Batches 801-1000 in 0.95 sec, Training loss 0.0001457642
Batches 1001-1200 in 0.42 sec, Training loss 0.0000614027
Batches 1201-1400 in 0.39 sec, Training loss 0.0000262476
Batches 1401-1600 in 0.49 sec, Training loss 0.0000111700
Batches 1601-1800 in 1.29 sec, Training loss 0.0000046670
Stop optimization as mean training loss 0.0000046670 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
Kept 957/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
Kept 3/957 unique fixed points with uniqueness tolerance 0.025.
fixed_points: [[0.70045189 0.00486431]
[0.01394651 0.65738904]
[0.28276323 0.40635171]]
losses: [2.16451334e-30 1.51498296e-28 3.08194113e-16]
[9]:
jac = finder.compute_jacobians(finder.fixed_points, plot=True, num_col=1)

[2D] Decision Making Model with Low-dimensional Analyzer
Please refer to Anlysis of A Decision Making Model.
[3D] Hindmarsh Rose Model
[1]:
import brainpy as bp
# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
import matplotlib.pyplot as plt
import numpy as np
[3]:
class HindmarshRose(bp.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(HindmarshRose, self).__init__()
# parameters
self.a = 1.
self.b = 2.5
self.c = 1.
self.d = 5.
self.s = 4.
self.x_r = -1.6
self.r = 0.001
# variables
self.x = bp.math.Variable(bp.math.ones(1))
self.y = bp.math.Variable(bp.math.ones(1))
self.z = bp.math.Variable(bp.math.ones(1))
self.I = bp.math.Variable(bp.math.zeros(1))
# integral functions
def dx(x, t, y, z, Isyn):
return y - self.a * x ** 3 + self.b * x * x - z + Isyn
def dy(y, t, x):
return self.c - self.d * x * x - y
def dz(z, t, x):
return self.r * (self.s * (x - self.x_r) - z)
self.int_x = bp.odeint(f=dx, method=method)
self.int_y = bp.odeint(f=dy, method=method)
self.int_z = bp.odeint(f=dz, method=method)
def update(self, tdi):
self.x.value = self.int_x(self.x, tdi.t, self.y, self.z, self.I, tdi.dt)
self.y.value = self.int_y(self.y, tdi.t, self.x, tdi.dt)
self.z.value = self.int_z(self.z, tdi.t, self.x, tdi.dt)
self.I[:] = 0.
Simulation
[4]:
model = HindmarshRose()
runner = bp.DSRunner(model, monitors=['x', 'y', 'z'], inputs=['I', 1.5])
runner.run(2000.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Bifurcation analysis
[5]:
analyzer = bp.analysis.FastSlow2D(
[model.int_x, model.int_y, model.int_z],
fast_vars={'x': [-3, 2], 'y': [-20., 3.]},
slow_vars={'z': [-0.5, 3.]},
pars_update={'Isyn': 1.5},
resolutions={'z': 0.01},
# options={bp.analysis.C.y_by_x_in_fy: lambda x: model.c - model.d * x * x}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.plot_trajectory({'x': [1.], 'y': [1.], 'z': [1.]},
duration=1700,
plot_durations=[360, 1680])
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
There are 7000 candidates
I am trying to filter out duplicate fixed points ...
Found 789 fixed points.
I am plotting the trajectory ...


Phase plane analysis
[6]:
for z in np.arange(0., 2.5, 0.3):
analyzer = bp.analysis.PhasePlane2D(
[model.int_x, model.int_y],
target_vars={'x': [-3, 2], 'y': [-20., 3.]},
pars_update={'Isyn': 1.5, 'z': z},
resolutions={'x': 0.01, 'y': 0.01},
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
fps = analyzer.plot_fixed_point(with_return=True)
analyzer.plot_trajectory({'x': [fps[-1, 0] + 0.1], 'y': [fps[-1, 0] + 0.1]},
duration=500, plot_durations=[400, 500])
plt.title(f'z={z:.2f}')
plt.show()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2111 candidates
I am trying to filter out duplicate fixed points ...
Found 1 fixed points.
#1 x=0.862288349294799, y=-2.7177058646654935 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2141 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-1.872654283356628, y=-16.5341715842593 is a stable node.
#2 x=-1.4420391946597404, y=-9.397382585390593 is a saddle node.
#3 x=0.8146858103927409, y=-2.318565022506274 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2171 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.0462132406600912, y=-19.934943131130087 is a stable node.
#2 x=-1.2168555131297534, y=-6.40368673337723 is a saddle node.
#3 x=0.7630687306850773, y=-1.911369548008847 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2201 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.1556926349477843, y=-22.235053679257277 is a stable node.
#2 x=-1.05070800710585, y=-4.519936589150076 is a saddle node.
#3 x=0.7064006730222961, y=-1.4950094398974472 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2231 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.2411861683124883, y=-24.1145772063309 is a stable node.
#2 x=-0.901932179236277, y=-3.067408474244581 is a saddle node.
#3 x=0.6431188901157783, y=-1.0680095250383574 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2261 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.3130990884027134, y=-25.75213708580098 is a stable node.
#2 x=-0.7575691876779478, y=-1.8695551041910168 is a saddle node.
#3 x=0.5706681801629564, y=-0.6283107114652488 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2291 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.3760052650455847, y=-27.227005096799942 is a stable node.
#2 x=-0.6083084538224832, y=-0.8501958918559419 is a saddle node.
#3 x=0.48431473687387644, y=-0.17280300724589612 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2321 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.4323928622096647, y=-28.582675180900974 is a stable node.
#2 x=-0.4407309093835196, y=0.028781414591295275 is a saddle node.
#3 x=0.37312367222430437, y=0.30389369711344255 is a unstable focus.
I am plotting the trajectory ...

I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
There are 2351 candidates
I am trying to filter out duplicate fixed points ...
Found 3 fixed points.
#1 x=-2.4837904744530963, y=-29.84607557713767 is a stable node.
#2 x=-0.20892015454771234, y=0.7817619225980599 is a saddle node.
#3 x=0.19271039914787816, y=0.8143135053948518 is a stable focus.
I am plotting the trajectory ...

Continuous-attractor Neural Network
[1]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
[2]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
Model
[3]:
class CANN1D(bp.NeuGroup):
def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4.,
z_min=-bm.pi, z_max=bm.pi, **kwargs):
super(CANN1D, self).__init__(size=num, **kwargs)
# parameters
self.tau = tau # The synaptic time constant
self.k = k # Degree of the rescaled inhibition
self.a = a # Half-width of the range of excitatory connections
self.A = A # Magnitude of the external input
self.J0 = J0 # maximum connection value
# feature space
self.z_min = z_min
self.z_max = z_max
self.z_range = z_max - z_min
self.x = bm.linspace(z_min, z_max, num) # The encoded feature values
self.rho = num / self.z_range # The neural density
self.dx = self.z_range / num # The stimulus density
# variables
self.u = bm.Variable(bm.zeros(num))
self.input = bm.Variable(bm.zeros(num))
# The connection matrix
self.conn_mat = self.make_conn(self.x)
# function
self.integral = bp.odeint(self.derivative)
def derivative(self, u, t, Iext):
r1 = bm.square(u)
r2 = 1.0 + self.k * bm.sum(r1)
r = r1 / r2
Irec = bm.dot(self.conn_mat, r)
du = (-u + Irec + Iext) / self.tau
return du
def dist(self, d):
d = bm.remainder(d, self.z_range)
d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
return d
def make_conn(self, x):
assert bm.ndim(x) == 1
x_left = bm.reshape(x, (-1, 1))
x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
d = self.dist(x_left - x_right)
Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
return Jxx
def get_stimulus_by_pos(self, pos):
return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def update(self, tdi):
self.u.value = self.integral(self.u, tdi.t, self.input, tdi.t)
self.input[:] = 0.
Find fixed points
[4]:
model = CANN1D(num=512, k=0.1, A=30, a=0.5)
[5]:
candidates = model.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.005).reshape((-1, 1)))
finder = bp.analysis.SlowPointFinder(f_cell=model, target_vars={'u': model.u})
finder.find_fps_with_gd_method(
candidates={'u': candidates},
tolerance=1e-6,
num_batch=200,
optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.1, 2, 0.999)),
)
finder.filter_loss(1e-7)
finder.keep_unique()
# finder.exclude_outliers(tolerance=1e1)
print('Losses of fixed points:')
print(finder.losses)
Optimizing with Adam(lr=ExponentialDecay(0.1, decay_steps=2, decay_rate=0.999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 5.00 sec, Training loss 0.0000000000
Stop optimization as mean training loss 0.0000000000 is below tolerance 0.0000010000.
Excluding fixed points with squared speed above tolerance 1e-07:
Kept 1257/1257 fixed points with tolerance under 1e-07.
Excluding non-unique fixed points:
Kept 1257/1257 unique fixed points with uniqueness tolerance 0.025.
Losses of fixed points:
[0. 0. 0. ... 0. 0. 0.]
Visualize fixed points
[6]:
pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points['u'])
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()

[7]:
fps = finder.fixed_points['u']
plot_ids = (10, 100, 200, 300,)
plot_ids = np.asarray(plot_ids)
for i in plot_ids:
plt.plot(model.x, fps[i], label=f'FP-{i}')
plt.legend()
plt.show()

Verify the stabilities of fixed points
[8]:
from jax.tree_util import tree_map
_ = finder.compute_jacobians(
tree_map(lambda x: x[plot_ids], finder._fixed_points),
plot=True
)

Gap junction-coupled FitzHugh-Nagumo Model
[1]:
import brainpy as bp
import brainpy.math as bm
# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
class GJCoupledFHN(bp.DynamicalSystem):
def __init__(self, num=2, method='exp_auto'):
super(GJCoupledFHN, self).__init__()
# parameters
self.num = num
self.a = 0.7
self.b = 0.8
self.tau = 12.5
self.gjw = 0.0001
# variables
self.V = bm.Variable(bm.random.uniform(-2, 2, num))
self.w = bm.Variable(bm.random.uniform(-2, 2, num))
self.Iext = bm.Variable(bm.zeros(num))
# functions
self.int_V = bp.odeint(self.dV, method=method)
self.int_w = bp.odeint(self.dw, method=method)
def dV(self, V, t, w, Iext=0.):
gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw
dV = V - V * V * V / 3 - w + Iext + gj
return dV
def dw(self, w, t, V):
dw = (V + self.a - self.b * w) / self.tau
return dw
def update(self, tdi):
self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
self.w.value = self.int_w(self.w, tdi.t, self.V, tdi.dt)
[3]:
def analyze_net(num=2, gjw=0.01, Iext=bm.asarray([0., 0.6])):
assert isinstance(Iext, (int, float)) or (len(Iext) == num)
model = GJCoupledFHN(num)
model.gjw = gjw
model.Iext[:] = Iext
# simulation
runner = bp.DSRunner(model, monitors=['V'])
runner.run(300.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',
plot_ids=list(range(model.num)), show=True)
# analysis
finder = bp.analysis.SlowPointFinder(f_cell=model,
target_vars={'V': model.V, 'w': model.w})
finder.find_fps_with_gd_method(
candidates={'V': bm.random.normal(0., 2., (1000, model.num)),
'w': bm.random.normal(0., 2., (1000, model.num))},
tolerance=1e-5,
num_batch=200,
optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)),
)
finder.filter_loss(1e-8)
finder.keep_unique()
print('fixed_points: ', finder.fixed_points)
print('losses:', finder.losses)
_ = finder.compute_jacobians(finder.fixed_points, plot=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
4D system
[4]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.6]))

Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.47 sec, Training loss 0.0002525318
Batches 201-400 in 1.82 sec, Training loss 0.0002021251
Batches 401-600 in 1.06 sec, Training loss 0.0001682558
Batches 601-800 in 0.64 sec, Training loss 0.0001426594
Batches 801-1000 in 0.88 sec, Training loss 0.0001222235
Batches 1001-1200 in 1.86 sec, Training loss 0.0001054805
Batches 1201-1400 in 0.51 sec, Training loss 0.0000914966
Batches 1401-1600 in 0.60 sec, Training loss 0.0000796636
Batches 1601-1800 in 0.75 sec, Training loss 0.0000695541
Batches 1801-2000 in 2.01 sec, Training loss 0.0000608892
Batches 2001-2200 in 0.63 sec, Training loss 0.0000534558
Batches 2201-2400 in 0.63 sec, Training loss 0.0000470366
Batches 2401-2600 in 1.92 sec, Training loss 0.0000414673
Batches 2601-2800 in 0.86 sec, Training loss 0.0000366226
Batches 2801-3000 in 0.44 sec, Training loss 0.0000323996
Batches 3001-3200 in 0.47 sec, Training loss 0.0000287081
Batches 3201-3400 in 0.65 sec, Training loss 0.0000254832
Batches 3401-3600 in 0.53 sec, Training loss 0.0000226843
Batches 3601-3800 in 0.80 sec, Training loss 0.0000202542
Batches 3801-4000 in 1.89 sec, Training loss 0.0000181205
Batches 4001-4200 in 0.64 sec, Training loss 0.0000162401
Batches 4201-4400 in 0.52 sec, Training loss 0.0000145707
Batches 4401-4600 in 2.00 sec, Training loss 0.0000130935
Batches 4601-4800 in 0.67 sec, Training loss 0.0000118004
Batches 4801-5000 in 0.59 sec, Training loss 0.0000106516
Batches 5001-5200 in 0.86 sec, Training loss 0.0000096305
Stop optimization as mean training loss 0.0000096305 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
Kept 306/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
Kept 1/306 unique fixed points with uniqueness tolerance 0.025.
fixed_points: {'V': array([[-1.1731516 , -0.73801606]]), 'w': array([[-0.59144118, -0.04753834]])}
losses: [6.8918834e-15]

[5]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.1]))

Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 1.39 sec, Training loss 0.0002524889
Batches 201-400 in 0.47 sec, Training loss 0.0002039485
Batches 401-600 in 0.39 sec, Training loss 0.0001722068
Batches 601-800 in 0.49 sec, Training loss 0.0001481642
Batches 801-1000 in 0.67 sec, Training loss 0.0001288039
Batches 1001-1200 in 0.69 sec, Training loss 0.0001127157
Batches 1201-1400 in 2.42 sec, Training loss 0.0000990425
Batches 1401-1600 in 0.71 sec, Training loss 0.0000872609
Batches 1601-1800 in 0.45 sec, Training loss 0.0000770819
Batches 1801-2000 in 1.56 sec, Training loss 0.0000682273
Batches 2001-2200 in 1.07 sec, Training loss 0.0000605319
Batches 2201-2400 in 0.59 sec, Training loss 0.0000538263
Batches 2401-2600 in 0.38 sec, Training loss 0.0000479664
Batches 2601-2800 in 1.37 sec, Training loss 0.0000428171
Batches 2801-3000 in 0.77 sec, Training loss 0.0000382962
Batches 3001-3200 in 0.53 sec, Training loss 0.0000343242
Batches 3201-3400 in 0.39 sec, Training loss 0.0000307939
Batches 3401-3600 in 1.54 sec, Training loss 0.0000276377
Batches 3601-3800 in 0.71 sec, Training loss 0.0000248339
Batches 3801-4000 in 0.41 sec, Training loss 0.0000223625
Batches 4001-4200 in 0.46 sec, Training loss 0.0000202066
Batches 4201-4400 in 0.58 sec, Training loss 0.0000182917
Batches 4401-4600 in 0.47 sec, Training loss 0.0000165570
Batches 4601-4800 in 1.13 sec, Training loss 0.0000150186
Batches 4801-5000 in 0.96 sec, Training loss 0.0000136606
Batches 5001-5200 in 0.50 sec, Training loss 0.0000124335
Batches 5201-5400 in 0.35 sec, Training loss 0.0000113094
Batches 5401-5600 in 1.40 sec, Training loss 0.0000103006
Batches 5601-5800 in 0.83 sec, Training loss 0.0000093835
Stop optimization as mean training loss 0.0000093835 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
Kept 485/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
Kept 1/485 unique fixed points with uniqueness tolerance 0.025.
fixed_points: {'V': array([[-1.19613916, -1.14106972]]), 'w': array([[-0.62017396, -0.55133715]])}
losses: [2.01213341e-24]

8D system
[6]:
analyze_net(num=4, gjw=0.1, Iext=bm.asarray([0., 0., 0., 0.6]))

Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
Batches 1-200 in 0.58 sec, Training loss 0.0002563626
Batches 201-400 in 0.63 sec, Training loss 0.0002068401
Batches 401-600 in 1.46 sec, Training loss 0.0001746516
Batches 601-800 in 0.53 sec, Training loss 0.0001505581
Batches 801-1000 in 0.34 sec, Training loss 0.0001312724
Batches 1001-1200 in 0.39 sec, Training loss 0.0001152650
Batches 1201-1400 in 0.57 sec, Training loss 0.0001017105
Batches 1401-1600 in 0.36 sec, Training loss 0.0000900637
Batches 1601-1800 in 1.23 sec, Training loss 0.0000799450
Batches 1801-2000 in 0.90 sec, Training loss 0.0000710771
Batches 2001-2200 in 0.58 sec, Training loss 0.0000632785
Batches 2201-2400 in 0.39 sec, Training loss 0.0000564160
Batches 2401-2600 in 1.61 sec, Training loss 0.0000503661
Batches 2601-2800 in 0.55 sec, Training loss 0.0000449990
Batches 2801-3000 in 0.51 sec, Training loss 0.0000402453
Batches 3001-3200 in 0.72 sec, Training loss 0.0000360206
Batches 3201-3400 in 1.17 sec, Training loss 0.0000322757
Batches 3401-3600 in 0.47 sec, Training loss 0.0000289431
Batches 3601-3800 in 0.44 sec, Training loss 0.0000259914
Batches 3801-4000 in 0.49 sec, Training loss 0.0000233745
Batches 4001-4200 in 1.22 sec, Training loss 0.0000210534
Batches 4201-4400 in 0.56 sec, Training loss 0.0000189919
Batches 4401-4600 in 0.30 sec, Training loss 0.0000171629
Batches 4601-4800 in 0.33 sec, Training loss 0.0000155418
Batches 4801-5000 in 0.46 sec, Training loss 0.0000141007
Batches 5001-5200 in 0.32 sec, Training loss 0.0000128158
Batches 5201-5400 in 0.51 sec, Training loss 0.0000116671
Batches 5401-5600 in 1.17 sec, Training loss 0.0000106446
Batches 5601-5800 in 0.57 sec, Training loss 0.0000097291
Stop optimization as mean training loss 0.0000097291 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
Kept 266/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
Kept 1/266 unique fixed points with uniqueness tolerance 0.025.
fixed_points: {'V': array([[-1.17905465, -1.19006061, -1.17904012, -0.81757202]]), 'w': array([[-0.59764936, -0.58889147, -0.59766085, -0.14496394]])}
losses: [5.78896779e-09]

Hénon map
The Hénon map is a discrete-time dynamical system. It is one of the most studied examples of dynamical systems that exhibit chaotic behavior. The Hénon map takes a point \((x_n, y_n)\) in the plane and maps it to a new point
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
class HenonMap(bp.DynamicalSystem):
"""Hénon map."""
def __init__(self, num, a=1.4, b=0.3):
super(HenonMap, self).__init__()
# parameters
self.a = a
self.b = b
self.num = num
# variables
self.x = bm.Variable(bm.zeros(num))
self.y = bm.Variable(bm.zeros(num))
def update(self, tdi):
x_new = 1 - self.a * self.x * self.x + self.y
self.y.value = self.b * self.x
self.x.value = x_new
[3]:
map = HenonMap(4)
map.a = bm.asarray([0.5, 1.0, 1.4, 2.0])
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
[5]:
fig, gs = bp.visualize.get_figure(4, 1, 4, 6)
for i in range(map.num):
fig.add_subplot(gs[i, 0])
plt.plot(runner.mon.x[:, i], runner.mon.y[:, i], '.k')
plt.xlabel('x')
plt.ylabel('y')
plt.title(f'a={map.a[i]}')
if (i + 1) == map.num:
plt.show()

The strange attractor illustrated above is obtained for \(a=1.4\) and \(b=0.3\).
[6]:
map = HenonMap(1)
map.a = 0.2
map.b = 0.9991
runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], ',k')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.show()

[7]:
map = HenonMap(1)
map.a = 0.2
map.b = -0.9999
runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], ',k')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Logistic map
The logistic map is a one-dimensional discrete time dynamical system that is defined by the equation
For an initial value \(0\leq x_{0}\leq1\) this map generates a sequence of values \(x_{0},x_{1},...x_{n},x_{n+1},...\) The growth parameter is chosen to be
which implies that for all \(n\) the state variable will remain bounded in the unit interval. Despite its simplicity this famous dynamical system can exhibit an unbelievable dynamic richness.
[6]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[7]:
class LogisticMap(bp.NeuGroup):
def __init__(self, num, mu0=2., mu1=4.):
super(LogisticMap, self).__init__(num)
self.mu = bm.linspace(mu0, mu1, num)
self.x = bm.Variable(bm.ones(num) * 0.2)
def update(self, tdi):
self.x.value = self.mu * self.x * ( 1- self.x)
[8]:
map = LogisticMap(10000, 2, 4)
[9]:
runner = bp.DSRunner(map, monitors=['x'], dt=1.)
runner.run(1100)
[10]:
plt.plot(map.mu, runner.mon.x[1000:].T, ',k', alpha=0.25)
plt.xlabel('mu')
plt.ylabel('x')
plt.show()

Lorenz system
The Lorenz system, originally intended as a simplified model of atmospheric convection, has instead become a standard example of sensitive dependence on initial conditions; that is, tiny differences in the starting condition for the system rapidly become magnified. The system also exhibits what is known as the “Lorenz attractor”, that is, the collection of trajectories for different starting points tends to approach a peculiar butterfly-shaped region.
The Lorenz system includes three ordinary differential equations:
dx/dt = sigma ( y - x )
dy/dt = x ( rho - z ) - y
dz/dt = xy - beta z
where the parameters beta, rho and sigma are usually assumed to be positive. The classic case uses the parameter values
beta = 8 / 3
rho = 28
sigma = 10
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
sigma = 10
beta = 8 / 3
rho = 28
[3]:
dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
[4]:
integral = bp.odeint(bp.JointEq([dx, dy, dz]), method='exp_auto')
[5]:
runner = bp.IntegratorRunner(integral,
monitors=['x', 'y', 'z'],
inits=dict(x=8, y=1, z=1),
dt=0.01)
runner.run(100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[6]:
fig = plt.figure()
fig.add_subplot(111, projection='3d')
plt.plot(runner.mon.x[100:, 0], runner.mon.y[100:, 0], runner.mon.z[100:, 0])
plt.show()

Mackey-Glass equation
The Mackey-Glass equation is the nonlinear time delay differential equation
where \(\beta, \gamma, \tau, n\) are real numbers, and \(x_{\tau}\) represents the value of the variable \(x\) at time \((t−\tau)\). Depending on the values of the parameters, this equation displays a range of periodic and chaotic dynamics.
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
bm.set_dt(0.05)
[3]:
class MackeyGlassEq(bp.NeuGroup):
def __init__(self, num, beta=2., gamma=1., tau=2., n=9.65):
super(MackeyGlassEq, self).__init__(num)
# parameters
self.beta = beta
self.gamma = gamma
self.tau = tau
self.n = n
self.delay_len = int(tau/bm.get_dt())
# variables
self.x = bm.Variable(bm.zeros(num))
self.x_delay = bm.LengthDelay(
self.x,
delay_len=self.delay_len,
initial_delay_data=lambda sh, dtype: 1.2+0.2*(bm.random.random(sh)-0.5)
)
self.x_oldest = bm.Variable(self.x_delay(self.delay_len))
# functions
self.integral = bp.odeint(lambda x, t, x_tau: self.beta * x_tau / (1 + x_tau ** self.n) - self.gamma * x,
method='exp_auto')
def update(self, tdi):
self.x.value = self.integral(self.x.value, tdi.t, self.x_oldest.value, tdi.dt)
self.x_delay.update(self.x.value)
self.x_oldest.value = self.x_delay(self.delay_len)
[4]:
eq = MackeyGlassEq(1, beta=0.2, gamma=0.1, tau=17, n=10)
# eq = MackeyGlassEq(1, )
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
runner = bp.DSRunner(eq, monitors=['x', 'x_oldest'])
runner.run(1000)
[6]:
plt.plot(runner.mon.x[1000:, 0], runner.mon.x_oldest[1000:, 0])
plt.show()

Multiscroll chaotic attractor (多卷波混沌吸引子)
Multiscroll attractors also called n-scroll attractor include the Chen attractor, the Lu Chen attractor, the modified Chen chaotic attractor, PWL Duffing attractor, Rabinovich Fabrikant attractor, modified Chua chaotic attractor, that is, multiple scrolls in a single attractor.
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
def run_and_visualize(runner, duration=100, dim=3, args=None):
assert dim in [3, 2]
if args is None:
runner.run(duration)
else:
runner.run(duration, args=args)
if dim == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(runner.mon.x.shape[1]):
plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i], runner.mon.z[100:, i])
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
else:
for i in range(runner.mon.x.shape[1]):
plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i])
plt.xlabel('x')
plt.xlabel('y')
plt.show()
Chen attractor
The Chen system is defined as follows [1]
[3]:
@bp.odeint(method='euler')
def chen_system(x, y, z, t, a=40, b=3, c=28):
dx = a * (y - x)
dy = (c - a) * x - x * z + c * y
dz = x * y - b * z
return dx, dy, dz
[4]:
runner = bp.IntegratorRunner(
chen_system,
monitors=['x', 'y', 'z'],
inits=dict(x=-0.1, y=0.5, z=-0.6),
dt=0.001
)
run_and_visualize(runner, 100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Lu Chen attractor
An extended Chen system with multiscroll was proposed by Jinhu Lu (吕金虎) and Guanrong Chen [2].
当 u ≤-11 时,Lṻ Chen 混沌吸引子为左卷波混沌吸引子,
当u 在 -10 和 10 之间 时为麻花型吸引子,
当 u≥ 11 ,是右卷波混沌吸引子。
[5]:
@bp.odeint(method='rk4')
def lu_chen_system(x, y, z, t, a=36, c=20, b=3, u=-15.15):
dx = a * (y - x)
dy = x - x * z + c * y + u
dz = x * y - b * z
return dx, dy, dz
[6]:
runner = bp.IntegratorRunner(
lu_chen_system,
monitors=['x', 'y', 'z'],
inits=dict(x=0.1, y=0.3, z=-0.6),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=-15.15),)
runner = bp.IntegratorRunner(
lu_chen_system,
monitors=['x', 'y', 'z'],
inits=dict(x=0.1, y=0.3, z=-0.6),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=-8),)


[7]:
runner = bp.IntegratorRunner(
lu_chen_system,
monitors=['x', 'y', 'z'],
inits=dict(x=0.1, y=0.3, z=-0.6),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=11),)

[8]:
runner = bp.IntegratorRunner(
lu_chen_system,
monitors=['x', 'y', 'z'],
inits=dict(x=0.1, y=0.3, z=-0.6),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=12),)

Modified Lu Chen attractor
System equations:
where
[9]:
class ModifiedLuChenSystem(bp.dyn.NeuGroup):
def __init__(self, num, a=35, b=3, c=28, d0=1, d1=1, d2=0., tau=.2, dt=0.1):
super(ModifiedLuChenSystem, self).__init__(num)
# parameters
self.a = a
self.b = b
self.c = c
self.d0 = d0
self.d1 = d1
self.d2 = d2
self.tau = tau
self.delay_len = int(tau/dt)
# variables
self.z = bm.Variable(bm.ones(num) * 14)
self.z_delay = bm.LengthDelay(self.z,
delay_len=self.delay_len,
initial_delay_data=14.)
self.x = bm.Variable(bm.ones(num))
self.y = bm.Variable(bm.ones(num))
# functions
def derivative(x, y, z, t, z_delay):
dx = self.a * (y - x)
f = self.d0 * z + self.d1 * z_delay - self.d2 * bm.sin(z_delay)
dy = (self.c - self.a) * x - x * f + self.c * y
dz = x * y - self.b * z
return dx, dy, dz
self.integral = bp.odeint(derivative, method='rk4')
def update(self, tdi):
self.x.value, self.y.value, self.z.value = self.integral(
self.x.value, self.y.value, self.z.value, tdi.t,
self.z_delay(self.delay_len), tdi.dt
)
self.z_delay.update(self.z)
[10]:
runner = bp.DSRunner(ModifiedLuChenSystem(1, dt=0.001),
monitors=['x', 'y', 'z'],
dt=0.001)
run_and_visualize(runner, 50)

Chua’s system
The classic Chua’s system is described by the following dimensionless equations [3]
where \(\alpha\), \(\beta\), and \(\gamma\) are constant parameters and
with \(a,b\) being the slopes of the inner and outer segments of \(f(x)\). Note that the parameterγis usually ignored; that is, \(\gamma=0\).
[11]:
@bp.odeint(method='rk4')
def chua_system(x, y, z, t, alpha=10, beta=14.514, gamma=0, a=-1.197, b=-0.6464):
fx = b * x + 0.5 * (a - b) * (bm.abs(x + 1) - bm.abs(x - 1))
dx = alpha * (y - x) - alpha * fx
dy = x - y + z
dz = -beta * y - gamma * z
return dx, dy, dz
[12]:
runner = bp.IntegratorRunner(
chua_system,
monitors=['x', 'y', 'z'],
inits=dict(x=0.001, y=0, z=0.),
dt=0.002
)
run_and_visualize(runner, 100)

[13]:
runner = bp.IntegratorRunner(
chua_system,
monitors=['x', 'y', 'z'],
inits=dict(x=9.4287, y=-0.5945, z=-13.4705),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(alpha=8.8, beta=12.0732, gamma=0.0052, a=-0.1768, b=-1.1468),)

[14]:
runner = bp.IntegratorRunner(
chua_system,
monitors=['x', 'y', 'z'],
inits=dict(x=[-6.0489, 6.0489],
y=[0.0839, -0.0839],
z=[8.7739, -8.7739]),
dt=0.002
)
run_and_visualize(runner, 100, args=dict(alpha=8.4562, beta=12.0732, gamma=0.0052, a=-0.1768, b=-1.1468),)

Modified Chua chaotic attractor
In 2001, Tang et al. proposed a modified Chua chaotic system:
where
[15]:
@bp.odeint(method='rk4')
def modified_chua_system(x, y, z, t, alpha=10.82, beta=14.286, a=1.3, b=.11, d=0):
dx = alpha * (y + b * bm.sin(bm.pi * x / 2 / a + d))
dy = x - y + z
dz = -beta * y
return dx, dy, dz
[16]:
runner = bp.IntegratorRunner(
modified_chua_system,
monitors=['x', 'y', 'z'],
inits=dict(x=1, y=1, z=0.),
dt=0.01
)
run_and_visualize(runner, 1000)

PWL Duffing chaotic attractor
Aziz Alaoui investigated PWL Duffing equation in 2000:
[17]:
@bp.odeint(method='rk4')
def PWL_duffing_eq(x, y, t, e=0.25, m0=-0.0845, m1=0.66, omega=1, i=-14):
gamma = 0.14 + i / 20
dx = y
dy = -m1 * x - (0.5 * (m0 - m1)) * (abs(x + 1) - abs(x - 1)) - e * y + gamma * bm.cos(omega * t)
return dx, dy
[18]:
runner = bp.IntegratorRunner(
PWL_duffing_eq,
monitors=['x', 'y'],
inits=dict(x=0, y=0),
dt=0.01
)
run_and_visualize(runner, 1000, dim=2)

Modified Lorenz chaotic system
Miranda & Stone proposed a modified Lorenz system:
[19]:
@bp.odeint(method='euler')
def modified_Lorenz(x, y, z, t, a=10, b=8 / 3, c=137 / 5):
temp = 3 * bm.sqrt(x * x + y * y)
dx = (-(a + 1) * x + a - c + z * y) / 3 + ((1 - a) * (x * x - y * y) + (2 * (a + c - z)) * x * y) / temp
dy = ((c - a - z) * x - (a + 1) * y) / 3 + (2 * (a - 1) * x * y + (a + c - z) * (x * x - y * y)) / temp
dz = (3 * x * x * y - y * y * y) / 2 - b * z
return dx, dy, dz
[20]:
runner = bp.IntegratorRunner(
modified_Lorenz,
monitors=['x', 'y', 'z'],
inits=dict(x=-8, y=4, z=10),
dt=0.001
)
run_and_visualize(runner, 100, dim=3)

References
[1] CHEN, GUANRONG; UETA, TETSUSHI (July 1999). “Yet Another Chaotic Attractor”. International Journal of Bifurcation and Chaos. 09 (7): 1465–1466.
[2] Chen, Guanrong; Jinhu Lu (2006). “Generating Multiscroll Chaotic Attractors: Theories, Methods and Applications” (PDF). International Journal of Bifurcation and Chaos. 16 (4): 775–858. Bibcode:2006IJBC…16..775L. doi:10.1142/s0218127406015179. Retrieved 2012-02-16.
[3] L. Fortuna, M. Frasca, and M. G. Xibilia, “Chuas circuit implementations: yesterday, today and tomorrow,” World Scientific, 2009.
Rabinovich-Fabrikant equations
The Rabinovich–Fabrikant equations are a set of three coupled ordinary differential equations exhibiting chaotic behaviour for certain values of the parameters. They are named after Mikhail Rabinovich and Anatoly Fabrikant, who described them in 1979.
where \(\alpha, \gamma\) are constants that control the evolution of the system.
https://zh.wikipedia.org/wiki/拉比诺维奇-法布里康特方程
https://en.wikipedia.org/wiki/Rabinovich%E2%80%93Fabrikant_equations
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
@bp.odeint(method='rk4')
def rf_eqs(x, y, z, t, alpha=1.1, gamma=0.803):
dx = y *(z-1+x*x) + gamma *x
dy = x *(3*z+1-x*x) + gamma *y
dz = -2*z*(alpha+x*y)
return dx, dy, dz
[3]:
def run_and_visualize(runner, duration=100, dim=3, args=None):
assert dim in [3, 2]
if args is None:
runner.run(duration)
else:
runner.run(duration, args=args)
if dim == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(runner.mon.x.shape[1]):
plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i], runner.mon.z[100:, i])
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
else:
for i in range(runner.mon.x.shape[1]):
plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i])
plt.xlabel('x')
plt.xlabel('y')
plt.show()
[4]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 100, args=dict(alpha=1.1, gamma=0.87),)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[5]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 100, args=dict(alpha=0.98, gamma=0.1),)

[6]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 300, args=dict(alpha=0.14, gamma=0.1),)

[7]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 50, dim=2, args=dict(alpha=0.05, gamma=0.1),)

[8]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 100, dim=2, args=dict(alpha=0.25, gamma=0.1),)

[9]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 100, dim=2, args=dict(alpha=1.1, gamma=0.86666666666666666667),)
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=-1, y=0, z=0.5),
dt=0.001
)
run_and_visualize(runner, 100, dim=3, args=dict(alpha=1.1, gamma=0.86666666666666666667),)


[10]:
runner = bp.IntegratorRunner(
rf_eqs,
monitors=['x', 'y', 'z'],
inits=dict(x=0.1, y=-0.1, z=0.1),
dt=0.001
)
run_and_visualize(runner, 60, dim=3, args=dict(alpha=0.05, gamma=0.1),)

Fractional-order Chaos Gallery
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
def plot_runner_3d(runner):
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
plt.plot(runner.mon.x.flatten(),
runner.mon.y.flatten(),
runner.mon.z.flatten())
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()
[3]:
def plot_runner_2d(runner):
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten())
plt.xlabel('x')
plt.ylabel('y')
plt.subplot(122)
plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
plt.xlabel('x')
plt.ylabel('z')
plt.show()
Fractional-order Chua system
The fractional-order Chua’s system is given by [1]
[4]:
a, b, c = 10.725, 10.593, 0.268
m0, m1 = -1.1726, -0.7872
def chua_system(x, y, z, t):
f = m1 * x + 0.5 * (m0 - m1) * (abs(x + 1) - abs(x - 1))
dx = a * (y - x - f)
dy = x - y + z
dz = -b * y - c * z
return dx, dy, dz
[5]:
dt = 0.001
duration = 200
inits = [0.2, -0.1, 0.1]
integrator = bp.fde.GLShortMemory(chua_system,
alpha=[0.93, 0.99, 0.92],
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[6]:
plot_runner_3d(runner)

[7]:
plot_runner_2d(runner)

Fractional-order Qi System
The fractional-order Qi chaotic system is given by
[8]:
a, b, c = 35, 8 / 3, 80
def qi_system(x, y, z, t):
dx = -a * x + a * y + y * z
dy = c * x - y - x * z
dz = -b * z + x * y
return dx, dy, dz
[9]:
dt = 0.001
duration = 200
inits = [0.1, 0.2, 0.3]
integrator = bp.fde.GLShortMemory(qi_system,
alpha=0.98,
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
[10]:
plot_runner_3d(runner)

[11]:
plot_runner_2d(runner)

Fractional-order Lorenz System
The fractional-order Lorenz system is given by [3]
[12]:
a, b, c = 10, 28, 8 / 3
def lorenz_system(x, y, z, t):
dx = a * (y - x)
dy = x * (b - z) - y
dz = x * y - c * z
return dx, dy, dz
[13]:
dt = 0.001
duration = 50
inits = [1, 2, 3]
integrator = bp.fde.GLShortMemory(lorenz_system,
alpha=0.99, # fractional order
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
[14]:
plot_runner_3d(runner)

[15]:
plot_runner_2d(runner)

Fractional-order Liu System
The fractional order system is given by [4]
[16]:
a, b, c, e, n, m = 1, 2.5, 5, 1, 4, 4
def liu_system(x, y, z, t):
dx = -a * x - e * y ** 2
dy = b * y - n * x * z
dz = -c * z + m * x * y
return dx, dy, dz
[17]:
dt = 0.001
duration = 100
inits = [0.2, 0, 0.5]
integrator = bp.fde.GLShortMemory(liu_system,
alpha=0.95, # fractional order
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
[18]:
plot_runner_3d(runner)

[19]:
plot_runner_2d(runner)

Fractional-order Chen System
The system is given by [5]
[20]:
a, b, c, d = 35, 3, 28, -7
def chen_system(x, y, z, t):
dx = a * (y - x)
dy = d * x - x * z + c * y
dz = x * y - b * z
return dx, dy, dz
[21]:
dt = 0.005
duration = 50
inits = [-9, -5, 14]
integrator = bp.fde.GLShortMemory(chen_system,
alpha=0.9, # fractional order
num_memory=1000,
inits=inits)
runner = bp.IntegratorRunner(integrator,
monitors=list('xyz'),
inits=inits,
dt=dt)
runner.run(duration)
[22]:
plot_runner_3d(runner)

[23]:
plot_runner_2d(runner)

References
[1] Petráš, Ivo. “A note on the fractional-order Chua’s system.” Chaos, Solitons & Fractals 38.1 (2008): 140-147.
[2] Wu, Xiangjun, and Yang Yang. “Chaos in the fractional-order Qi system and its synchronization using active control.” 2010 International Conference on Intelligent Computing and Cognitive Informatics. IEEE, 2010.
[3] Wu, Xiang-Jun, and Shi-Lei Shen. “Chaos in the fractional-order Lorenz system.” International Journal of Computer Mathematics 86.7 (2009): 1274-1282.
[4] Daftardar-Gejji, Varsha, and Sachin Bhalekar. “Chaos in fractional ordered Liu system.” Computers & mathematics with applications 59.3 (2010): 1117-1127.
[5] Lu, Jun Guo, and Guanrong Chen. “A note on the fractional-order Chen system.” Chaos, Solitons & Fractals 27.3 (2006): 685-688.
(Brette & Guigon, 2003): Reliability of spike timing
Implementation of the paper:
Brette R and E Guigon (2003). Reliability of Spike Timing Is a General Property of Spiking Model Neurons. Neural Computation 15, 279-308.
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
class ExternalInput(bp.DynamicalSystem):
def __init__(self, num):
super(ExternalInput, self).__init__()
# parameters
self.num = num
# variables
self.I = bm.Variable(bm.zeros(num))
[3]:
class LinearLeakyModel(bp.NeuGroup):
r"""A nonlinear leaky model.
.. math::
\tau \frac{d V}{d t}=-V+R I(t)
"""
def __init__(self, size, inputs):
super(LinearLeakyModel, self).__init__(size)
# parameters
self.tau = 33 # ms
self.R = 200 / 1e3 # Ω
self.Vt = 15 # mV
self.Vr = -5 # mV
self.sigma = 1.
self.inputs = inputs
# variables
self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
f = lambda V, t: (-V + self.R * self.inputs.I) / self.tau
g = lambda V, t: self.sigma * bm.sqrt(2. / self.tau)
self.integral = bp.sdeint(f=f, g=g)
def update(self, tdi):
self.inputs.update(tdi)
self.V.value = self.integral(self.V, tdi.t, tdi.dt)
self.spike.value = self.V >= self.Vt
self.V.value = bm.where(self.spike, self.Vr, self.V)
[4]:
class NonlinearLeakyModel(bp.NeuGroup):
r"""A nonlinear leaky model.
.. math::
\tau \frac{d V}{d t}=-a V^{3}+R I(t)
"""
def __init__(self, size, inputs):
super(NonlinearLeakyModel, self).__init__(size)
# parameters
self.tau = 33 # ms
self.R = 200 / 1e3 # MΩ
self.Vt = 15 # mV
self.Vr = -5 # mV
self.a = 4444 / 1e6 # V^(-2)
self.sigma = 1.
self.inputs = inputs
# variables
self.V = bm.Variable(bm.zeros(self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
f = lambda V, t: (-self.a * V ** 3 + self.R * self.inputs.I) / self.tau
g = lambda V, t: self.sigma * bm.sqrt(2. / self.tau)
self.integral = bp.sdeint(f=f, g=g)
def update(self, tdi):
self.inputs.update(tdi)
self.V.value = self.integral(self.V, tdi.t, tdi.dt)
self.spike.value = self.V >= self.Vt
self.V.value = bm.where(self.spike, self.Vr, self.V)
[5]:
class NonLeakyModel(bp.NeuGroup):
r"""A non-leaky model.
.. math::
\tau \frac{d V}{d t}=V I(t)+k
"""
def __init__(self, size, inputs):
super(NonLeakyModel, self).__init__(size)
# parameters
self.Vt = 1
self.Vr = 0
self.tau = 33 # ms
self.k = 1
self.sigma = 0.02
self.inputs = inputs
# variables
self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
f = lambda V, t: (V * self.inputs.I + self.k) / self.tau
g = lambda V, t: self.sigma * (2 / self.tau) ** .5
self.integral = bp.sdeint(f, g, method='euler')
def update(self, tdi):
self.inputs.update(tdi)
self.V.value = self.integral(self.V, tdi.t, tdi.dt)
self.spike.value = self.V >= self.Vt
self.V.value = bm.where(self.spike, self.Vr, self.V)
[6]:
class PerfectIntegrator(bp.NeuGroup):
r"""Integrate inputs.
.. math::
\tau \frac{d V}{d t}=f(t)
where :math:`f(t)` is an input function.
"""
def __init__(self, size, inputs):
super(PerfectIntegrator, self).__init__(size)
# parameters
self.tau = 12.5 # ms
self.Vt = 1
self.Vr = 0
self.sigma = 0.27
self.inputs = inputs
# variables
self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
# functions
f = lambda V, t: self.inputs.I / self.tau
g = lambda V, t: self.sigma * (2 / self.tau) ** .5
self.integral = bp.sdeint(f, g, method='euler')
def update(self, tdi):
self.inputs.update(tdi)
self.V.value = self.integral(self.V, tdi.t, tdi.dt)
self.spike.value = self.V >= self.Vt
self.V.value = bm.where(self.spike, self.Vr, self.V)
[7]:
def figure6(sigma=0.):
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.p = bm.linspace(0., 1., num)
# variables
self.B = bm.Variable(bm.zeros(1))
def update(self, tdi):
self.B[:] = 30 * bm.sin(40 * bm.pi * tdi.t / 1000) # (1,)
self.I.value = 85 + 40 * (1 - self.p) + self.p * self.B # (num,)
num = 500
model = LinearLeakyModel(num, inputs=Noise(num))
model.sigma = sigma
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(2000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 12)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[8]:
figure6(0.)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

[9]:
figure6(1.)

[10]:
def figure7(sigma):
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.p = bm.linspace(0., 1., num)
# variables
self.B = bm.Variable(bm.zeros(1))
def update(self, tdi):
self.B[:] = 150 * bm.sin(40 * bm.pi * tdi.t / 1000) # (1,)
self.I.value = 150 + self.p * self.B # (num,)
num = 500
model = NonlinearLeakyModel(num, inputs=Noise(num))
model.sigma = sigma
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(1000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[11]:
figure7(1.5)

[12]:
def figure8(sigma):
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.p = bm.linspace(0., 1., num)
# variables
self.B = bm.Variable(bm.zeros(1))
def update(self, tdi):
self.B[:] = 150 * bm.sin(40 * bm.pi * tdi.t / 1000) # (1,)
self.I.value = 150 + self.p * self.B # (num,)
num = 500
model = NonlinearLeakyModel(num, inputs=Noise(num))
model.V.value = bm.random.uniform(model.Vr, model.Vt, model.num)
model.sigma = sigma
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(1000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[13]:
figure8(1.5)

[14]:
def figure9(sigma):
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.p = bm.linspace(0., 1., num)
# variables
self.B = bm.Variable(bm.zeros(1))
def update(self, tdi):
self.B[:] = bm.sin(40 * bm.pi * tdi.t / 1000) # (1,)
self.I.value = 112.5 + (75 + 37.5 * self.p) * self.B + 37.5 * self.p # (num,)
num = 500
# B should be a triangular wave taking values between -1 and 1, with
# rising and falling time drawn uniformly between 10 ms and 50 ms.
# While we use sin wave instead.
model = LinearLeakyModel(num, inputs=Noise(num))
model.sigma = sigma
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(2000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 12)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[15]:
figure9(0.5)

[16]:
def figure10():
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.tau = 10 # ms
self.p = bm.linspace(0., 1., num)
# variables
self.x = bm.Variable(bm.zeros(1))
self.B = bm.Variable(bm.zeros(1))
# functions
f = lambda x, t: -x / self.tau
g = lambda x, t: bm.sqrt(2. / self.tau)
self.integral = bp.sdeint(f=f, g=g)
def update(self, tdi):
self.x.value = self.integral(self.x, tdi.t, tdi.dt) # (1,)
self.B.value = 2. / (1 + bm.exp(-2 * self.x)) - 1 # (1,)
self.I.value = 0.5 + 3 * self.p * self.B # (num,)
num = 500
model = NonLeakyModel(num, inputs=Noise(num))
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(1000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[17]:
def figure11(sigma):
class Noise(ExternalInput):
def __init__(self, num):
super(Noise, self).__init__(num)
# parameters
self.tau = 10 # ms
self.p = bm.linspace(0., 1., num)
# variables
self.x = bm.Variable(bm.zeros(1))
self.B = bm.Variable(bm.zeros(1))
# functions
f = lambda x, t: -x / self.tau
g = lambda x, t: bm.sqrt(2. / self.tau)
self.integral = bp.sdeint(f=f, g=g)
def update(self, tdi):
self.x.value = self.integral(self.x, tdi.t, tdi.dt) # (1,)
self.B.value = 2. / (1 + bm.exp(-2 * self.x)) - 1 # (1,)
self.I.value = 0.5 + self.p * self.B # (num,)
num = 500
model = PerfectIntegrator(num, inputs=Noise(num))
model.sigma = sigma
runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
runner.run(1000.)
fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
plt.show()
[18]:
figure11(0.)

[19]:
figure11(0.27)
