BrainPy Examples

This repository contains examples of using BrainPy to implement various models about neurons, synapse, networks, etc. We welcome your implementation, which can be post through our github page.

If you run some codes failed, please tell us through github issue: https://github.com/PKU-NIP-Lab/BrainPyExamples/issues

If you found these examples are useful for your research, please kindly cite us.

If you want to add more examples, please fork our github: https://github.com/PKU-NIP-Lab/BrainPyExamples

(Izhikevich, 2003) Izhikevich Model

@Chaoming Wang @Xinyu Liu

1. Model Overview

Izhikevich neuron model reproduces spiking and bursting behavior of known types of cortical neurons. This model combines the biologically plausibility of Hodgkin–Huxley-type dynamics (HH model) and the computational efficiency of integrate-and-fire neurons (LIF model) .

Bifurcation methodologies enable us to reduce many biophysically accurate Hodgkin–Huxley-type neuronal models to a two-dimensional system of ordinary differential equations of the form:

\[\begin{split}\begin{array}{l} \frac{d v}{d t}=0.04 v^{2}+5 v+140-u+I\\ \frac{d u}{d t}=a(b v-u) \end{array}\end{split}\]

with the auxiliary after-spike resetting:

\[\begin{split}\text { if } v \geq 30 \mathrm{mV}, \text { then }\left\{\begin{array}{l} v \leftarrow c \\ u \leftarrow u+d . \end{array}\right.\end{split}\]

\(v\) represents the membrane potential of the neuron.

\(u\) represents a membrane recovery variable, which accounts for the activation of \(\mathrm{K}^{+}\) ionic currents and inactivation of \(\mathrm{Na}^{+}\) ionic currents, and it provides negative feedback to \(v\) .

About the parameter ( a, b, c, d ) :

a: The parameter \(a\) describes the time scale of the recovery variable \(u\). Smaller values result in slower recovery. A typical value is \(a = 0.02\).

b: The parameter \(b\) describes the sensitivity of the recovery variable \(u\) to the subthreshold fluctuations of the membrane potential \(v\) and depend the resting potential in the model (\(60-70 mV\)). Greater values couple \(v\) and \(u\) more strongly resulting in possible subthreshold oscillations and low-threshold spiking dynamics. A typical value is \(b = 0.2\).

c: The parameter \(c\) describes the after-spike reset value of the membrane potential \(v\) caused by the fast high-threshold \(\mathrm{K}^{+}\) conductances. A typical value is \(c = 65 mV\).

d: The parameter \(d\) describes after-spike reset of the recovery variable \(u\) caused by slow high-threshold \(\mathrm{Na}^{+}\) and \(\mathrm{K}^{+}\) conductances. A typical value is \(d = 2\).

The threshold value of the model neuron is between \(–70mV\) and \(-50mV\), and it is dynamic, as in biological neurons.

[1]:
import brainpy as bp

import matplotlib.pyplot as plt

Summary of the neuro-computational properties of biological spiking neurons:

image0

The model can exhibit firing patterns of all known types of cortical neurons with the choice of parameters \(a,b,c,d\) and given below.

2. Different firing patterns

The following interpretation of the most prominent features of biological spiking neurons is based on reference [1].

Neuro-computational properties

a

b

c

d

Tonic Spiking

0.02

0.04

-65

2

Phasic Spiking

0.02

0.25

-65

6

Tonic Bursting

0.02

0.2

-50

2

Phasic Bursting

0.02

0.25

-55

0.05

Mixed Model

0.02

0.2

-55

4

Spike Frequency Adaptation

0.01

0.2

-65

8

Class 1 Excitability

0.02

-0.1

-55

6

Class 2 Excitability

0.2

0.26

-65

0

Spike Latency

0.02

0.2

-65

6

Subthreshold Oscillations

0.05

0.26

-60

0

Resonator

0.1

0.26

-60

-1

Integrator

0.02

-0.1

-55

6

Rebound Spike

0.03

0.25

-60

4

Rebound Burst

0.03

0.25

-52

0

Threshold Variability

0.03

0.25

-60

4

Bistability

1

1.5

-60

0

Depolarizing After-Potentials

1

0.2

-60

-21

Accommodation

0.02

1

-55

4

Inhibition-Induced Spiking

-0.02

-1

-60

8

Inhibition-Induced Bursting

-0.026

-1

-45

0

The table above gives the value of the parameter \(a,b,c,d\) under 20 types of firing patterns.

Tonic Spiking

While the inputison, the neuron continues to fire a train of spikes. This kind of behavior, called tonic spiking, can be observed in the three types of cortical neurons: regular spiking (RS) excitatory neurons,low threshold spiking(LTS),and fast spiking (FS) inhibitory neurons.Continuous firing of such neurons indicate that there is a persistent input.

[2]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.40, -65.0, 2.0

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Tonic Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')

ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/neurons_Izhikevich_2003_Izhikevich_model_14_2.png

Phasic Spiking

A neuron may fire only a single spike at the onset of the input, and remain quiescent afterwards. Such a response is called phasic spiking, and it is useful for detection of the beginning of stimulation.

[3]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65.0, 6.0

current = bp.inputs.section_input(values=[0., 1.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Phasic Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_16_1.png

Tonic Bursting

Some neurons, such as the chattering neurons in cat neocortex, fire periodic bursts of spikes when stimulated. The interburst frequency may be as high as 50 Hz, and it is believed that such neurons contribute to the gamma-frequency oscillations in the brain.

[4]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -50.0, 2.0

current = bp.inputs.section_input(values=[0., 15.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Tonic Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_18_1.png

Phasic Bursting

Similarly to the phasic spikers, some neurons are phasic bursters. Such neurons report the beginning of the stimulation by transmitting a burst.

[5]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -55.0, 0.05

current = bp.inputs.section_input(values=[0., 1.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Phasic Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()

plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_20_1.png

Mixed Mode

Intrinsically bursting (IB) excitatory neurons in mammalian neocortex can exhibit a mixed type of spiking activity. They fire a phasic burst at the onset of stimulation and then switch to the tonic spiking mode.

[6]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -55.0, 4.0

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Mixed Mode')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_22_1.png

Spike Frequency Adaptation

The most common type of excitatory neuron in mammalian neocortex, namely the regular spiking (RS) cell, fires tonic spikes with decreasing frequency. That is, the frequency is relatively high at the onset of stimulation, and then it adapts. Low-threshold spiking (LTS) inhibitory neurons also have this property.

[7]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.01, 0.20, -65.0, 8.0

current = bp.inputs.section_input(values=[0., 30.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Spike Frequency Adaptation')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_24_1.png

Class 1 Excitability

The frequency of tonic spiking of neocortical RS excitatory neurons depends on the strength of the input, and it may span the range from 2 Hz to 200 Hz, or even greater. The ability to fire low-frequency spikes when the input is weak (but superthreshold) is called Class 1 excitability.

[8]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, -0.1, -55.0, 6.0

current = bp.inputs.ramp_input(c_start=0., c_end=80., t_start=50., t_end=200., duration=250)
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=250.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Class 1')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 250.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 150)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_26_1.png

Class 2 Excitability

Some neurons cannot fire low-frequency spike trains. That is, they are either quiescent or fire a train of spikes with a certain relatively large frequency, say 40 Hz. Such neurons are called Class 2 excitable.

[9]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.20, 0.26, -65.0, 0.0

current = bp.inputs.ramp_input(c_start=0., c_end=10., t_start=50., t_end=200., duration=250)
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=250.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Class 2')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 250.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_28_1.png

Spike Latency

Most cortical neurons fire spikes with a delay that depends on the strength of the input signal. For a relatively weak but superthreshold input, the delay, also called spike latency, can be quite large.

[10]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.20, -65.0, 6.0

current = bp.inputs.section_input(values=[0., 50., 0.], durations=[15, 1, 15])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=31.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Spike Latency')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 24.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_30_1.png

Subthreshold Oscillations

Practically every brain structure has neurons capable of exhibiting oscillatory potentials. The frequency of such oscillations play an important role and such neurons act as bandpass filters.

[11]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.05, 0.26, -60.0, 0.0

current = bp.inputs.section_input(values=[0., 50., 0.], durations=[15, 1, 200])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=216.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Subthreshold Oscillation')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.tick_params('y', colors='b')
ax1.set_xlim(-0.1, 216.1)
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_32_1.png

Rebound Spike

When a neuron receives and then is released from an inhibitory input, it may fire a post-inhibitory (rebound) spike. This phenomenon is related to the anodal break excitation in excitable membranes. Many spiking neurons can fire in response to brief inhibitory inputs thereby blurring the difference between excitation and inhibition.

[12]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.03, 0.25, -60.0, 4.0

current = bp.inputs.section_input(values=[7., 0., 7.], durations=[10, 5, 40])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=55.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Rebound Spike')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 55.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 70)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_34_1.png

Depolarizing After-potentials

After firing a spike, the membrane potential of a neuron may exhibit a prolonged after-hyperpolarization (AHP)),or a prolonged depolarized after-potential (DAP). Such DAPs can appear because of dendritic influence, because of a high-threshold inward currents activated during the spike, or because of an interplay between subthreshold voltage-gated currents.

[13]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 1.00, 0.20, -60.0, -21.0

current = bp.inputs.section_input(values=[0., 23, 0], durations=[7, 1, 50])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=58.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Depolarizing After-Potentials')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 58.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 100)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_36_1.png

Resonator

Due to the resonance phenomenon, neurons having oscillatory potentials can respond selectively to the inputs having frequency content similar to the frequency of subthreshold oscillations. Such neurons can implement frequency-modulated (FM) interactions and multiplexing of signal.

[14]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.10, 0.26, -60.0, -1.0

current = bp.inputs.section_input(values=[-1, 0., -1, 0, -1, 0, -1, 0, -1],
                                  durations=[10, 10, 10, 10, 100, 10, 30, 20, 30])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=230.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Resonator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 230.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(-2, 5)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_38_1.png

Integrator

Neurons without oscillatory potentials act as integrators. They prefer high-frequency input; the higher the frequency the more likely they fire.

[15]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, -0.1, -55.0, 6.0

current = bp.inputs.section_input(values=[0, 48, 0, 48, 0, 48, 0, 48, 0],
                                  durations=[19, 1, 1, 1, 28, 1, 1, 1, 56])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=109.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Integrator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 109.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 200)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_40_1.png

Threshold Variability

Biological neurons have a variable threshold that depends on the prior activity of the neurons.

[16]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.03, 0.25, -60.0, 4.0

current = bp.inputs.section_input(values=[0, 5, 0, -5, 0, 5, 0],
                                  durations=[13, 3, 78, 2, 2, 3, 13])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=114.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Threshold Variability')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 114.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(-6, 20)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_42_1.png

Bistability

Some neurons can exhibit two stable modes of operation: resting and tonic spiking (or even bursting). An excitatory or inhibitory pulse can switch between the modes, thereby creating an interesting possibility for bistability and short-term memory.

[17]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 1.00, 1.50, -60.0, 0.0

current = bp.inputs.section_input(values=[0., 5., 0, 5, 0], durations=[10, 1, 10, 1, 10])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=32.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Bistability')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 32.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_44_1.png

3. Different types of neurons

The table above gives the value of the parameter \(a,b,c,d\) under 7 types of neurons.

Neuron

a

b

c

d

Regular Spiking (RS)

0.02

0.2

-65

8

Intrinsically Bursting (IB)

0.02

0.2

-55

4

Chattering (CH)

0.02

0.2

-50

2

Fast Spiking (FS)

0.1

0.2

-65

2

Thalamo-cortical (TC)

0.02

0.25

-65

0.05

Resonator (RZ)

0.1

0.26

-65

2

Low-threshold Spiking (LTS)

0.02

0.25

-65

2

Regular Spiking (RS)

[18]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -65, 8

current = bp.inputs.section_input(values=[0., 15.], durations=[50, 250])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=300.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Regular Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 300.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_49_1.png

Intrinsically Bursting (IB)

[19]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -55, 4

current = bp.inputs.section_input(values=[0., 15.], durations=[50, 250])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=300.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Intrinsically Bursting')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 300.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_51_1.png

Chattering (CH)

[20]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.2, -50, 2

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 350])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=400.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Chattering')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 400.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_53_1.png

Fast Spiking (FS)

[21]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.1, 0.2, -65, 2

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Fast Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_55_1.png

Thalamo-cortical (TC)

[22]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65, 0.05

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 100])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=150.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Thalamo-cortical')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 150.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_57_1.png

Resonator (RZ)

[23]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.1, 0.26, -65, 2

current = bp.inputs.section_input(values=[0., 5.], durations=[100, 300])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=400.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Resonator')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 400.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()
plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_59_1.png

Low-threshold Spiking (LTS)

[24]:
neu = bp.neurons.Izhikevich(1)
neu.a, neu.b, neu.c, neu.d = 0.02, 0.25, -65, 2

current = bp.inputs.section_input(values=[0., 10.], durations=[50, 150])
runner = bp.DSRunner(neu, inputs=['input', current, 'iter'], monitors=['V', 'u'])
runner.run(duration=200.)

fig, ax1 = plt.subplots(figsize=(10, 5))
plt.title('Low-threshold Spiking')
ax1.plot(runner.mon.ts, runner.mon.V[:, 0], 'b', label='V')
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Membrane potential (mV)', color='b')
ax1.set_xlim(-0.1, 200.1)
ax1.tick_params('y', colors='b')
ax2 = ax1.twinx()
ax2.plot(runner.mon.ts, current, 'r', label='Input')
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Input (mV)', color='r')
ax2.set_ylim(0, 50)
ax2.tick_params('y', colors='r')
ax1.legend(loc=1)
ax2.legend(loc=3)
fig.tight_layout()

plt.show()
_images/neurons_Izhikevich_2003_Izhikevich_model_61_1.png

References

[1]. Izhikevich, E. M . Which model to use for cortical spiking neurons?[J]. IEEE Trans Neural Netw, 2004, 15(5):1063-1070.

[2]. Izhikevich E M . Simple model of spiking neurons[J]. IEEE Transactions on Neural Networks, 2003, 14(6):1569-72.

(Brette, Romain. 2004) LIF phase locking

Implementation of the paper:

  • Brette, Romain. “Dynamics of one-dimensional spiking neuron models.” Journal of mathematical biology 48.1 (2004): 38-56.

Author:

[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
import matplotlib.pyplot as plt
[3]:
# set parameters
num = 2000
tau = 100.  # ms
Vth = 1.  # mV
Vr = 0.  # mV
inputs = bm.linspace(2., 4., num)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
class LIF(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(LIF, self).__init__(size, **kwargs)

    self.V = bm.Variable(bm.zeros(size))
    self.spike = bm.Variable(bm.zeros(size, dtype=bool))
    self.integral = bp.odeint(self.derivative)

  def derivative(self, V, t):
    return (-V + inputs + 2 * bm.sin(2 * bm.pi * t / tau)) / tau

  def update(self, tdi):
    V = self.integral(self.V, tdi.t, tdi.dt)
    self.spike.value = V >= Vth
    self.V.value = bm.where(self.spike > 0., Vr, V)
[5]:
group = LIF(num)
runner = bp.DSRunner(group, monitors=['spike'])
[6]:
t = runner.run(duration=5 * 1000.)

indices, times = bp.measure.raster_plot(runner.mon.spike, runner.mon.ts)

# plt.plot((times % tau) / tau, inputs[indices], ',')

spike_phases = (times % tau) / tau
params = inputs[indices]
plt.scatter(x=spike_phases, y=params, c=spike_phases,
            marker=',', s=0.1, cmap="coolwarm")

plt.xlabel('Spike phase')
plt.ylabel('Parameter (input)')
plt.show()
_images/neurons_Romain_2004_LIF_phase_locking_7_1.png

(Gerstner, 2005): Adaptive Exponential Integrate-and-Fire model

Adaptive Exponential Integrate-and-Fire neuron model is a spiking model, describes single neuron behavior and can generate many kinds of firing patterns by tuning parameters.

[1]:
import brainpy as bp

bp.math.set_dt(0.01)
bp.math.enable_x64()
bp.math.set_platform('cpu')

Tonic

[2]:
group = bp.neurons.AdExIF(size=1, a=0., b=60., R=.5, delta_T=2., tau=20., tau_w=30.,
                          V_reset=-55., V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', title='tonic')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_4_1.png

Adapting

[3]:
group = bp.neurons.AdExIF(size=1, a=0., b=5., R=.5, tau=20., tau_w=100., delta_T=2.,
                          V_reset=-55., V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(200.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', title='adapting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_6_1.png

Init Bursting

[4]:
group = bp.neurons.AdExIF(size=1, a=.5, b=7., R=.5, tau=5., tau_w=100., delta_T=2.,
                          V_reset=-51, V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(300.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-55., -35.), title='init_bursting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_8_1.png

Bursting

[5]:
group = bp.neurons.AdExIF(size=1, a=-0.5, b=7., R=.5, delta_T=2., tau=5, tau_w=100,
                          V_reset=-46, V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 65.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='bursting')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_10_1.png

Transient

[6]:
group = bp.neurons.AdExIF(size=1, a=1., b=10., R=.5, tau=10, tau_w=100, delta_T=2.,
                          V_reset=-60, V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 55.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='transient')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_12_1.png

Delayed

[7]:
group = bp.neurons.AdExIF(size=1, a=-1., b=10., R=.5, delta_T=2., tau=5., tau_w=100.,
                          V_reset=-60, V_rest=-70, V_th=-30, V_T=-50)

runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 20.))
runner.run(500.)
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', ylim=(-60., -35.), title='delayed')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True)
_images/neurons_Gerstner_2005_AdExIF_model_14_1.png

(Niebur, et. al, 2009) Generalized integrate-and-fire model

Implementation of the paper: Mihalaş, Ştefan, and Ernst Niebur. “A generalized linear integrate-and-fire neural model produces diverse spiking behaviors.” Neural computation 21.3 (2009): 704-718.

[1]:
import matplotlib.pyplot as plt
import brainpy as bp

Model Overview

Generalized integrate-and-fire model is a spiking neuron model, describes single neuron behavior and can generate most kinds of firing patterns by tuning parameters.

Generalized IF model is originated from Leaky Integrate-and-Fire model (LIF model), yet it’s differentiated from LIF model, for it includes internal currents \(I_j\) in its expressions.

\[\frac{d I_j}{d t} = -k_j I_j\]
\[\tau\frac{d V}{d t} = - (V - V_{rest}) + R\sum_{j}I_j + RI\]
\[\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty})\]

Generalized IF neuron fire when \(V\) meet \(V_{th}\):

\[I_j \leftarrow R_j I_j + A_j\]
\[V \leftarrow V_{reset}\]
\[V_{th} \leftarrow max(V_{th_{reset}}, V_{th})\]

Different firing patterns

These arbitrary number of internal currents \(I_j\) can be seen as currents caused by ion channels’ dynamics, provides the GeneralizedIF model a flexibility to generate various firing patterns.

With appropriate parameters, we can reproduce most of the single neuron firing patterns. In the original paper (Mihalaş et al., 2009), the author used two internal currents \(I1\) and \(I2\).

[2]:
def run(model, duration, I_ext):
  runner = bp.DSRunner(model,
                       inputs=('input', I_ext, 'iter'),
                       monitors=['V', 'V_th'])
  runner.run(duration)

  ts = runner.mon.ts
  fig, gs = bp.visualize.get_figure(1, 1, 4, 8)
  ax1 = fig.add_subplot(gs[0, 0])
  #ax1.title.set_text(f'{mode}')

  ax1.plot(ts, runner.mon.V[:, 0], label='V')
  ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th')
  ax1.set_xlabel('Time (ms)')
  ax1.set_ylabel('Membrane potential')
  ax1.set_xlim(-0.1, ts[-1] + 0.1)
  plt.legend()

  ax2 = ax1.twinx()
  ax2.plot(ts, I_ext, color='turquoise', label='input')
  ax2.set_xlabel('Time (ms)')
  ax2.set_ylabel('External input')
  ax2.set_xlim(-0.1, ts[-1] + 0.1)
  ax2.set_ylim(-5., 20.)
  plt.legend(loc='lower left')
  plt.show()

Simulate Generalized IF neuron groups to generate different spiking patterns. Here we plot 20 spiking patterns in groups of 4. The plots are labeled with corresponding pattern names above the plots.

Tonic Spiking

[3]:
Iext, duration = bp.inputs.constant_input([(1.5, 200.)])
neu = bp.neurons.GIF(1)
run(neu, duration, Iext)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/neurons_Niebur_2009_GIF_13_2.png

Class 1 Excitability

[4]:
Iext, duration = bp.inputs.constant_input([(1. + 1e-6, 500.)])
neu = bp.neurons.GIF(1)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_15_1.png

Spike Frequency Adaptation

[5]:
Iext, duration = bp.inputs.constant_input([(2., 200.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_17_1.png

Phasic Spiking

[6]:
Iext, duration = bp.inputs.constant_input([(1.5, 500.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_19_1.png

Accomodation

[7]:
Iext, duration = bp.inputs.constant_input([(1.5, 100.),
                                           (0, 500.),
                                           (0.5, 100.),
                                           (1., 100.),
                                           (1.5, 100.),
                                           (0., 100.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_21_1.png

Threshold Variability

[8]:
Iext, duration = bp.inputs.constant_input([(1.5, 20.),
                                           (0., 180.),
                                           (-1.5, 20.),
                                           (0., 20.),
                                           (1.5, 20.),
                                           (0., 140.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_23_1.png

Rebound Spiking

[9]:
Iext, duration = bp.inputs.constant_input([(0, 50.), (-3.5, 750.), (0., 200.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_25_1.png

Class 2 Excitability

[10]:
Iext, duration = bp.inputs.constant_input([(2 * (1. + 1e-6), 200.)])
neu = bp.neurons.GIF(1, a=0.005)
neu.V_th[:] = -30.
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_27_1.png

Integrator

[11]:
Iext, duration = bp.inputs.constant_input([(1.5, 20.),
                                           (0., 10.),
                                           (1.5, 20.),
                                           (0., 250.),
                                           (1.5, 20.),
                                           (0., 30.),
                                           (1.5, 20.),
                                           (0., 30.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_29_1.png

Input Bistability

[12]:
Iext, duration = bp.inputs.constant_input([(1.5, 100.),
                                           (1.7, 400.),
                                           (1.5, 100.),
                                           (1.7, 400.)])
neu = bp.neurons.GIF(1, a=0.005)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_31_1.png

Hyperpolarization-induced Spiking

[13]:
Iext, duration = bp.inputs.constant_input([(-1., 400.)])
neu = bp.neurons.GIF(1, V_th_reset=-60., V_th_inf=-120.)
neu.V_th[:] = -50.
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_33_1.png

Hyperpolarization-induced Bursting

[14]:
Iext, duration = bp.inputs.constant_input([(-1., 400.)])
neu = bp.neurons.GIF(1, V_th_reset=-60., V_th_inf=-120., A1=10.,
                 A2=-0.6)
neu.V_th[:] = -50.
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_35_1.png

Tonic Bursting

[15]:
Iext, duration = bp.inputs.constant_input([(2., 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_37_1.png

Phasic Bursting

[16]:
Iext, duration = bp.inputs.constant_input([(1.5, 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_39_1.png

Rebound Bursting

[17]:
Iext, duration = bp.inputs.constant_input([(0, 100.), (-3.5, 500.), (0., 400.)])
neu = bp.neurons.GIF(1, a=0.005, A1=10., A2=-0.6)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_41_1.png

Mixed Mode

[18]:
Iext, duration = bp.inputs.constant_input([(2., 500.)])
neu = bp.neurons.GIF(1, a=0.005, A1=5., A2=-0.3)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_43_1.png

Afterpotentials

[19]:
Iext, duration = bp.inputs.constant_input([(2., 15.), (0, 185.)])
neu = bp.neurons.GIF(1, a=0.005, A1=5., A2=-0.3)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_45_1.png

Basal Bistability

[20]:
Iext, duration = bp.inputs.constant_input([(5., 10.), (0., 90.), (5., 10.), (0., 90.)])
neu = bp.neurons.GIF(1, A1=8., A2=-0.1)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_47_1.png

Preferred Frequency

[21]:
Iext, duration = bp.inputs.constant_input([(5., 10.),
                                           (0., 10.),
                                           (4., 10.),
                                           (0., 370.),
                                           (5., 10.),
                                           (0., 90.),
                                           (4., 10.),
                                           (0., 290.)])
neu = bp.neurons.GIF(1, a=0.005, A1=-3., A2=0.5)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_49_1.png

Spike Latency

[22]:
Iext, duration = bp.inputs.constant_input([(8., 2.), (0, 48.)])
neu = bp.neurons.GIF(1, a=-0.08)
run(neu, duration, Iext)
_images/neurons_Niebur_2009_GIF_51_1.png

(Jansen & Rit, 1995): Jansen-Rit Model

The Jansen-Rit model, a neural mass model of the dynamic interactions between 3 populations:

  • pyramidal cells (PCs)

  • excitatory interneurons (EINs)

  • inhibitory interneurons (IINs)

Originally, the model has been developed to describe the waxing-and-waning of EEG activity in the alpha frequency range (8-12 Hz) in the visual cortex [1]. In the past years, however, it has been used as a generic model to describe the macroscopic electrophysiological activity within a cortical column [2].

By using the linearity of the convolution operation, the dynamic interactions between PCs, EINs and IINs can be expressed via 6 coupled ordinary differential equations that are composed of the two operators defined above:

\[\begin{split}\begin{aligned} \dot V_{pce} &= I_{pce}, \\ \dot I_{pce} &= \frac{H_e}{\tau_e} c_4 S(c_3 V_{in}) - \frac{2 I_{pce}}{\tau_e} - \frac{V_{pce}}{\tau_e^2}, \\ \dot V_{pci} &= I_{pci}, \\ \dot I_{pci} &= \frac{H_i}{\tau_i} c_2 S(c_1 V_{in}) - \frac{2 I_{pci}}{\tau_i} - \frac{V_{pci}}{\tau_i^2}, \\ \dot V_{in} &= I_{in}, \\ \dot I_{in} &= \frac{H_e}{\tau_e} S(V_{pce} - V_{pci}) - \frac{2 I_{in}}{\tau_e} - \frac{V_{in}}{\tau_e^2}, \end{aligned}\end{split}\]

where \(V_{pce}\), \(V_{pci}\), \(V_{in}\) are used to represent the average membrane potential deflection caused by the excitatory synapses at the PC population, the inhibitory synapses at the PC population, and the excitatory synapses at both interneuron populations, respectively.

  • [1] B.H. Jansen & V.G. Rit (1995) Electroencephalogram and visual evoked potential generation in a mathematical model of coupled cortical columns. Biological Cybernetics, 73(4): 357-366.

  • [2] A. Spiegler, S.J. Kiebel, F.M. Atay, T.R. Knösche (2010) Bifurcation analysis of neural mass models: Impact of extrinsic inputs and dendritic time constants. NeuroImage, 52(3): 1041-1058, https://doi.org/10.1016/j.neuroimage.2009.12.081.

[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
class JansenRitModel(bp.DynamicalSystem):
  def __init__(self, num, C=135., method='exp_auto'):
    super(JansenRitModel, self).__init__()

    self.num = num

    # parameters #
    self.v_max = 5.  # maximum firing rate
    self.v0 = 6.  # firing threshold
    self.r = 0.56  # slope of the sigmoid
    # other parameters
    self.A = 3.25
    self.B = 22.
    self.a = 100.
    self.tau_e = 0.01  # second
    self.tau_i = 0.02  # second
    self.b = 50.
    self.e0 = 2.5
    # The connectivity constants
    self.C1 = C
    self.C2 = 0.8 * C
    self.C3 = 0.25 * C
    self.C4 = 0.25 * C

    # variables #
    # y0, y1 and y2 representing the firing rate of
    # pyramidal, excitatory and inhibitory neurones.
    self.y0 = bm.Variable(bm.zeros(self.num))
    self.y1 = bm.Variable(bm.zeros(self.num))
    self.y2 = bm.Variable(bm.zeros(self.num))
    self.y3 = bm.Variable(bm.zeros(self.num))
    self.y4 = bm.Variable(bm.zeros(self.num))
    self.y5 = bm.Variable(bm.zeros(self.num))
    self.p = bm.Variable(bm.ones(self.num) * 220.)

    # integral function
    self.derivative = bp.JointEq([self.dy0, self.dy1, self.dy2, self.dy3, self.dy4, self.dy5])
    self.integral = bp.odeint(self.derivative, method=method)

  def sigmoid(self, x):
    return self.v_max / (1. + bm.exp(self.r * (self.v0 - x)))

  def dy0(self, y0, t, y3): return y3

  def dy1(self, y1, t, y4): return y4

  def dy2(self, y2, t, y5): return y5

  def dy3(self, y3, t, y0, y1, y2):
    return (self.A * self.sigmoid(y1 - y2) - 2 * y3 - y0 / self.tau_e) / self.tau_e

  def dy4(self, y4, t, y0, y1, p):
    return (self.A * (p + self.C2 * self.sigmoid(self.C1 * y0)) - 2 * y4 - y1 / self.tau_e) / self.tau_e

  def dy5(self, y5, t, y0, y2):
    return (self.B * self.C4 * self.sigmoid(self.C3 * y0) - 2 * y5 - y2 / self.tau_i) / self.tau_i

  def update(self, tdi):
    self.y0.value, self.y1.value, self.y2.value, self.y3.value, self.y4.value, self.y5.value = \
      self.integral(self.y0, self.y1, self.y2, self.y3, self.y4, self.y5, tdi.t, p=self.p, dt=tdi.dt)
[3]:
def simulation(duration=5.):
  dt = 0.1 / 1e3
  # random input uniformly distributed between 120 and 320 pulses per second
  all_ps = bm.random.uniform(120, 320, size=(int(duration / dt), 1))
  jrm = JansenRitModel(num=6, C=bm.array([68., 128., 135., 270., 675., 1350.]))
  runner = bp.DSRunner(jrm,
                       monitors=['y0', 'y1', 'y2', 'y3', 'y4', 'y5'],
                       inputs=['p', all_ps, 'iter', '='],
                       dt=dt)
  runner.run(duration)

  start, end = int(2 / dt), int(duration / dt)
  fig, gs = bp.visualize.get_figure(6, 3, 2, 3)
  for i in range(6):
    fig.add_subplot(gs[i, 0])
    title = 'E' if i == 0 else None
    xlabel = 'time [s]' if i == 5 else None
    bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y1[start: end, i],
                           title=title, xlabel=xlabel, ylabel='Hz')
    fig.add_subplot(gs[i, 1])
    title = 'P' if i == 0 else None
    bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y0[start: end, i],
                           title=title, xlabel=xlabel)
    fig.add_subplot(gs[i, 2])
    title = 'I' if i == 0 else None
    bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y2[start: end, i],
                           title=title, show=i==5, xlabel=xlabel)
[4]:
simulation()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/neurons_JR_1995_jansen_rit_model_5_2.png

(Teka, et. al, 2018): Fractional-order Izhikevich neuron model

Implementation of the model:

  • Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. “Spiking and bursting patterns of fractional-order Izhikevich model.” Communications in Nonlinear Science and Numerical Simulation 56 (2018): 161-176.

[1]:
import brainpy as bp

import matplotlib.pyplot as plt
[2]:
def run_model(dt=0.1, duration=500, alpha=1.0):
    inputs, length = bp.inputs.section_input([0, 10], [50, duration],
                                              dt=dt, return_length=True)
    neuron = bp.neurons.FractionalIzhikevich(1, num_memory=int(length / dt), alpha=alpha)
    runner = bp.DSRunner(neuron,
                         monitors=['V'],
                         inputs=['input', inputs, 'iter'],
                         dt=dt)
    runner.run(length)

    plt.plot(runner.mon.ts, runner.mon.V.flatten())
    plt.xlabel('Time [ms]')
    plt.ylabel('Potential [mV]')
    plt.title(r'$\alpha$=' + str(alpha))
    plt.show()

Regular spiking

[3]:
run_model(dt=0.1, duration=500, alpha=1.0)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/neurons_2018_Fractional_Izhikevich_model_5_2.png

Intrinsically bursting

[4]:
run_model(dt=0.1, duration=500, alpha=0.87)
_images/neurons_2018_Fractional_Izhikevich_model_7_1.png

Mixed Mode (Irregular)

[5]:
run_model(dt=0.1, duration=500, alpha=0.86)
_images/neurons_2018_Fractional_Izhikevich_model_9_1.png

Chattering

[6]:
run_model(dt=0.1, duration=500, alpha=0.8)
_images/neurons_2018_Fractional_Izhikevich_model_11_1.png

Bursting

[7]:
run_model(dt=0.1, duration=1000, alpha=0.7)
_images/neurons_2018_Fractional_Izhikevich_model_13_1.png

Bursting with longer bursts

[8]:
run_model(dt=0.1, duration=1000, alpha=0.5)
_images/neurons_2018_Fractional_Izhikevich_model_15_1.png

Fast spiking

[9]:
run_model(dt=0.1, duration=1000, alpha=0.3)
_images/neurons_2018_Fractional_Izhikevich_model_17_1.png

(Mondal, et. al, 2019): Fractional-order FitzHugh-Rinzel bursting neuron model

Implementation of the paper:

  • Mondal, A., Sharma, S.K., Upadhyay, R.K. et al. Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. Sci Rep 9, 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4

[1]:
import brainpy as bp

import matplotlib.pyplot as plt
[2]:
def run_model(model, inputs, length):
    runner = bp.DSRunner(model, monitors=['V'], inputs=inputs)
    runner.run(length)
    return runner

Parameter set 1

[3]:
dt = 0.1
Iext = 0.3125
duration = 4000
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.0001,
                   w_initializer=bp.init.Constant(-0.1),
                   y_initializer=bp.init.Constant(0.1))
[4]:
alphas = [1.0, 0.98, 0.95]

plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
    neuron = bp.neurons.FractionalFHR(1,
                                      alpha=alpha,
                                      num_memory=4000,
                                      **neuron_pars)
    runner = run_model(neuron, inputs=['input', Iext], length=duration)

    plt.subplot(len(alphas), 1, i+1)
    plt.plot(runner.mon.ts, runner.mon.V[:, 0])
    plt.title(r'$\alpha$=' + str(alphas[i]))
    plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/neurons_2019_Fractional_order_FHR_model_6_4.png

Parameter set 2

[5]:
Iext = 0.4
duration = 3500
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.0001)
[6]:
alphas = [1.0, 0.92, 0.85, 0.68]

plt.figure(figsize=(12, 10))
for i, alpha in enumerate(alphas):
    neuron = bp.neurons.FractionalFHR(1,
                                      alpha=alpha,
                                      num_memory=4000,
                                      **neuron_pars)
    runner = run_model(neuron, inputs=['input', Iext], length=duration)

    plt.subplot(len(alphas), 1, i+1)
    plt.plot(runner.mon.ts, runner.mon.V[:, 0])
    plt.title(r'$\alpha$=' + str(alphas[i]))
    plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
_images/neurons_2019_Fractional_order_FHR_model_9_4.png

Parameter set 3

[7]:
Iext = 3
duration = 500
neuron_pars = dict(a=0.7, b=0.8, c=-0.775, d=1., delta=0.08, mu=0.18)
[8]:
alphas = [1.0, 0.99, 0.97, 0.95]

plt.figure(figsize=(12, 10))
for i, alpha in enumerate(alphas):
    neuron = bp.neurons.FractionalFHR(1,
                                      alpha=alpha,
                                      **neuron_pars)
    runner = run_model(neuron, inputs=['input', Iext], length=duration)

    plt.subplot(len(alphas), 1, i+1)
    plt.plot(runner.mon.ts, runner.mon.V[:, 0])
    plt.title(r'$\alpha$=' + str(alphas[i]))
    plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
_images/neurons_2019_Fractional_order_FHR_model_12_4.png

Parameter set 4

[9]:
Iext = 0.3125
duration = 3500
neuron_pars = dict(a=0.7, b=0.8, c=1.3, d=1., delta=0.08, mu=0.0001)
[10]:
alphas = [1.0, 0.85, 0.80]

plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
    neuron = bp.neurons.FractionalFHR(1,
                                      alpha=alpha,
                                      num_memory=3000,
                                      **neuron_pars)
    runner = run_model(neuron, inputs=['input', Iext], length=duration)

    plt.subplot(len(alphas), 1, i+1)
    plt.plot(runner.mon.ts, runner.mon.V[:, 0])
    plt.title(r'$\alpha$=' + str(alphas[i]))
    plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
_images/neurons_2019_Fractional_order_FHR_model_15_3.png

Parameter set 5

[11]:
Iext = 0.3125
duration = 2500
neuron_pars = dict(a=0.7, b=0.8, c=-0.908, d=1., delta=0.08, mu=0.002)
[12]:
alphas = [1.0, 0.98, 0.95]

plt.figure(figsize=(9, 10))
for i, alpha in enumerate(alphas):
    neuron = bp.neurons.FractionalFHR(1, alpha=alpha, **neuron_pars)
    runner = run_model(neuron, inputs=['input', Iext], length=duration)

    plt.subplot(len(alphas), 1, i+1)
    plt.plot(runner.mon.ts, runner.mon.V[:, 0])
    plt.title(r'$\alpha$=' + str(alphas[i]))
    plt.ylabel('V')
plt.xlabel('Time [ms]')
plt.show()
_images/neurons_2019_Fractional_order_FHR_model_18_3.png

(Si Wu, 2008): Continuous-attractor Neural Network 1D

Here we show the implementation of the paper:

  • Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.

Author:

The mathematical equation of the Continuous-Attractor Neural Network (CANN) is given by:

\[\tau \frac{du(x,t)}{dt} = -u(x,t) + \rho \int dx' J(x,x') r(x',t)+I_{ext}\]
\[r(x,t) = \frac{u(x,t)^2}{1 + k \rho \int dx' u(x',t)^2}\]
\[J(x,x') = \frac{1}{\sqrt{2\pi}a}\exp(-\frac{|x-x'|^2}{2a^2})\]
\[I_{ext} = A\exp\left[-\frac{|x-z(t)|^2}{4a^2}\right]\]
[7]:
import brainpy as bp
import brainpy.math as bm
[8]:
class CANN1D(bp.dyn.NeuGroup):
  def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4.,
               z_min=-bm.pi, z_max=bm.pi, **kwargs):
    super(CANN1D, self).__init__(size=num, **kwargs)

    # parameters
    self.tau = tau  # The synaptic time constant
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
    self.rho = num / self.z_range  # The neural density
    self.dx = self.z_range / num  # The stimulus density

    # variables
    self.u = bm.Variable(bm.zeros(num))
    self.input = bm.Variable(bm.zeros(num))

    # The connection matrix
    self.conn_mat = self.make_conn(self.x)

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, u, t, Iext):
    r1 = bm.square(u)
    r2 = 1.0 + self.k * bm.sum(r1)
    r = r1 / r2
    Irec = bm.dot(self.conn_mat, r)
    du = (-u + Irec + Iext) / self.tau
    return du

  def dist(self, d):
    d = bm.remainder(d, self.z_range)
    d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
    return d

  def make_conn(self, x):
    assert bm.ndim(x) == 1
    x_left = bm.reshape(x, (-1, 1))
    x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
    d = self.dist(x_left - x_right)
    Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / \
          (bm.sqrt(2 * bm.pi) * self.a)
    return Jxx

  def get_stimulus_by_pos(self, pos):
    return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

  def update(self, tdi):
    self.u.value = self.integral(self.u, tdi.t, self.input, tdi.dt)
    self.input[:] = 0.
[9]:
cann = CANN1D(num=512, k=0.1)

Population coding

[10]:
I1 = cann.get_stimulus_by_pos(0.)
Iext, duration = bp.inputs.section_input(values=[0., I1, 0.],
                                         durations=[1., 8., 8.],
                                         return_length=True)
runner = bp.DSRunner(cann,
                     inputs=['input', Iext, 'iter'],
                     monitors=['u'])
runner.run(duration)
bp.visualize.animate_1D(
  dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
                  {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
  frame_step=1,
  frame_delay=100,
  show=True,
  # save_path='../../images/cann-encoding.gif'
)

image0

Template matching

The cann can perform efficient population decoding by achieving template-matching.

[11]:
cann.k = 8.1

dur1, dur2, dur3 = 10., 30., 0.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
Iext = bm.zeros((num1 + num2 + num3,) + cann.size)
Iext[:num1] = cann.get_stimulus_by_pos(0.5)
Iext[num1:num1 + num2] = cann.get_stimulus_by_pos(0.)
Iext[num1:num1 + num2] += 0.1 * cann.A * bm.random.randn(num2, *cann.size)

runner = bp.dyn.DSRunner(cann,
                         inputs=('input', Iext, 'iter'),
                         monitors=['u'])
runner.run(dur1 + dur2 + dur3)
bp.visualize.animate_1D(
  dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
                  {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
  frame_step=5,
  frame_delay=50,
  show=True,
  # save_path='../../images/cann-decoding.gif'
)

image0

Smooth tracking

The cann can track moving stimulus.

[12]:
dur1, dur2, dur3 = 20., 20., 20.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
position = bm.zeros(num1 + num2 + num3)
position[num1: num1 + num2] = bm.linspace(0., 12., num2)
position[num1 + num2:] = 12.
position = position.reshape((-1, 1))
Iext = cann.get_stimulus_by_pos(position)
runner = bp.dyn.DSRunner(cann,
                         inputs=('input', Iext, 'iter'),
                         monitors=['u'])
runner.run(dur1 + dur2 + dur3)
bp.visualize.animate_1D(
  dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
                  {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],
  frame_step=5,
  frame_delay=50,
  show=True,
  # save_path='../../images/cann-tracking.gif'
)

image0

(Si Wu, 2008): Continuous-attractor Neural Network 2D

Implementation of the paper: - Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.

The mathematical equation of the Continuous-Attractor Neural Network (CANN) is given by:

\[\tau \frac{du(x,t)}{dt} = -u(x,t) + \rho \int dx' J(x,x') r(x',t)+I_{ext}\]
\[r(x,t) = \frac{u(x,t)^2}{1 + k \rho \int dx' u(x',t)^2}\]
\[J(x,x') = \frac{1}{\sqrt{2\pi}a}\exp(-\frac{|x-x'|^2}{2a^2})\]
\[I_{ext} = A\exp\left[-\frac{|x-z(t)|^2}{4a^2}\right]\]

The model solving is usig Fast Fourier transform. It can run on CPU, GPU, TPU devices.

[1]:
import matplotlib.pyplot as plt
import jax
[2]:
import brainpy as bp
import brainpy.math as bm
[3]:
class CANN2D(bp.dyn.NeuGroup):
  def __init__(self, length, tau=1., k=8.1, a=0.5, A=10., J0=4.,
               z_min=-bm.pi, z_max=bm.pi, name=None):
    super(CANN2D, self).__init__(size=(length, length), name=name)

    # parameters
    self.length = length
    self.tau = tau  # The synaptic time constant
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, length)  # The encoded feature values
    self.rho = length / self.z_range  # The neural density
    self.dx = self.z_range / length  # The stimulus density

    # The connections
    self.conn_mat = self.make_conn()

    # variables
    self.r = bm.Variable(bm.zeros((length, length)))
    self.u = bm.Variable(bm.zeros((length, length)))
    self.input = bm.Variable(bm.zeros((length, length)))

  def show_conn(self):
    plt.imshow(bm.as_numpy(self.conn_mat))
    plt.colorbar()
    plt.show()

  def dist(self, d):
    v_size = bm.asarray([self.z_range, self.z_range])
    return bm.where(d > v_size / 2, v_size - d, d)

  def make_conn(self):
    x1, x2 = bm.meshgrid(self.x, self.x)
    value = bm.stack([x1.flatten(), x2.flatten()]).T

    @jax.vmap
    def get_J(v):
      d = self.dist(bm.abs(v - value))
      d = bm.linalg.norm(d, axis=1)
      # d = d.reshape((self.length, self.length))
      Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
      return Jxx

    return get_J(value)

  def get_stimulus_by_pos(self, pos):
    assert bm.size(pos) == 2
    x1, x2 = bm.meshgrid(self.x, self.x)
    value = bm.stack([x1.flatten(), x2.flatten()]).T
    d = self.dist(bm.abs(bm.asarray(pos) - value))
    d = bm.linalg.norm(d, axis=1)
    d = d.reshape((self.length, self.length))
    return self.A * bm.exp(-0.25 * bm.square(d / self.a))

  def update(self, tdi):
    r1 = bm.square(self.u)
    r2 = 1.0 + self.k * bm.sum(r1)
    self.r.value = r1 / r2
    interaction = (self.r.flatten() @ self.conn_mat).reshape((self.length, self.length))
    self.u.value = self.u + (-self.u + self.input + interaction) / self.tau * tdi.dt
    self.input[:] = 0.
[4]:
cann = CANN2D(length=100, k=0.1)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
cann.show_conn()
_images/cann_Wu_2008_CANN_2D_8_0.png
[6]:
# encoding
Iext, length = bp.inputs.section_input(
    values=[cann.get_stimulus_by_pos([0., 0.]), 0.],
    durations=[10., 20.],
    return_length=True
)
runner = bp.dyn.DSRunner(cann,
                         inputs=['input', Iext, 'iter'],
                         monitors=['r'],
                         dyn_vars=cann.vars())
runner.run(length)
[7]:
bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)),
                        net_size=(cann.length, cann.length))

encoding

[8]:
# tracking
length = 20
positions = bp.inputs.ramp_input(-bm.pi, bm.pi, duration=length, t_start=0)
positions = bm.stack([positions, positions]).T
Iext = jax.vmap(cann.get_stimulus_by_pos)(positions)
runner = bp.dyn.DSRunner(cann,
                         inputs=['input', Iext, 'iter'],
                         monitors=['r'],
                         dyn_vars=cann.vars())
runner.run(length)
[9]:
bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)),
                        net_size=(cann.length, cann.length))

tracking

CANN 1D Oscillatory Tracking

Implementation of the paper:

  • Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. “Dynamics and computation of continuous attractors.” Neural computation 20.4 (2008): 994-1025.

  • Mi, Y., Fung, C. C., Wong, M. K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 1(January), 505.

[1]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
[2]:
class CANN1D(bp.NeuGroup):
  def __init__(self, num, tau=1., tau_v=50., k=1., a=0.3, A=0.2, J0=1.,
               z_min=-bm.pi, z_max=bm.pi, m=0.3):
    super(CANN1D, self).__init__(size=num)

    # parameters
    self.tau = tau  # The synaptic time constant
    self.tau_v = tau_v
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value
    self.m = m

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
    self.rho = num / self.z_range  # The neural density
    self.dx = self.z_range / num  # The stimulus density

    # The connection matrix
    self.conn_mat = self.make_conn()

    # variables
    self.r = bm.Variable(bm.zeros(num))
    self.u = bm.Variable(bm.zeros(num))
    self.v = bm.Variable(bm.zeros(num))
    self.input = bm.Variable(bm.zeros(num))

  def dist(self, d):
    d = bm.remainder(d, self.z_range)
    d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
    return d

  def make_conn(self):
    x_left = bm.reshape(self.x, (-1, 1))
    x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)
    d = self.dist(x_left - x_right)
    conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
    return conn

  def get_stimulus_by_pos(self, pos):
    return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

  def update(self, tdi):
    r1 = bm.square(self.u)
    r2 = 1.0 + self.k * bm.sum(r1)
    self.r.value = r1 / r2
    Irec = bm.dot(self.conn_mat, self.r)
    self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * tdi.dt
    self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * tdi.dt
    self.input[:] = 0.
[3]:
cann = CANN1D(num=512)
[4]:
dur1, dur2, dur3 = 100., 2000., 500.
num1 = int(dur1 / bm.get_dt())
num2 = int(dur2 / bm.get_dt())
num3 = int(dur3 / bm.get_dt())
position = bm.zeros(num1 + num2 + num3)
final_pos = cann.a / cann.tau_v * 0.6 * dur2
position[num1: num1 + num2] = bm.linspace(0., final_pos, num2)
position[num1 + num2:] = final_pos
position = position.reshape((-1, 1))
Iext = cann.get_stimulus_by_pos(position)

runner = bp.DSRunner(cann,
                     inputs=('input', Iext, 'iter'),
                     monitors=['u', 'v'])
runner.run(dur1 + dur2 + dur3)
_ = bp.visualize.animate_1D(
  dynamical_vars=[
    {'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},
    {'ys': runner.mon.v, 'xs': cann.x, 'legend': 'v'},
    {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}
  ],
  frame_step=30,
  frame_delay=5,
  show=True
)

image0

(Wang, 2002) Decision making spiking model

Implementation of the paper: Wang, Xiao-Jing. “Probabilistic decision making by slow reverberation in cortical circuits.” Neuron 36.5 (2002): 955-968.

Please refer to github example folder:

(Wong & Wang, 2006) Decision making rate model

Please refer to Anlysis of A Decision Making Model.

(Vreeswijk & Sompolinsky, 1996) E/I balanced network

Overviews

Van Vreeswijk and Sompolinsky proposed E-I balanced network in 1996 to explain the temporally irregular spiking patterns. They suggested that the temporal variability may originated from the balance between excitatory and inhibitory inputs.

There are \(N_E\) excitatory neurons and \(N_I\) inbibitory neurons.

An important feature of the network is random and sparse connectivity. Connections between neurons \(K\) meets \(1 << K << N_E\).

Implementations

[1]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

Dynamic of membrane potential is given as:

\[\tau \frac {dV_i}{dt} = -(V_i - V_{rest}) + I_i^{ext} + I_i^{net} (t)\]

where \(I_i^{net}(t)\) represents the synaptic current, which describes the sum of excitatory and inhibitory neurons.

\[I_i^{net} (t) = J_E \sum_{j=1}^{pN_e} \sum_{t_j^\alpha < t} f(t-t_j^\alpha ) - J_I \sum_{j=1}^{pN_i} \sum_{t_j^\alpha < t} f(t-t_j^\alpha )\]

where

\[\begin{split} f(t) = \begin{cases} {\rm exp} (-\frac t {\tau_s} ), \quad t \geq 0 \\ 0, \quad t < 0 \end{cases}\end{split}\]

Parameters: \(J_E = \frac 1 {\sqrt {pN_e}}, J_I = \frac 1 {\sqrt {pN_i}}\)

We can see from the dynamic that network is based on leaky Integrate-and-Fire neurons, and we can just use get_LIF from bpmodels.neurons to get this model.

The function of \(I_i^{net}(t)\) is actually a synase with single exponential decay, we can also get it by using get_exponential.

Network

Let’s create a neuron group with \(N_E\) excitatory neurons and \(N_I\) inbibitory neurons. Use conn=bp.connect.FixedProb(p) to implement the random and sparse connections.

[2]:
class EINet(bp.Network):
  def __init__(self, num_exc, num_inh, prob, JE, JI):
    # neurons
    pars = dict(V_rest=-52., V_th=-50., V_reset=-60., tau=10., tau_ref=0.,
                V_initializer=bp.init.Normal(-60., 10.))
    E = bp.neurons.LIF(num_exc, **pars)
    I = bp.neurons.LIF(num_inh, **pars)

    # synapses
    E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob), g_max=JE, tau=2.)
    E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob), g_max=JE, tau=2.)
    I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob), g_max=JI, tau=2.)
    I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob), g_max=JI, tau=2.)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
num_exc = 500
num_inh = 500
prob = 0.1

Ib = 3.
JE = 1 / bp.math.sqrt(prob * num_exc)
JI = -1 / bp.math.sqrt(prob * num_inh)
[4]:
net = EINet(num_exc, num_inh, prob=prob, JE=JE, JI=JI)

runner = bp.DSRunner(net,
                     monitors=['E.spike'],
                     inputs=[('E.input', Ib), ('I.input', Ib)])
t = runner.run(1000.)

Visualization

[5]:
import matplotlib.pyplot as plt

fig, gs = bp.visualize.get_figure(4, 1, 2, 10)

fig.add_subplot(gs[:3, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(50, 950))

fig.add_subplot(gs[3, 0])
rates = bp.measure.firing_rate(runner.mon['E.spike'], 5.)
plt.plot(runner.mon.ts, rates)
plt.xlim(50, 950)
plt.show()
_images/ei_nets_Vreeswijk_1996_EI_net_14_0.png

Reference

[1] Van Vreeswijk, Carl, and Haim Sompolinsky. “Chaos in neuronal networks with balanced excitatory and inhibitory activity.” Science 274.5293 (1996): 1724-1726.

(Brette, et, al., 2007) COBA

Implementation of the paper:

  • Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98

which is based on the balanced network proposed by:

  • Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95

Authors:

[1]:
import brainpy as bp

bp.math.set_platform('cpu')

Version 1

[2]:
class EINet(bp.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                V_initializer=bp.init.Normal(-55., 2.))
    E = bp.neurons.LIF(num_exc, **pars, method=method)
    I = bp.neurons.LIF(num_inh, **pars, method=method)

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02),
                                  g_max=we, tau=5., method=method,
                                  output=bp.synouts.COBA(E=0.))
    E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02),
                                  g_max=we, tau=5., method=method,
                                  output=bp.synouts.COBA(E=0.))
    I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02),
                                  g_max=wi, tau=10., method=method,
                                  output=bp.synouts.COBA(E=-80.))
    I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02),
                                  g_max=wi, tau=10., method=method,
                                  output=bp.synouts.COBA(E=-80.))

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
# network
net = EINet()
[4]:
# simulation
runner = bp.DSRunner(
  net,
  monitors=['E.spike'],
  inputs=[('E.input', 20.), ('I.input', 20.)]
)
runner.run(100.)
[5]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
_images/ei_nets_Brette_2007_COBA_7_0.png

Version 2

[6]:
class EINet_V2(bp.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    super(EINet_V2, self).__init__()

    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    self.N = bp.neurons.LIF(num_exc + num_inh,
                            V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
                            method=method, V_initializer=bp.initialize.Normal(-55., 2.))

    # synapses
    we = 0.6 / scale  # excitatory synaptic weight (voltage)
    wi = 6.7 / scale  # inhibitory synaptic weight
    self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc],
                                        post=self.N,
                                        conn=bp.connect.FixedProb(0.02),
                                        g_max=we, tau=5.,
                                        output=bp.synouts.COBA(E=0.),
                                        method=method)
    self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:],
                                        post=self.N,
                                        conn=bp.connect.FixedProb(0.02),
                                        g_max=wi, tau=10.,
                                        output=bp.synouts.COBA(E=-80.),
                                        method=method)
[7]:
net = EINet_V2(scale=1., method='exp_auto')
# simulation
runner = bp.DSRunner(
net,
monitors={'spikes': net.N.spike},
inputs=[(net.N.input, 20.)]
)
runner.run(100.)

# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)
_images/ei_nets_Brette_2007_COBA_10_1.png

(Brette, et, al., 2007) CUBA

Implementation of the paper:

  • Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98

which is based on the balanced network proposed by:

  • Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95

Authors:

[1]:
import brainpy as bp

bp.math.set_platform('cpu')
[2]:
class CUBA(bp.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    # network size
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)

    # neurons
    pars = dict(V_rest=-49, V_th=-50., V_reset=-60, tau=20., tau_ref=5.,
                V_initializer=bp.init.Normal(-55., 2.))
    E = bp.neurons.LIF(num_exc, **pars, method=method)
    I = bp.neurons.LIF(num_inh, **pars, method=method)

    # synapses
    we = 1.62 / scale  # excitatory synaptic weight (voltage)
    wi = -9.0 / scale  # inhibitory synaptic weight
    E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(0.02),
                                  g_max=we, tau=5., method=method)
    E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(0.02),
                                  g_max=we, tau=5., method=method)
    I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(0.02),
                                  g_max=wi, tau=10., method=method)
    I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(0.02),
                                  g_max=wi, tau=10., method=method)

    super(CUBA, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[3]:
# network
net = CUBA()
[4]:
# simulation
runner = bp.DSRunner(net,
                     monitors=['E.spike'],
                     inputs=[('E.input', 20.), ('I.input', 20.)])
t = runner.run(100.)
[5]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
_images/ei_nets_Brette_2007_CUBA_6_0.png

(Brette, et, al., 2007) COBA-HH

Implementation of the paper:

  • Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007), Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98

which is based on the balanced network proposed by:

  • Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95

Authors:

[1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

Parameters

[2]:
num_exc = 3200
num_inh = 800
Cm = 200  # Membrane Capacitance [pF]

gl = 10.  # Leak Conductance   [nS]
g_Na = 20. * 1000
g_Kd = 6. * 1000  # K Conductance      [nS]
El = -60.  # Resting Potential [mV]
ENa = 50.  # reversal potential (Sodium) [mV]
EK = -90.  # reversal potential (Potassium) [mV]
VT = -63.
V_th = -20.

# Time constants
taue = 5.  # Excitatory synaptic time constant [ms]
taui = 10.  # Inhibitory synaptic time constant [ms]

# Reversal potentials
Ee = 0.  # Excitatory reversal potential (mV)
Ei = -80.  # Inhibitory reversal potential (Potassium) [mV]

# excitatory synaptic weight
we = 6.  # excitatory synaptic conductance [nS]

# inhibitory synaptic weight
wi = 67.  # inhibitory synaptic conductance [nS]

Implementation 1

[3]:
class HH(bp.NeuGroup):
  def __init__(self, size, method='exp_auto'):
    super(HH, self).__init__(size)

    # variables
    self.V = bm.Variable(El + (bm.random.randn(self.num) * 5 - 5))
    self.m = bm.Variable(bm.zeros(self.num))
    self.n = bm.Variable(bm.zeros(self.num))
    self.h = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.input = bm.Variable(bm.zeros(size))

    def dV(V, t, m, h, n, Isyn):
      gna = g_Na * (m * m * m) * h
      gkd = g_Kd * (n * n * n * n)
      dVdt = (-gl * (V - El) - gna * (V - ENa) - gkd * (V - EK) + Isyn) / Cm
      return dVdt

    def dm(m, t, V, ):
      m_alpha = 0.32 * (13 - V + VT) / (bm.exp((13 - V + VT) / 4) - 1.)
      m_beta = 0.28 * (V - VT - 40) / (bm.exp((V - VT - 40) / 5) - 1)
      dmdt = (m_alpha * (1 - m) - m_beta * m)
      return dmdt

    def dh(h, t, V):
      h_alpha = 0.128 * bm.exp((17 - V + VT) / 18)
      h_beta = 4. / (1 + bm.exp(-(V - VT - 40) / 5))
      dhdt = (h_alpha * (1 - h) - h_beta * h)
      return dhdt

    def dn(n, t, V):
      c = 15 - V + VT
      n_alpha = 0.032 * c / (bm.exp(c / 5) - 1.)
      n_beta = .5 * bm.exp((10 - V + VT) / 40)
      dndt = (n_alpha * (1 - n) - n_beta * n)
      return dndt

    # functions
    self.integral = bp.odeint(bp.JointEq([dV, dm, dh, dn]), method=method)

  def update(self, tdi):
    V, m, h, n = self.integral(self.V, self.m, self.h, self.n, tdi.t, Isyn=self.input, dt=tdi.dt)
    self.spike.value = bm.logical_and(self.V < V_th, V >= V_th)
    self.m.value = m
    self.h.value = h
    self.n.value = n
    self.V.value = V
    self.input[:] = 0.
[4]:
class ExpCOBA(bp.TwoEndConn):
  def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
               method='exp_auto'):
    super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn)
    self.check_pre_attrs('spike')
    self.check_post_attrs('input', 'V')

    # parameters
    self.E = E
    self.tau = tau
    self.delay = delay
    self.g_max = g_max
    self.pre2post = self.conn.require('pre2post')

    # variables
    self.g = bm.Variable(bm.zeros(self.post.num))

    # function
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def update(self, tdi):
    self.g.value = self.integral(self.g, tdi.t, dt=tdi.dt)
    post_sps = bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max)
    self.g.value += post_sps
    self.post.input += self.g * (self.E - self.post.V)
[5]:
class COBAHH(bp.Network):
  def __init__(self, scale=1., method='exp_auto'):
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)
    E = HH(num_exc, method=method)
    I = HH(num_inh, method=method)
    E2E = ExpCOBA(pre=E, post=E, conn=bp.conn.FixedProb(prob=0.02),
                  E=Ee, g_max=we / scale, tau=taue, method=method)
    E2I = ExpCOBA(pre=E, post=I, conn=bp.conn.FixedProb(prob=0.02),
                  E=Ee, g_max=we / scale, tau=taue, method=method)
    I2E = ExpCOBA(pre=I, post=E, conn=bp.conn.FixedProb(prob=0.02),
                  E=Ei, g_max=wi / scale, tau=taui, method=method)
    I2I = ExpCOBA(pre=I, post=I, conn=bp.conn.FixedProb(prob=0.02),
                  E=Ei, g_max=wi / scale, tau=taui, method=method)

    super(COBAHH, self).__init__(E2E, E2I, I2I, I2E, E=E, I=I)
[6]:
net = COBAHH()
[7]:
runner = bp.DSRunner(net, monitors=['E.spike'])
t = runner.run(100.)
[8]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
_images/ei_nets_Brette_2007_COBAHH_11_0.png

Implementation 2

[9]:
class HH2(bp.CondNeuGroup):
  def __init__(self, size):
    super(HH2, self).__init__(size, )
    self.INa = bp.channels.INa_TM1991(size, g_max=100., V_sh=-63.)
    self.IK = bp.channels.IK_TM1991(size, g_max=30., V_sh=-63.)
    self.IL = bp.channels.IL(size, E=-60., g_max=0.05)

[10]:
class EINet_v2(bp.Network):
  def __init__(self, scale=1.):
    super(EINet_v2, self).__init__()

    prob = 0.02
    self.num_exc = int(3200 * scale)
    self.num_inh = int(800 * scale)

    self.N = HH2(self.num_exc + self.num_inh)
    self.Esyn = bp.synapses.Exponential(self.N[:self.num_exc],
                                        self.N,
                                        bp.conn.FixedProb(prob),
                                        g_max=0.03 / scale, tau=5,
                                        output=bp.synouts.COBA(E=0.))
    self.Isyn = bp.synapses.Exponential(self.N[self.num_exc:],
                                        self.N,
                                        bp.conn.FixedProb(prob),
                                        g_max=0.335 / scale, tau=10.,
                                        output=bp.synouts.COBA(E=-80))
[11]:
net = EINet_v2(scale=1)
runner = bp.DSRunner(net, monitors={'spikes': net.N.spike})
runner.run(100.)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True)
_images/ei_nets_Brette_2007_COBAHH_15_1.png

(Tian, et al., 2020) E/I Net for fast response

Implementation of the paper:

  • Tian, Gengshuo, et al. “Excitation-Inhibition Balanced Neural Networks for Fast Signal Detection.” Frontiers in Computational Neuroscience 14 (2020): 79.

Author: Chaoming Wang

[ ]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
[2]:
# set parameters

num = 10000
num_inh = int(num * 0.2)
num_exc = num - num_inh
prob = 0.25

tau_E = 15.
tau_I = 10.
V_reset = 0.
V_threshold = 15.
f_E = 3.
f_I = 2.
mu_f = 0.1

tau_Es = 6.
tau_Is = 5.
JEE = 0.25
JEI = -1.
JIE = 0.4
JII = -1.
[3]:
class LIF(bp.NeuGroup):
  def __init__(self, size, tau, **kwargs):
    super(LIF, self).__init__(size, **kwargs)

    # parameters
    self.tau = tau

    # variables
    self.V = bp.math.Variable(bp.math.zeros(size))
    self.spike = bp.math.Variable(bp.math.zeros(size, dtype=bool))
    self.input = bp.math.Variable(bp.math.zeros(size))

    # integral
    self.integral = bp.odeint(lambda V, t, Isyn: (-V + Isyn) / self.tau)

  def update(self, tdi):
    V = self.integral(self.V, tdi.t, self.input, tdi.dt)
    self.spike.value = V >= V_threshold
    self.V.value = bm.where(self.spike, V_reset, V)
    self.input[:] = 0.
[4]:
class EINet(bp.Network):
  def __init__(self):
    # neurons
    E = LIF(num_exc, tau=tau_E)
    I = LIF(num_inh, tau=tau_I)
    E.V[:] = bm.random.random(num_exc) * (V_threshold - V_reset) + V_reset
    I.V[:] = bm.random.random(num_inh) * (V_threshold - V_reset) + V_reset

    # synapses
    E2I = bp.synapses.Exponential(pre=E, post=I, conn=bp.conn.FixedProb(prob), tau=tau_Es, g_max=JIE)
    E2E = bp.synapses.Exponential(pre=E, post=E, conn=bp.conn.FixedProb(prob), tau=tau_Es, g_max=JEE)
    I2I = bp.synapses.Exponential(pre=I, post=I, conn=bp.conn.FixedProb(prob), tau=tau_Is, g_max=JII)
    I2E = bp.synapses.Exponential(pre=I, post=E, conn=bp.conn.FixedProb(prob), tau=tau_Is, g_max=JEI)

    super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
[5]:
net = EINet()
[6]:
runner = bp.DSRunner(net,
                     monitors=['E.spike', 'I.spike'],
                     inputs=[('E.input', f_E * bm.sqrt(num) * mu_f),
                             ('I.input', f_I * bm.sqrt(num) * mu_f)])
t = runner.run(100.)
[7]:
# visualization
fig, gs = bp.visualize.get_figure(5, 1, 1.5, 10)

fig.add_subplot(gs[:3, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(0, 100))

fig.add_subplot(gs[3:, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], xlim=(0, 100), show=True)
_images/ei_nets_Tian_2020_EI_net_for_fast_response_8_0.png

Classify MNIST dataset by a fully connected LIF layer

The tutorial will introduce how to train a simple SNN using the encoder and the surrogate gradient method to classify MNIST.

Convolutional SNN to Classify Fashion-MNIST

In this tutorial, we will build a convolutional spiking neural network to classify the Fashion-MNIST dataset.

(2022, NeurIPS): Online Training Through Time for Spiking Neural Networks

Implementation of the paper:

  • Xiao, M., Meng, Q., Zhang, Z., He, D., & Lin, Z. (2022). Online Training Through Time for Spiking Neural Networks. ArXiv, abs/2210.04195.

(2019, Zenke, F.): SNN Surrogate Gradient Learning

Training a spiking neural network with surrogate gradient learning.

(2019, Zenke, F.): SNN Surrogate Gradient Learning to Classify Fashion-MNIST

Training a spiking neural network on a simple vision dataset.

(2021, Raminmh): Liquid time-constant Networks

Training a liquid_time_constant_network with BPTT.

Predicting Mackey-Glass timeseries

[1]:
import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bp_data

# bm.set_platform('cpu')
bm.set_environment(mode=bm.batching_mode, x64=True)
[2]:
import numpy as np
import matplotlib.pyplot as plt

Dataset

[3]:
def plot_mackey_glass_series(ts, x_series, x_tau_series, num_sample):
  plt.figure(figsize=(13, 5))

  plt.subplot(121)
  plt.title(f"Timeserie - {num_sample} timesteps")
  plt.plot(ts[:num_sample], x_series[:num_sample], lw=2, color="lightgrey", zorder=0)
  plt.scatter(ts[:num_sample], x_series[:num_sample], c=ts[:num_sample], cmap="viridis", s=6)
  plt.xlabel("$t$")
  plt.ylabel("$P(t)$")

  ax = plt.subplot(122)
  ax.margins(0.05)
  plt.title(f"Phase diagram: $P(t) = f(P(t-\\tau))$")
  plt.plot(x_tau_series[: num_sample], x_series[: num_sample], lw=1, color="lightgrey", zorder=0)
  plt.scatter(x_tau_series[:num_sample], x_series[: num_sample], lw=0.5, c=ts[:num_sample], cmap="viridis", s=6)
  plt.xlabel("$P(t-\\tau)$")
  plt.ylabel("$P(t)$")
  cbar = plt.colorbar()
  # cbar.ax.set_ylabel('$t$', rotation=270)
  cbar.ax.set_ylabel('$t$')

  plt.tight_layout()
  plt.show()
[4]:
dt = 0.1

mg_data = bp_data.chaos.MackeyGlassEq(25000, dt=dt, tau=17, beta=0.2, gamma=0.1, n=10, inits=1.2, seed=123)

plot_mackey_glass_series(mg_data.ts, mg_data.xs, mg_data.ys, num_sample=int(1000 / dt))
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/reservoir_computing_predicting_Mackey_Glass_timeseries_5_1.png
[5]:
forecast = int(10 / dt)  # predict 10 s ahead
train_length = int(20000 / dt)
sample_rate = int(1 / dt)

X_train = mg_data.xs[:train_length:sample_rate]
Y_train = mg_data.xs[forecast: train_length + forecast: sample_rate]
X_test = mg_data.xs[train_length: -forecast: sample_rate]
Y_test = mg_data.xs[train_length + forecast::sample_rate]

X_train = np.expand_dims(X_train, 0)
Y_train = np.expand_dims(Y_train, 0)
X_test = np.expand_dims(X_test, 0)
Y_test = np.expand_dims(Y_test, 0)
[6]:
sample = 300
fig = plt.figure(figsize=(15, 5))
plt.plot(X_train.flatten()[:sample], label="Training data")
plt.plot(Y_train.flatten()[:sample], label="True prediction")
plt.legend()
plt.show()
_images/reservoir_computing_predicting_Mackey_Glass_timeseries_7_0.png

Model

[7]:
class ESN(bp.DynamicalSystem):
  def __init__(self, num_in, num_hidden, num_out):
    super(ESN, self).__init__()
    self.r = bp.layers.Reservoir(num_in, num_hidden,
                                 Wrec_initializer=bp.init.KaimingNormal())
    self.o = bp.layers.Dense(num_hidden, num_out, mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))


model = ESN(1, 100, 1)
[8]:
runner = bp.DSTrainer(model)
out = runner.predict(bm.asarray(X_train))

out.shape
[8]:
(1, 20000, 1)

Training

[9]:
trainer = bp.RidgeTrainer(model, alpha=1e-6)
[10]:
_ = trainer.fit([bm.asarray(X_train),
                 bm.asarray(Y_train)])

Prediction

[11]:
ys_predict = trainer.predict(bm.asarray(X_train), reset_state=True)

start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.arange(end - start).to_numpy(),
         bm.as_numpy(ys_predict)[0, start:end, 0],
         lw=3,
         label="ESN prediction")
plt.plot(bm.arange(end - start).to_numpy(),
         Y_train[0, start:end, 0],
         linestyle="--",
         lw=2,
         label="True value")
plt.legend()
plt.show()
_images/reservoir_computing_predicting_Mackey_Glass_timeseries_15_1.png
[12]:
ys_predict = trainer.predict(bm.asarray(X_test), reset_state=True)

start, end = 100, 600
plt.figure(figsize=(15, 7))
plt.subplot(211)
plt.plot(bm.arange(end - start).to_numpy(),
         bm.as_numpy(ys_predict)[0, start:end, 0],
         lw=3,
         label="ESN prediction")
plt.plot(bm.arange(end - start).to_numpy(),
         Y_test[0, start:end, 0],
         linestyle="--",
         lw=2,
         label="True value")
plt.title(f'Mean Square Error: {bp.losses.mean_squared_error(bm.as_numpy(ys_predict), Y_test)}')
plt.legend()
plt.show()
_images/reservoir_computing_predicting_Mackey_Glass_timeseries_16_1.png

(Gauthier, et. al, 2021): Next generation reservoir computing

Implementation of the paper:

[1]:
import matplotlib.pyplot as plt
import numpy as np

import brainpy as bp
import brainpy.math as bm
import brainpy_datasets as bd

bm.set_environment(bm.batching_mode, x64=True)
[2]:
def plot_weights(Wout, coefs, bias=None):
  Wout = np.asarray(Wout)
  if bias is not None:
    bias = np.asarray(bias)
    Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0)
    coefs.insert(0, 'bias')
  x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2]

  fig = plt.figure(figsize=(10, 10))
  ax = fig.add_subplot(131)
  ax.grid(axis="y")
  ax.set_xlabel("$[W_{out}]_x$")
  ax.set_ylabel("Features")
  ax.set_yticks(np.arange(len(coefs)))
  ax.set_yticklabels(coefs)
  ax.barh(np.arange(x_Wout.size), x_Wout)

  ax1 = fig.add_subplot(132)
  ax1.grid(axis="y")
  ax1.set_yticks(np.arange(len(coefs)))
  ax1.set_xlabel("$[W_{out}]_y$")
  ax1.barh(np.arange(y_Wout.size), y_Wout)

  ax2 = fig.add_subplot(133)
  ax2.grid(axis="y")
  ax2.set_yticks(np.arange(len(coefs)))
  ax2.set_xlabel("$[W_{out}]_z$")
  ax2.barh(np.arange(z_Wout.size), z_Wout)

  plt.show()

Forecasting Lorenz63 strange attractor

[3]:
def get_subset(data, start, end):
  res = {'x': data.xs[start: end],
         'y': data.ys[start: end],
         'z': data.zs[start: end]}
  res = bm.hstack([res['x'], res['y'], res['z']])
  return res.reshape((1, ) + res.shape)
[4]:
def plot_lorenz(ground_truth, predictions):
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(121, projection='3d')
  ax.set_title("Generated attractor")
  ax.set_xlabel("$x$")
  ax.set_ylabel("$y$")
  ax.set_zlabel("$z$")
  ax.grid(False)
  ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.set_title("Real attractor")
  ax2.grid(False)
  ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
  plt.show()
[5]:
dt = 0.01
t_warmup = 5.  # ms
t_train = 10.  # ms
t_test = 120.  # ms
num_warmup = int(t_warmup / dt)  # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)

Datasets

[6]:
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
                                  dt=dt,
                                  inits={'x': 17.67715816276679,
                                         'y': 12.931379185960404,
                                         'z': 43.91404334248268})

X_warmup = get_subset(lorenz_series, 0, num_warmup - 1)
Y_warmup = get_subset(lorenz_series, 1, num_warmup)
X_train = get_subset(lorenz_series, num_warmup - 1, num_warmup + num_train - 1)
# Target: Lorenz[t] - Lorenz[t - 1]
dX_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_train
X_test = get_subset(lorenz_series,
                    num_warmup + num_train - 1,
                    num_warmup + num_train + num_test - 1)
Y_test = get_subset(lorenz_series,
                    num_warmup + num_train,
                    num_warmup + num_train + num_test)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Model

[7]:
class NGRC(bp.DynamicalSystem):
  def __init__(self, num_in):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True,)
    self.di = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None, mode=bm.training_mode)

  def update(self, sha, x):
    dx = self.di(sha, self.r(sha, x))
    return x + dx
[8]:
model = NGRC(3)
print(model.r.get_feature_names())
['1', 'x0(t)', 'x1(t)', 'x2(t)', 'x0(t-1)', 'x1(t-1)', 'x2(t-1)', 'x0(t)^2', 'x0(t) x1(t)', 'x0(t) x2(t)', 'x0(t) x0(t-1)', 'x0(t) x1(t-1)', 'x0(t) x2(t-1)', 'x1(t)^2', 'x1(t) x2(t)', 'x1(t) x0(t-1)', 'x1(t) x1(t-1)', 'x1(t) x2(t-1)', 'x2(t)^2', 'x2(t) x0(t-1)', 'x2(t) x1(t-1)', 'x2(t) x2(t-1)', 'x0(t-1)^2', 'x0(t-1) x1(t-1)', 'x0(t-1) x2(t-1)', 'x1(t-1)^2', 'x1(t-1) x2(t-1)', 'x2(t-1)^2']

Training

[9]:
# warm-up
trainer = bp.RidgeTrainer(model, alpha=2.5e-6)

# training
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
trainer.fit([X_train, dX_train])
plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b)
Warmup NMS:  107865.37236001804
_images/reservoir_computing_Gauthier_2021_ngrc_14_4.png

Prediction

[10]:
model = bm.jit(model)
outputs = [model(dict(), X_test[:, 0])]
for i in range(1, X_test.shape[1]):
  outputs.append(model(dict(), outputs[i - 1]))
outputs = bm.asarray(outputs)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze())
Prediction NMS:  148.18244150701176
_images/reservoir_computing_Gauthier_2021_ngrc_16_1.png

Forecasting the double-scroll system

[11]:
def plot_double_scroll(ground_truth, predictions):
  fig = plt.figure(figsize=(15, 10))
  ax = fig.add_subplot(121, projection='3d')
  ax.set_title("Generated attractor")
  ax.set_xlabel("$x$")
  ax.set_ylabel("$y$")
  ax.set_zlabel("$z$")
  ax.grid(False)
  ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2])

  ax2 = fig.add_subplot(122, projection='3d')
  ax2.set_title("Real attractor")
  ax2.grid(False)
  ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2])
  plt.show()
[12]:
dt = 0.02
t_warmup = 10.  # ms
t_train = 100.  # ms
t_test = 800.  # ms
num_warmup = int(t_warmup / dt)  # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)

Datasets

[13]:
data_series = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt)

X_warmup = get_subset(data_series, 0, num_warmup - 1)
Y_warmup = get_subset(data_series, 1, num_warmup)
X_train = get_subset(data_series, num_warmup - 1, num_warmup + num_train - 1)
# Target: Lorenz[t] - Lorenz[t - 1]
dX_train = get_subset(data_series, num_warmup, num_warmup + num_train) - X_train
X_test = get_subset(data_series,
                    num_warmup + num_train - 1,
                    num_warmup + num_train + num_test - 1)
Y_test = get_subset(data_series,
                    num_warmup + num_train,
                    num_warmup + num_train + num_test)

Model

[14]:
class NGRC(bp.DynamicalSystem):
  def __init__(self, num_in):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=2, order=3)
    self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode)

  def update(self, sha, x):
    di = self.di(sha, self.r(sha, x))
    return x + di

model = NGRC(3)

Training

[15]:
# warm-up
trainer = bp.RidgeTrainer(model, alpha=1e-5, jit=True)
[16]:
# training
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))
trainer.fit([X_train, dX_train])
plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b)
Warmup NMS:  3.1532995057589828
_images/reservoir_computing_Gauthier_2021_ngrc_26_4.png

Prediction

[17]:
model = bm.jit(model)
outputs = [model(dict(), X_test[:, 0])]
for i in range(1, X_test.shape[1]):
  outputs.append(model(dict(), outputs[i - 1]))
outputs = bm.asarray(outputs).squeeze()
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
plot_double_scroll(Y_test.numpy().squeeze(), outputs.numpy())
Prediction NMS:  1.4541218384321954
_images/reservoir_computing_Gauthier_2021_ngrc_28_1.png

Infering dynamics of Lorenz63 strange attractor

[18]:
def get_subset(data, start, end):
  res = {'x': data.xs[start: end],
         'y': data.ys[start: end],
         'z': data.zs[start: end]}
  X = bm.hstack([res['x'], res['y']])
  X = X.reshape((1,) + X.shape)
  Y = res['z']
  Y = Y.reshape((1, ) + Y.shape)
  return X, Y
[19]:
def plot_lorenz2(x, y, true_z, predict_z, linewidth=.8):
  fig1 = plt.figure()
  fig1.set_figheight(8)
  fig1.set_figwidth(12)

  t_all = t_warmup + t_train + t_test
  ts = np.arange(0, t_all, dt)

  h = 240
  w = 2

  # top left of grid is 0,0
  axs1 = plt.subplot2grid(shape=(h, w), loc=(0, 0), colspan=2, rowspan=30)
  axs2 = plt.subplot2grid(shape=(h, w), loc=(36, 0), colspan=2, rowspan=30)
  axs3 = plt.subplot2grid(shape=(h, w), loc=(72, 0), colspan=2, rowspan=30)
  axs4 = plt.subplot2grid(shape=(h, w), loc=(132, 0), colspan=2, rowspan=30)
  axs5 = plt.subplot2grid(shape=(h, w), loc=(168, 0), colspan=2, rowspan=30)
  axs6 = plt.subplot2grid(shape=(h, w), loc=(204, 0), colspan=2, rowspan=30)

  # training phase x
  axs1.set_title('training phase')
  axs1.plot(ts[num_warmup:num_warmup + num_train],
            x[num_warmup:num_warmup + num_train],
            color='b', linewidth=linewidth)
  axs1.set_ylabel('x')
  axs1.axes.xaxis.set_ticklabels([])
  axs1.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
  axs1.axes.set_ybound(-21., 21.)
  axs1.text(-.14, .9, 'a)', ha='left', va='bottom', transform=axs1.transAxes)

  # training phase y
  axs2.plot(ts[num_warmup:num_warmup + num_train],
            y[num_warmup:num_warmup + num_train],
            color='b', linewidth=linewidth)
  axs2.set_ylabel('y')
  axs2.axes.xaxis.set_ticklabels([])
  axs2.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
  axs2.axes.set_ybound(-26., 26.)
  axs2.text(-.14, .9, 'b)', ha='left', va='bottom', transform=axs2.transAxes)

  # training phase z
  axs3.plot(ts[num_warmup:num_warmup + num_train],
            true_z[num_warmup:num_warmup + num_train],
            color='b', linewidth=linewidth)
  axs3.plot(ts[num_warmup:num_warmup + num_train],
            predict_z[num_warmup:num_warmup + num_train],
            color='r', linewidth=linewidth)
  axs3.set_ylabel('z')
  axs3.set_xlabel('time')
  axs3.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05)
  axs3.axes.set_ybound(3., 48.)
  axs3.text(-.14, .9, 'c)', ha='left', va='bottom', transform=axs3.transAxes)

  # testing phase x
  axs4.set_title('testing phase')
  axs4.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
            x[num_warmup + num_train:num_warmup + num_train + num_test],
            color='b', linewidth=linewidth)
  axs4.set_ylabel('x')
  axs4.axes.xaxis.set_ticklabels([])
  axs4.axes.set_ybound(-21., 21.)
  axs4.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
  axs4.text(-.14, .9, 'd)', ha='left', va='bottom', transform=axs4.transAxes)

  # testing phase y
  axs5.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
            y[num_warmup + num_train:num_warmup + num_train + num_test],
            color='b', linewidth=linewidth)
  axs5.set_ylabel('y')
  axs5.axes.xaxis.set_ticklabels([])
  axs5.axes.set_ybound(-26., 26.)
  axs5.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
  axs5.text(-.14, .9, 'e)', ha='left', va='bottom', transform=axs5.transAxes)

  # testing phose z
  axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
            true_z[num_warmup + num_train:num_warmup + num_train + num_test],
            color='b', linewidth=linewidth)
  axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test],
            predict_z[num_warmup + num_train:num_warmup + num_train + num_test],
            color='r', linewidth=linewidth)
  axs6.set_ylabel('z')
  axs6.set_xlabel('time')
  axs6.axes.set_ybound(3., 48.)
  axs6.axes.set_xbound(t_warmup + t_train - .5, t_all + .5)
  axs6.text(-.14, .9, 'f)', ha='left', va='bottom', transform=axs6.transAxes)

  plt.show()
[20]:
dt = 0.02
t_warmup = 10.  # ms
t_train = 20.  # ms
t_test = 50.  # ms
num_warmup = int(t_warmup / dt)  # warm up NVAR
num_train = int(t_train / dt)
num_test = int(t_test / dt)

Datasets

[21]:
lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test,
                                  dt=dt,
                                  inits={'x': 17.67715816276679,
                                         'y': 12.931379185960404,
                                         'z': 43.91404334248268})

X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup)
X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train)
X_test, Y_test = get_subset(lorenz_series, 0, num_warmup + num_train + num_test)

Model

[22]:
class NGRC(bp.DynamicalSystem):
  def __init__(self, num_in):
    super(NGRC, self).__init__()
    self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5)
    self.o = bp.layers.Dense(self.r.num_out, 1, mode=bm.training_mode)

  def update(self, sha, x):
    return self.o(sha, self.r(sha, x))


model = NGRC(2)

Training

[23]:
trainer = bp.RidgeTrainer(model, alpha=0.05)

# warm-up
outputs = trainer.predict(X_warmup)
print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup))

# training
_ = trainer.fit([X_train, Y_train])
Warmup NMS:  7268.475225524938

Prediction

[24]:
outputs = trainer.predict(X_test, reset_state=True)
print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test))
Prediction NMS:  590.5516565428906
[25]:
plot_lorenz2(x=bm.as_numpy(lorenz_series.xs.flatten()),
             y=bm.as_numpy(lorenz_series.ys.flatten()),
             true_z=bm.as_numpy(lorenz_series.zs.flatten()),
             predict_z=bm.as_numpy(outputs.flatten()))
_images/reservoir_computing_Gauthier_2021_ngrc_41_0.png

(Sussillo & Abbott, 2009) FORCE Learning

Implementation of the paper:

  • Sussillo, David, and Larry F. Abbott. “Generating coherent patterns of activity from chaotic neural networks.” Neuron 63, no. 4 (2009): 544-557.

[1]:
import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
[3]:
class EchoStateNet(bp.DynamicalSystem):
  r"""The continuous-time Echo State Network.

  .. math::

    \frac{dh}{dt} = -h + W_{ir} * x + W_{rr} * r + W_{or} * z \\
    r = \tanh(h) \\
    o = W_{ro} * r
  """

  def __init__(self, num_input, num_hidden, num_output,
               tau=1.0, dt=0.1, g=1.8, alpha=1.0, **kwargs):
    super(EchoStateNet, self).__init__(**kwargs)

    # parameters
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    self.tau = tau
    self.dt = dt
    self.g = g
    self.alpha = alpha

    # weights
    self.w_ir = bm.random.normal(size=(num_input, num_hidden)) / bm.sqrt(num_input)
    self.w_rr = g * bm.random.normal(size=(num_hidden, num_hidden)) / bm.sqrt(num_hidden)
    self.w_or = bm.random.normal(size=(num_output, num_hidden))
    w_ro = bm.random.normal(size=(num_hidden, num_output)) / bm.sqrt(num_hidden)
    self.w_ro = bm.Variable(w_ro)

    # variables
    self.h = bm.Variable(bm.random.normal(size=num_hidden) * 0.5)  # hidden
    self.r = bm.Variable(bm.tanh(self.h))  # firing rate
    self.o = bm.Variable(bm.dot(self.r, w_ro))  # output unit
    self.P = bm.Variable(bm.eye(num_hidden) * self.alpha)  # inverse correlation matrix

  def update(self, x):
    # update the hidden and output state
    dhdt = -self.h + bm.dot(x, self.w_ir)
    dhdt += bm.dot(self.r, self.w_rr)
    dhdt += bm.dot(self.o, self.w_or)
    self.h += self.dt / self.tau * dhdt
    self.r.value = bm.tanh(self.h)
    self.o.value = bm.dot(self.r, self.w_ro)

  def rls(self, target):
    # update the inverse correlation matrix
    k = bm.expand_dims(bm.dot(self.P, self.r), axis=1)  # (num_hidden, 1)
    hPh = bm.dot(self.r.T, k)  # (1,)
    c = 1.0 / (1.0 + hPh)  # (1,)
    self.P -= bm.dot(k * c, k.T)  # (num_hidden, num_hidden)
    # update the output weights
    e = bm.atleast_2d(self.o - target)  # (1, num_output)
    dw = bm.dot(-c * k, e)  # (num_hidden, num_output)
    self.w_ro += dw

  def simulate(self, xs):
    f = bm.make_loop(self.update, dyn_vars=[self.h, self.r, self.o], out_vars=[self.r, self.o])
    return f(xs)

  def train(self, xs, targets):
    def _f(x):
      input, target = x
      self.update(input)
      self.rls(target)

    f = bm.make_loop(_f, dyn_vars=self.vars(), out_vars=[self.r, self.o])
    return f([xs, targets])
[4]:
def print_force(ts, rates, outs, targets, duration, ntoplot=10):
  """Plot activations and outputs for the Echo state network."""
  plt.figure(figsize=(16, 16))

  plt.subplot(321)
  plt.plot(ts, targets + 2 * np.arange(0, targets.shape[1]), 'g')
  plt.plot(ts, outs + 2 * np.arange(0, outs.shape[1]), 'r')
  plt.xlim((0, duration))
  plt.title('Target (green), Output (red)')
  plt.xlabel('Time')
  plt.ylabel('Dimension')

  plt.subplot(122)
  plt.imshow(rates.T, interpolation=None)
  plt.title('Hidden activations of ESN')
  plt.xlabel('Time')
  plt.ylabel('Dimension')

  plt.subplot(323)
  plt.plot(ts, rates[:, 0:ntoplot] + 2 * np.arange(0, ntoplot), 'b')
  plt.xlim((0, duration))
  plt.title('%d hidden activations of ESN' % (ntoplot))
  plt.xlabel('Time')
  plt.ylabel('Dimension')

  plt.subplot(325)
  plt.plot(ts, np.sqrt(np.square(outs - targets)), 'c')
  plt.xlim((0, duration))
  plt.title('Error - mean absolute error')
  plt.xlabel('Time')
  plt.ylabel('Error')

  plt.tight_layout()
  plt.show()
[5]:
def plot_params(net):
  """Plot some of the parameters associated with the ESN."""
  assert isinstance(net, EchoStateNet)

  plt.figure(figsize=(16, 10))
  plt.subplot(221)
  plt.imshow(bm.as_numpy(net.w_rr + net.w_ro @ net.w_or), interpolation=None)
  plt.colorbar()
  plt.title('Effective matrix - W_rr + W_ro * W_or')

  plt.subplot(222)
  plt.imshow(bm.as_numpy(net.w_ro), interpolation=None)
  plt.colorbar()
  plt.title('Readout weights - W_ro')

  x_circ = np.linspace(-1, 1, 1000)
  y_circ = np.sqrt(1 - x_circ ** 2)
  evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr))
  plt.subplot(223)
  plt.plot(np.real(evals), np.imag(evals), 'o')
  plt.plot(x_circ, y_circ, 'k')
  plt.plot(x_circ, -y_circ, 'k')
  plt.axis('equal')
  plt.title('Eigenvalues of W_rr')

  evals, _ = np.linalg.eig(bm.as_numpy(net.w_rr + net.w_ro @ net.w_or))
  plt.subplot(224)
  plt.plot(np.real(evals), np.imag(evals), 'o', color='orange')
  plt.plot(x_circ, y_circ, 'k')
  plt.plot(x_circ, -y_circ, 'k')
  plt.axis('equal')
  plt.title('Eigenvalues of W_rr + W_ro * W_or')

  plt.tight_layout()
  plt.show()
[6]:
dt = 0.1
T = 30
times = bm.arange(0, T, dt)
xs = bm.zeros((times.shape[0], 1))

Generate some target data by running an ESN, and just grabbing hidden dimensions as the targets of the FORCE trained network.

[7]:
esn1 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.8)
rs, ys = esn1.simulate(xs)
targets = rs[:, 0: esn1.num_output]  # This will be the training data for the trained ESN
plt.plot(times, targets + 2 * np.arange(0, esn1.num_output), 'g')
plt.xlim((0, T))
plt.ylabel('Dimensions')
plt.xlabel('Time')
plt.show()
_images/recurrent_networks_Sussillo_Abbott_2009_FORCE_Learning_9_0.png

Un-trained ESN.

[8]:
esn2 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.5)
rs, ys = esn2.simulate(xs)  # the untrained ESN
print_force(times, rates=rs, outs=ys, targets=targets, duration=T, ntoplot=10)
_images/recurrent_networks_Sussillo_Abbott_2009_FORCE_Learning_11_0.png

Trained ESN.

[9]:
esn3 = EchoStateNet(num_input=1, num_hidden=500, num_output=20, dt=dt, g=1.5, alpha=1.)
rs, ys = esn3.train(xs=xs, targets=targets)  # train once
print_force(times, rates=rs, outs=ys, targets=targets, duration=T, ntoplot=10)
_images/recurrent_networks_Sussillo_Abbott_2009_FORCE_Learning_13_0.png
[10]:
plot_params(esn3)
_images/recurrent_networks_Sussillo_Abbott_2009_FORCE_Learning_14_0.png

(Sherman & Rinzel, 1992) Gap junction leads to anti-synchronization

Implementation of the paper:

  • Sherman, A., & Rinzel, J. (1992). Rhythmogenic effects of weak electrotonic coupling in neuronal models. Proceedings of the National Academy of Sciences, 89(6), 2471-2474.

Author: Chaoming Wang

[1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
[2]:
import matplotlib.pyplot as plt

Fig 1: weakly coupled cells can oscillate antiphase

The “square-wave burster” model is given by:

\[\begin{split}\begin{split} \tau {dV \over dt} &= -I_{in}(V) - I_{out}(V) - g_s S (V-V_K) + I + I_j \\ \tau {dn \over dt} &= \lambda (n_{\infty} - n) \end{split}\end{split}\]

where

\[\begin{split}\begin{split} I_{in}(V) &= g_{Ca} m (V-V_{Ca}) \\ I_{out}(V) &= g_Kn(V-V_K) \\ m &= m_{\infty}(V) \\ x_{\infty} (V) &= {1 \over 1 + \exp[(V_x - V) / \theta_x]} \\ I_j &= -g_c (V-\bar{V}) \end{split}\end{split}\]

where

\[\begin{split}\begin{split} S &= 0.15 \\ \lambda &= 0.8 \\ V_m &= -20 \\ \theta_m &= 12 \\ V_n &= -17 \\ \theta_n &= 5.6 \\ V_{Ca} &= 25 \\ V_{K} &= -75 \\ \tau &= 20 \, \mathrm{ms} \\ g_{Ca} &= 3.6 \\ g_K &= 10 \\ g_s &= 4 \end{split}\end{split}\]

At t = 0.5 s the junctional-coupling conductance, \(g_c\), is raised to 0.08, and a small symmetry-breaking perturbation (0.3 mV) is applied to one of the cells. This destabilizes the single-cell oscillation and leads to an antiphase oscillation. At t = 5.5 s the single-cell behavior is restored by increasing \(g_c\) to 0.24; alternatively, one could set \(g_c\) to 0, but then the two cells would not be in-phase.

[3]:
lambda_ = 0.8
V_m = -20
theta_m = 12
V_n = -17
theta_n = 5.6
V_ca = 25
V_K = -75
tau = 20
g_ca = 3.6
g_K = 10
g_s = 4
[4]:
class Model1(bp.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(Model1, self).__init__()

    # parameters
    self.gc = bm.Variable(bm.zeros(1))
    self.I = bm.Variable(bm.zeros(2))
    self.S = 0.15

    # variables
    self.V = bm.Variable(bm.zeros(2))
    self.n = bm.Variable(bm.zeros(2))

    # integral
    self.integral = bp.odeint(bp.JointEq([self.dV, self.dn]), method=method)

  def dV(self, V, t, n):
    I_in = g_ca / (1 + bp.math.exp((V_m - V) / theta_m)) * (V - V_ca)
    I_out = g_K * n * (V - V_K)
    Is = g_s * self.S * (V - V_K)
    Ij = self.gc * bm.array([V[0] - V[1], V[1] - V[0]])
    dV = (- I_in - I_out - Is - Ij + self.I) / tau
    return dV

  def dn(self, n, t, V):
    n_inf = 1 / (1 + bp.math.exp((V_n - V) / theta_n))
    dn = lambda_ * (n_inf - n) / tau
    return dn

  def update(self, tdi):
    V, n = self.integral(self.V, self.n, tdi.t, tdi.dt)
    self.V.value = V
    self.n.value = n
[5]:
def run_and_plot1(model, duration, inputs=None, plot_duration=None):
  runner = bp.dyn.DSRunner(model, inputs=inputs, monitors=['V', 'n', 'gc', 'I'])
  runner.run(duration)

  fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
  plot_duration = (0, duration) if plot_duration is None else plot_duration

  fig.add_subplot(gs[0:3, 0])
  plt.plot(runner.mon.ts, runner.mon.V)
  plt.ylabel('V [mV]')
  plt.xlim(*plot_duration)

  fig.add_subplot(gs[3, 0])
  plt.plot(runner.mon.ts, runner.mon.gc)
  plt.ylabel(r'$g_c$')
  plt.xlim(*plot_duration)

  fig.add_subplot(gs[4, 0])
  plt.plot(runner.mon.ts, runner.mon.I[:, 0])
  plt.ylabel(r'$I_0$')
  plt.xlim(*plot_duration)

  plt.xlabel('Time [ms]')
  plt.show()
[6]:
model = Model1()
model.S = 0.15
model.V[:] = -55.
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
[7]:
gc = bp.inputs.section_input(values=[0., 0.0, 0.24], durations=[500, 5000, 1500])
Is = bp.inputs.section_input(values=[0., bm.array([0.3, 0.])], durations=[500., 6500.])
run_and_plot1(model, duration=7000, inputs=[('gc', gc, 'iter', '='),
                                            ('I', Is, 'iter', '=')])
_images/gj_nets_Sherman_1992_gj_antisynchrony_14_1.png

Fig 2: weak coupling can convert excitable cells into spikers

Cells are initially uncoupled and at rest, but one cell has a current of strength 1.0 injected for 0.5 s, resulting in two spikes. Spiking ends when the current stimulus is removed. The unstimulated cell remains at rest. At t = 2 s, \(g_c\) is increased to 0.04. This does not prevent the stimulated cell from remaining at rest, but the system is now bistable and the rest state coexists with an antiphase oscillation. A second identical current stimulus draws both cells near enough to the oscillatory solution so that they continue to oscillate after the stimulus terminates.

[8]:
model = Model1()
model.S = 0.177
model.V[:] = -62.69
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
[9]:
gc = bp.inputs.section_input(values=[0, 0.04], durations=[2000, 2500])
Is = bp.inputs.section_input(values=[bm.array([1., 0.]), 0., bm.array([1., 0.]), 0.],
                             durations=[500, 2000, 500, 1500])

run_and_plot1(model, 4500, inputs=[('gc', gc, 'iter', '='),
                                   ('I', Is, 'iter', '=')])
_images/gj_nets_Sherman_1992_gj_antisynchrony_18_1.png

Fig 3: weak coupling can increase the period of bursting

We consider cells with endogenous bursting properties. Now \(S\) is a slow dynamic variable, satisfying

\[\tau_S {dS \over dt} = S_{\infty}(V) - S\]

with \(\tau_S \gg \tau\).

[10]:
tau_S = 35 * 1e3  # ms
V_S = -38  # mV
theta_S = 10  # mV
[11]:
class Model2(bp.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(Model2, self).__init__()

    # parameters
    self.lambda_ = 0.1
    self.gc = bm.Variable(bm.zeros(1))
    self.I = bm.Variable(bm.zeros(2))

    # variables
    self.V = bm.Variable(bm.zeros(2))
    self.n = bm.Variable(bm.zeros(2))
    self.S = bm.Variable(bm.zeros(2))

    # integral
    self.integral = bp.odeint(bp.JointEq([self.dV, self.dn, self.dS]), method=method)

  def dV(self, V, t, n, S):
    I_in = g_ca / (1 + bm.exp((V_m - V) / theta_m)) * (V - V_ca)
    I_out = g_K * n * (V - V_K)
    Is = g_s * S * (V - V_K)
    Ij = self.gc * bm.array([V[0] - V[1], V[1] - V[0]])
    dV = (- I_in - I_out - Is - Ij + self.I) / tau
    return dV

  def dn(self, n, t, V):
    n_inf = 1 / (1 + bm.exp((V_n - V) / theta_n))
    dn = self.lambda_ * (n_inf - n) / tau
    return dn

  def dS(self, S, t, V):
    S_inf = 1 / (1 + bm.exp((V_S - V) / theta_S))
    dS = (S_inf - S) / tau_S
    return dS

  def update(self, tdi):
    V, n, S = self.integral(self.V, self.n, self.S, tdi.t, dt=tdi.dt)
    self.V.value = V
    self.n.value = n
    self.S.value = S
[12]:
def run_and_plot2(model, duration, inputs=None, plot_duration=None):
  runner = bp.dyn.DSRunner(model, inputs=inputs, monitors=['V', 'S'])
  runner.run(duration)

  fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
  plot_duration = (0, duration) if plot_duration is None else plot_duration

  fig.add_subplot(gs[0:3, 0])
  plt.plot(runner.mon.ts, runner.mon.V)
  plt.ylabel('V [mV]')
  plt.xlim(*plot_duration)

  fig.add_subplot(gs[3:, 0])
  plt.plot(runner.mon.ts, runner.mon.S)
  plt.ylabel('S')
  plt.xlim(*plot_duration)

  plt.xlabel('Time [ms]')
  plt.show()

With \(\lambda = 0.9\), an isolated cell alternates periodically between a depolarized spiking phase and a hyperpolarized silent phase.

[13]:
model = Model2()
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.
model.I[:] = 0.
[14]:
run_and_plot2(model, 50 * 1e3)
_images/gj_nets_Sherman_1992_gj_antisynchrony_26_1.png

When two identical bursters are coupled with \(g_c = 0.06\) and started in-phase, they initially follow the single-cell bursting solution. This behavior is unstable, however, and a new stable burst pattern emerges during the second burst with smaller amplitude, higher frequency, antiphase spikes.

[15]:
model = Model2(method='exp_auto')
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
[16]:
run_and_plot2(model, 50 * 1e3)
_images/gj_nets_Sherman_1992_gj_antisynchrony_29_1.png
[17]:
model = Model2(method='exp_auto')
model.lambda_ = 0.9
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
run_and_plot2(model, 4 * 1e3)
_images/gj_nets_Sherman_1992_gj_antisynchrony_30_1.png

Fig 4: weak coupling can convert spikers to bursters

Parameters are the same as in Fig. 3, except \(\lambda = 0.8\), resulting in repetitive spiking (beating) instead of bursting. Oscillations in \(S\) are nearly abolished. Two identical cells are started with identical initial conditions (only one shown for clarity). At t = 20 s, \(g_c\) is increased to 0.04 (right arrow) and a small symmetry-breaking perturbation (0.3 mV) is applied to one cell. After a brief transient, the two cells begin to burst in-phase but with antiphase spikes, as in Fig. 3.

[18]:
model = Model2()
model.lambda_ = 0.8
model.S[:] = 0.172
model.V[:] = V_S - theta_S * bm.log(1 / model.S - 1)
model.n[:] = 1 / (1 + bm.exp((V_n - model.V) / theta_n))
model.gc[:] = 0.06
model.I[:] = 0.
[19]:
gc = bp.inputs.section_input(values=[0., 0.04], durations=[20 * 1e3, 30 * 1e3])
Is = bp.inputs.section_input(values=[0., bp.math.array([0.3, 0.])], durations=[20 * 1e3, 30 * 1e3])
run_and_plot2(model, 50 * 1e3, inputs=[('gc', gc, 'iter', '='),
                                       ('I', Is, 'iter', '=')])
_images/gj_nets_Sherman_1992_gj_antisynchrony_34_1.png

(Fazli and Richard, 2022): Electrically Coupled Bursting Pituitary Cells

Implementation of the paper:

  • Fazli, Mehran, and Richard Bertram. “Network Properties of Electrically Coupled Bursting Pituitary Cells.” Frontiers in Endocrinology 13 (2022).

[1]:
import brainpy as bp
import brainpy.math as bm
[2]:
class PituitaryCell(bp.NeuGroup):
  def __init__(self, size, name=None):
    super(PituitaryCell, self).__init__(size, name=name)

    # parameter values
    self.vn = -5
    self.kc = 0.12
    self.ff = 0.005
    self.vca = 60
    self.vk = -75
    self.vl = -50.0
    self.gk = 2.5
    self.cm = 5
    self.gbk = 1
    self.gca = 2.1
    self.gsk = 2
    self.vm = -20
    self.vb = -5
    self.sn = 10
    self.sm = 12
    self.sbk = 2
    self.taun = 30
    self.taubk = 5
    self.ks = 0.4
    self.alpha = 0.0015
    self.gl = 0.2

    # variables
    self.V = bm.Variable(bm.random.random(self.num) * -90 + 20)
    self.n = bm.Variable(bm.random.random(self.num) / 2)
    self.b = bm.Variable(bm.random.random(self.num) / 2)
    self.c = bm.Variable(bm.random.random(self.num))
    self.input = bm.Variable(self.num)

    # integrators
    self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler')

  def dn(self, n, t, V):
    ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn))
    return (ninf - n) / self.taun

  def db(self, b, t, V):
    bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk))
    return (bkinf - b) / self.taubk

  def dc(self, c, t, V):
    minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
    ica = self.gca * minf * (V - self.vca)
    return -self.ff * (self.alpha * ica + self.kc * c)

  def dV(self, V, t, n, b, c):
    minf = 1 / (1 + bm.exp((self.vm - V) / self.sm))
    cinf = c ** 2 / (c ** 2 + self.ks * self.ks)
    ica = self.gca * minf * (V - self.vca)
    isk = self.gsk * cinf * (V - self.vk)
    ibk = self.gbk * b * (V - self.vk)
    ikdr = self.gk * n * (V - self.vk)
    il = self.gl * (V - self.vl)
    return -(ica + isk + ibk + ikdr + il + self.input) / self.cm

  def update(self, tdi, x=None):
    V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt)
    self.V.value = V
    self.n.value = n
    self.c.value = c
    self.b.value = b

  def clear_input(self):
    self.input.value = bm.zeros_like(self.input)
[3]:
class PituitaryNetwork(bp.Network):
  def __init__(self, num, gc):
    super(PituitaryNetwork, self).__init__()

    self.N = PituitaryCell(num)
    self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc)
[4]:
net = PituitaryNetwork(2, 0.002)
runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5)
runner.run(10 * 1e3)

fig, gs = bp.visualize.get_figure(1, 1, 6, 10)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/gj_nets_Fazli_2022_gj_coupled_bursting_pituitary_cells_5_2.png

(Wang & Buzsáki, 1996) Gamma Oscillation

Here we show the implementation of gamma oscillation proposed by Xiao-Jing Wang and György Buzsáki (1996). They demonstrated that the GABA\(_A\) synaptic transmission provides a suitable mechanism for synchronized gamma oscillations in a network of fast-spiking interneurons.

Let’s first import brainpy and set profiles.

[1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_dt(0.05)

The network is constructed with Hodgkin–Huxley (HH) type neurons and GABA\(_A\) synapses.

The dynamics of the HH type neurons is given by:

\[C \frac {dV} {dt} = -(I_{Na} + I_{K} + I_L) + I(t)\]

where \(I(t)\) is the injected current, the leak current $ I_L = g_L (V - E_L) $, and the transient sodium current

\[I_{Na} = g_{Na} m_{\infty}^3 h (V - E_{Na})\]

where the activation variable \(m\) is assumed fast and substituted by its steady-state function \(m_{\infty} = \alpha_m / (\alpha_m + \beta_m)\). And the inactivation variable \(h\) obeys a first=order kinetics:

\[\frac {dh} {dt} = \phi (\alpha_h (1-h) - \beta_h h)\]
\[I_K = g_K n^4 (V - E_K)\]

where the activation variable \(n\) also obeys a first=order kinetics:

\[\frac {dn} {dt} = \phi (\alpha_n (1-n) - \beta_n n)\]
[2]:
class HH(bp.NeuGroup):
  def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35.,
               gK=9., gL=0.1, V_th=20., phi=5.0, method='exp_auto'):
    super(HH, self).__init__(size=size)

    # parameters
    self.ENa = ENa
    self.EK = EK
    self.EL = EL
    self.C = C
    self.gNa = gNa
    self.gK = gK
    self.gL = gL
    self.V_th = V_th
    self.phi = phi

    # variables
    self.V = bm.Variable(bm.ones(size) * -65.)
    self.h = bm.Variable(bm.ones(size) * 0.6)
    self.n = bm.Variable(bm.ones(size) * 0.32)
    self.spike = bm.Variable(bm.zeros(size, dtype=bool))
    self.input = bm.Variable(bm.zeros(size))
    self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)

    # integral
    self.integral = bp.odeint(bp.JointEq([self.dV, self.dh, self.dn]), method=method)

  def dh(self, h, t, V):
    alpha = 0.07 * bm.exp(-(V + 58) / 20)
    beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1)
    dhdt = alpha * (1 - h) - beta * h
    return self.phi * dhdt

  def dn(self, n, t, V):
    alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
    beta = 0.125 * bm.exp(-(V + 44) / 80)
    dndt = alpha * (1 - n) - beta * n
    return self.phi * dndt

  def dV(self, V, t, h, n, Iext):
    m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
    m_beta = 4 * bm.exp(-(V + 60) / 18)
    m = m_alpha / (m_alpha + m_beta)
    INa = self.gNa * m ** 3 * h * (V - self.ENa)
    IK = self.gK * n ** 4 * (V - self.EK)
    IL = self.gL * (V - self.EL)
    dVdt = (- INa - IK - IL + Iext) / self.C

    return dVdt

  def update(self, tdi):
    V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.input, tdi.dt)
    self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.t_last_spike.value = bm.where(self.spike, tdi.t, self.t_last_spike)
    self.V.value = V
    self.h.value = h
    self.n.value = n
    self.input[:] = 0.

Let’s run a simulation of a network with 100 neurons with constant inputs (1 \(\mu\)A/cm\(^2\)).

[3]:
num = 100
neu = HH(num)
neu.V[:] = -70. + bm.random.normal(size=num) * 20

syn = bp.synapses.GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False))
syn.g_max = 0.1 / num

net = bp.Network(neu=neu, syn=syn)
runner = bp.DSRunner(net, monitors=['neu.spike', 'neu.V'], inputs=['neu.input', 1.])
runner.run(duration=500.)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)

fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon['neu.V'], ylabel='Membrane potential (N0)')

fig.add_subplot(gs[1, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['neu.spike'], show=True)
_images/oscillation_synchronization_Wang_1996_gamma_oscillation_9_0.png

We can see the result of this simulation that cells starting at random and asynchronous initial conditions quickly become synchronized and their spiking times are perfectly in-phase within 5-6 oscillatory cycles.

Reference:

  • Wang, Xiao-Jing, and György Buzsáki. “Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model.” Journal of neuroscience 16.20 (1996): 6402-6413.

(Brunel & Hakim, 1999) Fast Global Oscillation

Implementation of the paper:

  • Brunel, Nicolas, and Vincent Hakim. “Fast global oscillations in networks of integrate-and-fire neurons with low firing rates.” Neural computation 11.7 (1999): 1621-1671.

Author: Chaoming Wang

[1]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
[2]:
Vr = 10.  # mV
theta = 20.  # mV
tau = 20.  # ms
delta = 2.  # ms
taurefr = 2.  # ms
duration = 100.  # ms
J = .1  # mV
muext = 25.  # mV
sigmaext = 1.  # mV
C = 1000
N = 5000
sparseness = float(C) / N
[3]:
class LIF(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(LIF, self).__init__(size, **kwargs)

    # variables
    self.V = bm.Variable(bm.ones(self.num) * Vr)
    self.t_last_spike = bm.Variable(-1e7 * bm.ones(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))

    # integration functions
    fv = lambda V, t: (-V + muext) / tau
    gv = lambda V, t: sigmaext / bm.sqrt(tau)
    self.int_v = bp.sdeint(f=fv, g=gv)

  def update(self, tdi):
    V = self.int_v(self.V, tdi.t, tdi.dt)
    in_ref = (tdi.t - self.t_last_spike) < taurefr
    V = bm.where(in_ref, self.V, V)
    spike = V >= theta
    self.spike.value = spike
    self.V.value = bm.where(spike, Vr, V)
    self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
    self.refractory.value = bm.logical_or(in_ref, spike)
[4]:
group = LIF(N)
syn = bp.synapses.Delta(group, group,
                        conn=bp.conn.FixedProb(sparseness),
                        delay_step=int(delta / bm.get_dt()),
                        post_ref_key='refractory',
                        output=bp.synouts.CUBA(target_var='V'),
                        g_max=-J)
net = bp.Network(syn, group=group)
[5]:
runner = bp.DSRunner(net, monitors=['group.spike'])
runner.run(duration)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['group.spike'],
                         xlim=(0, duration), show=True)
_images/oscillation_synchronization_Brunel_Hakim_1999_fast_oscillation_6_1.png

(Diesmann, et, al., 1999) Synfire Chains

Implementation of the paper:

  • Diesmann, Markus, Marc-Oliver Gewaltig, and Ad Aertsen. “Stable propagation of synchronous spiking in cortical neural networks.” Nature 402.6761 (1999): 529-533.

Author: Chaoming Wang

[1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')
[2]:
duration = 100.  # ms

# Neuron model parameters
Vr = -70.  # mV
Vt = -55.  # mV
tau_m = 10.  # ms
tau_ref = 1.  # ms
tau_psp = 0.325  # ms
weight = 4.86  # mV
noise = 39.24  # mV

# Neuron groups
n_groups = 10
group_size = 100
spike_sigma = 1.

# Synapse parameter
delay = 5.0  # ms
[3]:
# neuron model
# ------------


class Groups(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(Groups, self).__init__(size, **kwargs)

    self.V = bm.Variable(Vr + bm.random.random(self.num) * (Vt - Vr))
    self.x = bm.Variable(bm.zeros(self.num))
    self.y = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

    # integral functions
    self.int_V = bp.odeint(lambda V, t, x: (-(V - Vr) + x) / tau_m)
    self.int_x = bp.odeint(lambda x, t, y: (-x + y) / tau_psp)
    self.int_y = bp.sdeint(f=lambda y, t: -y / tau_psp + 25.27,
                           g=lambda y, t: noise)

  def update(self, tdi):
    self.x[:] = self.int_x(self.x, tdi.t, self.y, tdi.dt)
    self.y[:] = self.int_y(self.y, tdi.t, tdi.dt)
    in_ref = (tdi.t - self.t_last_spike) < tau_ref
    V = self.int_V(self.V, tdi.t, self.x, tdi.dt)
    V = bm.where(in_ref, self.V, V)
    self.spike.value = V >= Vt
    self.t_last_spike.value = bm.where(self.spike, tdi.t, self.t_last_spike)
    self.V.value = bm.where(self.spike, Vr, V)
    self.refractory.value = bm.logical_or(in_ref, self.spike)
[4]:
# synaptic  model
# ---------------

class SynBetweenGroups(bp.TwoEndConn):
  def __init__(self, group, ext_group, **kwargs):
    super(SynBetweenGroups, self).__init__(group, group, **kwargs)

    self.group = group
    self.ext = ext_group

    # variables
    self.delay_step = int(delay/bm.get_dt())
    self.g = bm.LengthDelay(bm.zeros(self.group.num), self.delay_step)

  def update(self, tdi):
    # synapse model between external and group 1
    g = bm.zeros(self.group.num)
    g[:group_size] = weight * self.ext.spike.sum()
    # feed-forward connection
    for i in range(1, n_groups):
      s1 = (i - 1) * group_size
      s2 = i * group_size
      s3 = (i + 1) * group_size
      g[s2: s3] = weight * self.group.spike[s1: s2].sum()
    # delay push
    self.g.update(g)
    # delay pull
    self.group.y += self.g(self.delay_step)
[5]:
# network running
# ---------------

def run_network(spike_num=48):
  bm.random.seed(123)
  times = bm.random.randn(spike_num) * spike_sigma + 20
  ext_group = bp.neurons.SpikeTimeGroup(spike_num, times=times.value, indices=bm.arange(spike_num).value)
  group = Groups(size=n_groups * group_size)
  syn_conn = SynBetweenGroups(group, ext_group)
  net = bp.Network(ext_group=ext_group, syn_conn=syn_conn, group=group)

  # simulation
  runner = bp.DSRunner(net,
                       monitors=['group.spike'],
                       dyn_vars=net.vars() + dict(rng=bm.random.DEFAULT))
  runner.run(duration)

  # visualization
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['group.spike'],
                           xlim=(0, duration), show=True)
[6]:
run_network(spike_num=51)
_images/oscillation_synchronization_Diesmann_1999_synfire_chains_7_1.png

When external spike num is 44, the synchronous excitation disperses and eventually dies out.

[7]:
run_network(spike_num=44)
_images/oscillation_synchronization_Diesmann_1999_synfire_chains_9_1.png

(Li, et. al, 2017): Unified Thalamus Oscillation Model

Implementation of the model:

  • Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. “Unified thalamic model generates multiple distinct oscillations with state-dependent entrainment by stimulation.” PLoS computational biology 13.10 (2017): e1005797.

[1]:
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np

import brainpy as bp
import brainpy.math as bm
from brainpy import channels, synapses, synouts, synplast

HTC neuron

[2]:
class HTC(bp.CondNeuGroup):
  def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ):
    gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
    IL = channels.IL(size, g_max=gL, E=-70)
    IKL = channels.IKL(size, g_max=gKL)
    INa = channels.INa_Ba2002(size, V_sh=-30)
    IDR = channels.IKDR_Ba2002(size, V_sh=-30., phi=0.25)
    Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43)

    ICaL = channels.ICaL_IS2008(size, g_max=0.5)
    IAHP = channels.IAHP_De1994(size, g_max=0.3, E=-90.)
    ICaN = channels.ICaN_IS2008(size, g_max=0.5)
    ICaT = channels.ICaT_HM1992(size, g_max=2.1)
    ICaHT = channels.ICaHT_HM1992(size, g_max=3.0)
    Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL,
                                  IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT)

    super(HTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20.,
                              IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)


RTC neuron

[3]:
class RTC(bp.CondNeuGroup):
  def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ):
    gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
    IL = channels.IL(size, g_max=gL, E=-70)
    IKL = channels.IKL(size, g_max=gKL)
    INa = channels.INa_Ba2002(size, V_sh=-40)
    IDR = channels.IKDR_Ba2002(size, V_sh=-40, phi=0.25)
    Ih = channels.Ih_HM1992(size, g_max=0.01, E=-43)

    ICaL = channels.ICaL_IS2008(size, g_max=0.3)
    IAHP = channels.IAHP_De1994(size, g_max=0.1, E=-90.)
    ICaN = channels.ICaN_IS2008(size, g_max=0.6)
    ICaT = channels.ICaT_HM1992(size, g_max=2.1)
    ICaHT = channels.ICaHT_HM1992(size, g_max=0.6)
    Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL,
                                  IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT)

    super(RTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20.,
                              IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)

IN neuron

[4]:
class IN(bp.CondNeuGroup):
  def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ):
    gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
    IL = channels.IL(size, g_max=gL, E=-60)
    IKL = channels.IKL(size, g_max=gKL)
    INa = channels.INa_Ba2002(size, V_sh=-30)
    IDR = channels.IKDR_Ba2002(size, V_sh=-30, phi=0.25)
    Ih = channels.Ih_HM1992(size, g_max=0.05, E=-43)

    IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.)
    ICaN = channels.ICaN_IS2008(size, g_max=0.1)
    ICaHT = channels.ICaHT_HM1992(size, g_max=2.5)
    Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5,
                                  IAHP=IAHP, ICaN=ICaN, ICaHT=ICaHT)

    super(IN, self).__init__(size, A=1.7e-4, V_initializer=V_initializer, V_th=20.,
                             IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca)

TRN neuron

[5]:
class TRN(bp.CondNeuGroup):
  def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ):
    gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125)
    IL = channels.IL(size, g_max=gL, E=-60)
    IKL = channels.IKL(size, g_max=gKL)
    INa = channels.INa_Ba2002(size, V_sh=-40)
    IDR = channels.IKDR_Ba2002(size, V_sh=-40)

    IAHP = channels.IAHP_De1994(size, g_max=0.2, E=-90.)
    ICaN = channels.ICaN_IS2008(size, g_max=0.2)
    ICaT = channels.ICaT_HP1992(size, g_max=1.3)
    Ca = channels.CalciumDetailed(size, C_rest=5e-5, tau=100., d=0.5,
                                  IAHP=IAHP, ICaN=ICaN, ICaT=ICaT)

    super(TRN, self).__init__(size, A=1.43e-4,
                              V_initializer=V_initializer, V_th=20.,
                              IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca)

Thalamus network

[6]:
class MgBlock(bp.SynOut):
  def __init__(self, E=0.):
    super(MgBlock, self).__init__()
    self.E = E

  def filter(self, g):
    V = self.master.post.V.value
    return g * (self.E - V) / (1 + bm.exp(-(V + 25) / 12.5))
[7]:
class Thalamus(bp.Network):
  def __init__(
      self,
      g_input: Dict[str, float],
      g_KL: Dict[str, float],
      HTC_V_init=bp.init.OneInit(-65.),
      RTC_V_init=bp.init.OneInit(-65.),
      IN_V_init=bp.init.OneInit(-70.),
      RE_V_init=bp.init.OneInit(-70.),
  ):
    super(Thalamus, self).__init__()

    # populations
    self.HTC = HTC(size=(7, 7), gKL=g_KL['TC'], V_initializer=HTC_V_init)
    self.RTC = RTC(size=(12, 12), gKL=g_KL['TC'], V_initializer=RTC_V_init)
    self.RE = TRN(size=(10, 10), gKL=g_KL['RE'], V_initializer=IN_V_init)
    self.IN = IN(size=(8, 8), gKL=g_KL['IN'], V_initializer=RE_V_init)

    # noises
    self.poisson_HTC = bp.neurons.PoissonGroup(self.HTC.size, freqs=100)
    self.poisson_RTC = bp.neurons.PoissonGroup(self.RTC.size, freqs=100)
    self.poisson_IN = bp.neurons.PoissonGroup(self.IN.size, freqs=100)
    self.poisson_RE = bp.neurons.PoissonGroup(self.RE.size, freqs=100)
    self.noise2HTC = synapses.Exponential(self.poisson_HTC, self.HTC, bp.conn.One2One(),
                                          output=synouts.COBA(E=0.), tau=5.,
                                          g_max=g_input['TC'])
    self.noise2RTC = synapses.Exponential(self.poisson_RTC, self.RTC, bp.conn.One2One(),
                                          output=synouts.COBA(E=0.), tau=5.,
                                          g_max=g_input['TC'])
    self.noise2IN = synapses.Exponential(self.poisson_IN, self.IN, bp.conn.One2One(),
                                         output=synouts.COBA(E=0.), tau=5.,
                                         g_max=g_input['IN'])
    self.noise2RE = synapses.Exponential(self.poisson_RE, self.RE, bp.conn.One2One(),
                                         output=synouts.COBA(E=0.), tau=5.,
                                         g_max=g_input['RE'])

    # HTC cells were connected with gap junctions
    self.gj_HTC = synapses.GapJunction(self.HTC, self.HTC,
                                       bp.conn.ProbDist(dist=2., prob=0.3, ),
                                       comp_method='sparse',
                                       g_max=1e-2)

    # HTC provides feedforward excitation to INs
    self.HTC2IN_ampa = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     alpha=0.94,
                                     beta=0.18,
                                     g_max=6e-3)
    self.HTC2IN_nmda = synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     output=MgBlock(),
                                     alpha=1.,
                                     beta=0.0067,
                                     g_max=3e-3)

    # INs delivered feedforward inhibition to RTC cells
    self.IN2RTC = synapses.GABAa(self.IN, self.RTC, bp.conn.FixedProb(0.3),
                                 delay_step=int(2 / bm.get_dt()),
                                 stp=synplast.STD(tau=700, U=0.07),
                                 output=synouts.COBA(E=-80),
                                 alpha=10.5,
                                 beta=0.166,
                                 g_max=3e-3)

    # 20% RTC cells electrically connected with HTC cells
    self.gj_RTC2HTC = synapses.GapJunction(self.RTC, self.HTC,
                                           bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2),
                                           comp_method='sparse',
                                           g_max=1 / 300)

    # Both HTC and RTC cells sent glutamatergic synapses to RE neurons, while
    # receiving GABAergic feedback inhibition from the RE population
    self.HTC2RE_ampa = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     alpha=0.94,
                                     beta=0.18,
                                     g_max=4e-3)
    self.RTC2RE_ampa = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     alpha=0.94,
                                     beta=0.18,
                                     g_max=4e-3)
    self.HTC2RE_nmda = synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     output=MgBlock(),
                                     alpha=1.,
                                     beta=0.0067,
                                     g_max=2e-3)
    self.RTC2RE_nmda = synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2),
                                     delay_step=int(2 / bm.get_dt()),
                                     stp=synplast.STD(tau=700, U=0.07),
                                     output=MgBlock(),
                                     alpha=1.,
                                     beta=0.0067,
                                     g_max=2e-3)
    self.RE2HTC = synapses.GABAa(self.RE, self.HTC, bp.conn.FixedProb(0.2),
                                 delay_step=int(2 / bm.get_dt()),
                                 stp=synplast.STD(tau=700, U=0.07),
                                 output=synouts.COBA(E=-80),
                                 alpha=10.5,
                                 beta=0.166,
                                 g_max=3e-3)
    self.RE2RTC = synapses.GABAa(self.RE, self.RTC, bp.conn.FixedProb(0.2),
                                 delay_step=int(2 / bm.get_dt()),
                                 stp=synplast.STD(tau=700, U=0.07),
                                 output=synouts.COBA(E=-80),
                                 alpha=10.5,
                                 beta=0.166,
                                 g_max=3e-3)

    # RE neurons were connected with both gap junctions and GABAergic synapses
    self.gj_RE = synapses.GapJunction(self.RE, self.RE,
                                      bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2),
                                      comp_method='sparse',
                                      g_max=1 / 300)
    self.RE2RE = synapses.GABAa(self.RE, self.RE, bp.conn.FixedProb(0.2),
                                delay_step=int(2 / bm.get_dt()),
                                stp=synplast.STD(tau=700, U=0.07),
                                output=synouts.COBA(E=-70),
                                alpha=10.5, beta=0.166,
                                g_max=1e-3)

    # 10% RE neurons project GABAergic synapses to local interneurons
    # probability (0.05) was used for the RE->IN synapses according to experimental data
    self.RE2IN = synapses.GABAa(self.RE, self.IN, bp.conn.FixedProb(0.05, pre_ratio=0.1),
                                delay_step=int(2 / bm.get_dt()),
                                stp=synplast.STD(tau=700, U=0.07),
                                output=synouts.COBA(E=-80),
                                alpha=10.5, beta=0.166,
                                g_max=1e-3, )

Simulation

[8]:

states = { 'delta': dict(g_input={'IN': 1e-4, 'RE': 1e-4, 'TC': 1e-4}, g_KL={'TC': 0.035, 'RE': 0.03, 'IN': 0.01}), 'spindle': dict(g_input={'IN': 3e-4, 'RE': 3e-4, 'TC': 3e-4}, g_KL={'TC': 0.01, 'RE': 0.02, 'IN': 0.015}), 'alpha': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.5e-3}, g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), 'gamma': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.7e-2}, g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), }
[9]:

def rhythm_const_input(amp, freq, length, duration, t_start=0., t_end=None, dt=None): if t_end is None: t_end = duration if length > duration: raise ValueError(f'Expected length <= duration, while we got {length} > {duration}') sec_length = 1e3 / freq values, durations = [0.], [t_start] for t in np.arange(t_start, t_end, sec_length): values.append(amp) if t + length <= t_end: durations.append(length) values.append(0.) if t + sec_length <= t_end: durations.append(sec_length - length) else: durations.append(t_end - t - length) else: durations.append(t_end - t) values.append(0.) durations.append(duration - t_end) return bp.inputs.section_input(values=values, durations=durations, dt=dt, )

[10]:

def try_trn_neuron(): trn = TRN(1) I, length = bp.inputs.section_input(values=[0, -0.05, 0], durations=[100, 100, 500], return_length=True, dt=0.01) runner = bp.DSRunner(trn, monitors=['V'], inputs=['input', I, 'iter'], dt=0.01) runner.run(length) bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)

[11]:
try_trn_neuron()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/oscillation_synchronization_Li_2017_unified_thalamus_oscillation_model_18_2.png
[12]:

def try_network(state='delta'): duration = 3e3 net = Thalamus( IN_V_init=bp.init.OneInit(-70.), RE_V_init=bp.init.OneInit(-70.), HTC_V_init=bp.init.OneInit(-80.), RTC_V_init=bp.init.OneInit(-80.), **states[state], ) net.reset() currents = rhythm_const_input(2e-4, freq=4., length=10., duration=duration, t_end=2e3, t_start=1e3) plt.plot(currents) plt.title('Current') plt.show() runner = bp.DSRunner( net, monitors=['HTC.spike', 'RTC.spike', 'RE.spike', 'IN.spike', 'HTC.V', 'RTC.V', 'RE.V', 'IN.V', ], inputs=[('HTC.input', currents, 'iter'), ('RTC.input', currents, 'iter'), ('IN.input', currents, 'iter')], ) runner.run(duration) fig, gs = bp.visualize.get_figure(4, 2, 2, 5) fig.add_subplot(gs[0, 0]) bp.visualize.line_plot(runner.mon.ts, runner.mon.get('HTC.V'), ylabel='HTC', xlim=(0, duration)) fig.add_subplot(gs[1, 0]) bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RTC.V'), ylabel='RTC', xlim=(0, duration)) fig.add_subplot(gs[2, 0]) bp.visualize.line_plot(runner.mon.ts, runner.mon.get('IN.V'), ylabel='IN', xlim=(0, duration)) fig.add_subplot(gs[3, 0]) bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RE.V'), ylabel='RE', xlim=(0, duration)) fig.add_subplot(gs[0, 1]) bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('HTC.spike'), xlim=(0, duration)) fig.add_subplot(gs[1, 1]) bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RTC.spike'), xlim=(0, duration)) fig.add_subplot(gs[2, 1]) bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('IN.spike'), xlim=(0, duration)) fig.add_subplot(gs[3, 1]) bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RE.spike'), xlim=(0, duration)) plt.show()
[13]:
try_network()
_images/oscillation_synchronization_Li_2017_unified_thalamus_oscillation_model_20_0.png
_images/oscillation_synchronization_Li_2017_unified_thalamus_oscillation_model_20_2.png

(Susin & Destexhe, 2021): Asynchronous Network

Implementation of the paper:

  • Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.

[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert

import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
              E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
  def __init__(
      self,
      size,

      # neuronal parameters
      Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
      gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,

      # synaptic parameters
      tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,

      # other parameters
      name=None, method='exp_euler',
      V_initializer=bp.init.Uniform(-65, -50),
      w_initializer=bp.init.Constant(0.),
  ):
    super(AdEx, self).__init__(size=size, name=name)

    # neuronal parameters
    self.Vth = Vth
    self.delta = delta
    self.tau_ref = tau_ref
    self.tau_w = tau_w
    self.a = a
    self.b = b
    self.C = C
    self.gL = gL
    self.EL = EL
    self.V_reset = V_reset
    self.V_sp_th = V_sp_th

    # synaptic parameters
    self.tau_e = tau_e
    self.tau_i = tau_i
    self.E_e = E_e
    self.E_i = E_i

    # neuronal variables
    self.V = bp.init.variable_(V_initializer, self.num)
    self.w = bp.init.variable_(w_initializer, self.num)
    self.spike = bm.Variable(self.num, dtype=bool)
    self.refractory = bm.Variable(self.num, dtype=bool)
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)

    # synaptic parameters
    self.ge = bm.Variable(self.num)
    self.gi = bm.Variable(self.num)

    # integral
    self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)

  def dge(self, ge, t):
    return -ge / self.tau_e

  def dgi(self, gi, t):
    return -gi / self.tau_i

  def dV(self, V, t, w, ge, gi, Iext=None):
    I = ge * (self.E_e - V) + gi * (self.E_i - V)
    if Iext is not None: I += Iext
    dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
            - w + self.gL * (self.EL - V) + I) / self.C
    return dVdt

  def dw(self, w, t, V):
    dwdt = (self.a * (V - self.EL) - w) / self.tau_w
    return dwdt

  def update(self, tdi, x=None):
    V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
                                 tdi.t, Iext=x, dt=tdi.dt)
    refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
    V = bm.where(refractory, self.V.value, V)
    spike = V >= self.V_sp_th
    self.V.value = bm.where(spike, self.V_reset, V)
    self.w.value = bm.where(spike, w + self.b, w)
    self.ge.value = ge
    self.gi.value = gi
    self.spike.value = spike
    self.refractory.value = bm.logical_or(refractory, spike)
    self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class AINet(bp.Network):
  def __init__(self, ext_varied_rates, ext_weight=1., method='exp_euler', dt=bm.get_dt()):
    super(AINet, self).__init__()

    self.num_exc = 20000
    self.num_inh = 5000
    self.exc_syn_tau = 5.  # ms
    self.inh_syn_tau = 5.  # ms
    self.exc_syn_weight = 1.  # nS
    self.inh_syn_weight = 5.  # nS
    self.num_delay_step = int(1.5 / dt)
    self.ext_varied_rates = ext_varied_rates

    # neuronal populations
    RS_par_ = RS_par.copy()
    FS_par_ = FS_par.copy()
    RS_par_.update(Vth=-50, V_sp_th=-40)
    FS_par_.update(Vth=-50, V_sp_th=-40)
    self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
    self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
    self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1))

    # Poisson inputs
    self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)
    self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)

    # synaptic projections
    self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)

  def change_freq(self, tdi):
    self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]

[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
  dt = bm.get_dt() if dt is None else dt
  t = 0
  num_gap = int(t_gap / dt)
  num_total = int(t_total / dt)
  num_transition = int(t_transition / dt)

  inputs = []
  ramp_up = np.linspace(c_low, c_high, num_transition)
  ramp_down = np.linspace(c_high, c_low, num_transition)
  plato_base = np.ones(num_gap) * c_low
  while t < num_total:
    num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
    inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
    t += (num_gap + num_transition + num_plato + num_transition)
  return bm.asarray(np.concatenate(inputs)[:num_total])

[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
  # sampling_space: in seconds (no units)
  # signal_time: in seconds (no units)
  # low_cut: in Hz (no units)(band to filter)
  # high_cut: in Hz (no units)(band to filter)

  signal = signal - np.mean(signal)
  width = 5.0  # The desired width in Hz of the transition from pass to stop
  ripple_db = 60.0  # The desired attenuation in the stop band, in dB.
  sampling_rate = 1. / sampling_space
  Nyquist = sampling_rate / 2.

  num_taps, beta = kaiserord(ripple_db, width / Nyquist)
  if num_taps % 2 == 0:
    num_taps = num_taps + 1  # Numtaps must be odd
  taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
                pass_zero=False, scale=True)
  filtered_signal = lfilter(taps, 1.0, signal)
  delay = 0.5 * (num_taps - 1) / sampling_rate  # To corrected to zero-phase
  delay_index = int(np.floor(delay * sampling_rate))
  filtered_signal = filtered_signal[num_taps - 1:]  # taking out the "corrupted" signal
  # correcting the delay and taking out the "corrupted" signal part
  filtered_time = signal_time[num_taps - 1:] - delay
  cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]

  # --------------------------------------------------------------------------
  # The hilbert transform are very slow when the signal has odd lenght,
  # This part check if the length is odd, and if this is the case it adds a zero in the end
  # of all the vectors related to the filtered Signal:
  if len(filtered_signal) % 2 != 0:  # If the lengh is odd
    tmp1 = filtered_signal.tolist()
    tmp1.append(0)
    tmp2 = filtered_time.tolist()
    tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
    tmp3 = cutted_signal.tolist()
    tmp3.append(0)
    filtered_signal = np.asarray(tmp1)
    filtered_time = np.asarray(tmp2)
    cutted_signal = np.asarray(tmp3)
  # --------------------------------------------------------------------------

  ht_filtered_signal = hilbert(filtered_signal)
  envelope = np.abs(ht_filtered_signal)
  phase = np.angle(ht_filtered_signal)  # The phase is between -pi and pi in radians

  return filtered_time, filtered_signal, cutted_signal, envelope, phase

[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
                                 xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
  fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
  # 1. input firing rate
  ax = fig.add_subplot(gs[0])
  plt.plot(times, varied_rates)
  if xlim is None:
    xlim = (0, times[-1])
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('External\nRate (Hz)')

  # 2. inhibitory cell rater plot
  ax = fig.add_subplot(gs[1: 3])
  i = 0
  y_ticks = ([], [])
  for key, (sp_matrix, sp_type) in spikes.items():
    iis, sps = np.where(sp_matrix)
    tts = times[iis]
    plt.plot(tts, sps + i, '.', markersize=1, label=key)
    y_ticks[0].append(i + sp_matrix.shape[1] / 2)
    y_ticks[1].append(key)
    i += sp_matrix.shape[1]
  ax.set_xlim(*xlim)
  ax.set_xlabel('')
  ax.set_ylabel('Neuron Index')
  ax.set_xticks([])
  ax.set_yticks(*y_ticks)
  # ax.legend()

  # 3. example membrane potential
  ax = fig.add_subplot(gs[3: 5])
  for key, potential in example_potentials.items():
    vs = np.where(spikes[key][0][:, 0], 0, potential)
    plt.plot(times, vs, label=key)
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('V (mV)')
  ax.legend()

  # 4. LFP
  ax = fig.add_subplot(gs[5:7])
  ax.set_xlim(*xlim)
  t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
  t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
  times = times[t1: t2]
  lfp = 0
  for sp_matrix, sp_type in spikes.values():
    lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
  phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
                                                                    bm.get_dt() * 1e-3)
  plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
  plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
  plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
  plt.legend(loc='best')
  plt.xlabel('Time (ms)')

  # save or show
  if filename:
    plt.savefig(filename, dpi=500)
  plt.show()

[8]:
def simulate_ai_net():
  duration = 2e3
  varied_rates = get_inputs(2., 2., 50., 150, 600, 1e3, duration)

  net = AINet(varied_rates, ext_weight=1.)
  runner = bp.DSRunner(
    net,
    inputs=net.change_freq,
    monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
              'RS.V0': lambda tdi: net.rs_pop.V[0],
              'FS.spike': lambda tdi: net.fs_pop.spike,
              'RS.spike': lambda tdi: net.rs_pop.spike}
  )
  runner.run(duration)

  visualize_simulation_results(times=runner.mon.ts,
                               spikes={'FS': (runner.mon['FS.spike'], 'inh'),
                                       'RS': (runner.mon['RS.spike'], 'exc')},
                               example_potentials={'FS': runner.mon['FS.V0'],
                                                   'RS': runner.mon['RS.V0']},
                               varied_rates=varied_rates.to_numpy())

[9]:
simulate_ai_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/oscillation_synchronization_Susin_Destexhe_2021_gamma_oscillation_AI_10_2.png

(Susin & Destexhe, 2021): CHING Network for Generating Gamma Oscillation

Implementation of the paper:

  • Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.

[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert

import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
              E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
  def __init__(
      self,
      size,

      # neuronal parameters
      Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
      gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,

      # synaptic parameters
      tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,

      # other parameters
      name=None, method='exp_euler',
      V_initializer=bp.init.Uniform(-65, -50),
      w_initializer=bp.init.Constant(0.),
  ):
    super(AdEx, self).__init__(size=size, name=name)

    # neuronal parameters
    self.Vth = Vth
    self.delta = delta
    self.tau_ref = tau_ref
    self.tau_w = tau_w
    self.a = a
    self.b = b
    self.C = C
    self.gL = gL
    self.EL = EL
    self.V_reset = V_reset
    self.V_sp_th = V_sp_th

    # synaptic parameters
    self.tau_e = tau_e
    self.tau_i = tau_i
    self.E_e = E_e
    self.E_i = E_i

    # neuronal variables
    self.V = bp.init.variable_(V_initializer, self.num)
    self.w = bp.init.variable_(w_initializer, self.num)
    self.spike = bm.Variable(self.num, dtype=bool)
    self.refractory = bm.Variable(self.num, dtype=bool)
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)

    # synaptic parameters
    self.ge = bm.Variable(self.num)
    self.gi = bm.Variable(self.num)

    # integral
    self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)

  def dge(self, ge, t):
    return -ge / self.tau_e

  def dgi(self, gi, t):
    return -gi / self.tau_i

  def dV(self, V, t, w, ge, gi, Iext=None):
    I = ge * (self.E_e - V) + gi * (self.E_i - V)
    if Iext is not None: I += Iext
    dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
            - w + self.gL * (self.EL - V) + I) / self.C
    return dVdt

  def dw(self, w, t, V):
    dwdt = (self.a * (V - self.EL) - w) / self.tau_w
    return dwdt

  def update(self, tdi, x=None):
    V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
                                 tdi.t, Iext=x, dt=tdi.dt)
    refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
    V = bm.where(refractory, self.V.value, V)
    spike = V >= self.V_sp_th
    self.V.value = bm.where(spike, self.V_reset, V)
    self.w.value = bm.where(spike, w + self.b, w)
    self.ge.value = ge
    self.gi.value = gi
    self.spike.value = spike
    self.refractory.value = bm.logical_or(refractory, spike)
    self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class CHINGNet(bp.Network):
  def __init__(self, ext_varied_rates, method='exp_euler', dt=bm.get_dt()):
    super(CHINGNet, self).__init__()

    self.num_rs = 19000
    self.num_fs = 5000
    self.num_ch = 1000
    self.exc_syn_tau = 5.  # ms
    self.inh_syn_tau = 5.  # ms
    self.exc_syn_weight = 1.  # nS
    self.inh_syn_weight1 = 7.  # nS
    self.inh_syn_weight2 = 5.  # nS
    self.ext_weight1 = 1.  # nS
    self.ext_weight2 = 0.75  # nS
    self.num_delay_step = int(1.5 / dt)
    self.ext_varied_rates = ext_varied_rates

    # neuronal populations
    RS_par_ = RS_par.copy()
    FS_par_ = FS_par.copy()
    Ch_par_ = Ch_par.copy()
    RS_par_.update(Vth=-50, V_sp_th=-40)
    FS_par_.update(Vth=-50, V_sp_th=-40)
    Ch_par_.update(Vth=-50, V_sp_th=-40)
    self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
    self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
    self.ch_pop = AdEx(self.num_ch, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **Ch_par_)
    self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1))

    # Poisson inputs
    self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=self.ext_weight2)
    self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=self.ext_weight1)
    self.ext_to_CH = bp.synapses.Delta(self.ext_pop, self.ch_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=self.ext_weight1)

    # synaptic projections
    self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_Ch = bp.synapses.Delta(self.rs_pop, self.ch_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)

    self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight1,
                                      delay_step=self.num_delay_step)
    self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight2,
                                      delay_step=self.num_delay_step)
    self.FS_to_Ch = bp.synapses.Delta(self.fs_pop, self.ch_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight1,
                                      delay_step=self.num_delay_step)

    self.Ch_to_RS = bp.synapses.Delta(self.ch_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.Ch_to_FS = bp.synapses.Delta(self.ch_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.Ch_to_Ch = bp.synapses.Delta(self.ch_pop, self.ch_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)

  def change_freq(self, tdi):
    self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]

[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
  dt = bm.get_dt() if dt is None else dt
  t = 0
  num_gap = int(t_gap / dt)
  num_total = int(t_total / dt)
  num_transition = int(t_transition / dt)

  inputs = []
  ramp_up = np.linspace(c_low, c_high, num_transition)
  ramp_down = np.linspace(c_high, c_low, num_transition)
  plato_base = np.ones(num_gap) * c_low
  while t < num_total:
    num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
    inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
    t += (num_gap + num_transition + num_plato + num_transition)
  return bm.asarray(np.concatenate(inputs)[:num_total])

[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
  # sampling_space: in seconds (no units)
  # signal_time: in seconds (no units)
  # low_cut: in Hz (no units)(band to filter)
  # high_cut: in Hz (no units)(band to filter)

  signal = signal - np.mean(signal)
  width = 5.0  # The desired width in Hz of the transition from pass to stop
  ripple_db = 60.0  # The desired attenuation in the stop band, in dB.
  sampling_rate = 1. / sampling_space
  Nyquist = sampling_rate / 2.

  num_taps, beta = kaiserord(ripple_db, width / Nyquist)
  if num_taps % 2 == 0:
    num_taps = num_taps + 1  # Numtaps must be odd
  taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
                pass_zero=False, scale=True)
  filtered_signal = lfilter(taps, 1.0, signal)
  delay = 0.5 * (num_taps - 1) / sampling_rate  # To corrected to zero-phase
  delay_index = int(np.floor(delay * sampling_rate))
  filtered_signal = filtered_signal[num_taps - 1:]  # taking out the "corrupted" signal
  # correcting the delay and taking out the "corrupted" signal part
  filtered_time = signal_time[num_taps - 1:] - delay
  cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]

  # --------------------------------------------------------------------------
  # The hilbert transform are very slow when the signal has odd lenght,
  # This part check if the length is odd, and if this is the case it adds a zero in the end
  # of all the vectors related to the filtered Signal:
  if len(filtered_signal) % 2 != 0:  # If the lengh is odd
    tmp1 = filtered_signal.tolist()
    tmp1.append(0)
    tmp2 = filtered_time.tolist()
    tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
    tmp3 = cutted_signal.tolist()
    tmp3.append(0)
    filtered_signal = np.asarray(tmp1)
    filtered_time = np.asarray(tmp2)
    cutted_signal = np.asarray(tmp3)
  # --------------------------------------------------------------------------

  ht_filtered_signal = hilbert(filtered_signal)
  envelope = np.abs(ht_filtered_signal)
  phase = np.angle(ht_filtered_signal)  # The phase is between -pi and pi in radians

  return filtered_time, filtered_signal, cutted_signal, envelope, phase

[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
                                 xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
  fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
  # 1. input firing rate
  ax = fig.add_subplot(gs[0])
  plt.plot(times, varied_rates)
  if xlim is None:
    xlim = (0, times[-1])
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('External\nRate (Hz)')

  # 2. inhibitory cell rater plot
  ax = fig.add_subplot(gs[1: 3])
  i = 0
  y_ticks = ([], [])
  for key, (sp_matrix, sp_type) in spikes.items():
    iis, sps = np.where(sp_matrix)
    tts = times[iis]
    plt.plot(tts, sps + i, '.', markersize=1, label=key)
    y_ticks[0].append(i + sp_matrix.shape[1] / 2)
    y_ticks[1].append(key)
    i += sp_matrix.shape[1]
  ax.set_xlim(*xlim)
  ax.set_xlabel('')
  ax.set_ylabel('Neuron Index')
  ax.set_xticks([])
  ax.set_yticks(*y_ticks)
  # ax.legend()

  # 3. example membrane potential
  ax = fig.add_subplot(gs[3: 5])
  for key, potential in example_potentials.items():
    vs = np.where(spikes[key][0][:, 0], 0, potential)
    plt.plot(times, vs, label=key)
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('V (mV)')
  ax.legend()

  # 4. LFP
  ax = fig.add_subplot(gs[5:7])
  ax.set_xlim(*xlim)
  t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
  t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
  times = times[t1: t2]
  lfp = 0
  for sp_matrix, sp_type in spikes.values():
    lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
  phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
                                                                    bm.get_dt() * 1e-3)
  plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
  plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
  plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
  plt.legend(loc='best')
  plt.xlabel('Time (ms)')

  # save or show
  if filename:
    plt.savefig(filename, dpi=500)
  plt.show()

[8]:
def simulate_ching_net():
  duration = 6e3
  varied_rates = get_inputs(1., 2., 50., 150, 600, 1e3, duration)

  net = CHINGNet(varied_rates)
  runner = bp.DSRunner(
    net,
    inputs=net.change_freq,
    monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
              'CH.V0': lambda tdi: net.ch_pop.V[0],
              'RS.V0': lambda tdi: net.rs_pop.V[0],
              'FS.spike': lambda tdi: net.fs_pop.spike,
              'CH.spike': lambda tdi: net.ch_pop.spike,
              'RS.spike': lambda tdi: net.rs_pop.spike}
  )
  runner.run(duration)

  visualize_simulation_results(times=runner.mon.ts,
                               spikes={'FS': (runner.mon['FS.spike'], 'inh'),
                                       'CH': (runner.mon['CH.spike'], 'exc'),
                                       'RS': (runner.mon['RS.spike'], 'exc')},
                               example_potentials={'FS': runner.mon['FS.V0'],
                                                   'CH': runner.mon['CH.V0'],
                                                   'RS': runner.mon['RS.V0']},
                               varied_rates=varied_rates.to_numpy(),
                               xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)
[9]:
simulate_ching_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/oscillation_synchronization_Susin_Destexhe_2021_gamma_oscillation_CHING_10_2.png

(Susin & Destexhe, 2021): ING Network for Generating Gamma Oscillation

Implementation of the paper:

  • Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.

[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert

import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
              E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
  def __init__(
      self,
      size,

      # neuronal parameters
      Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
      gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,

      # synaptic parameters
      tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,

      # other parameters
      name=None, method='exp_euler',
      V_initializer=bp.init.Uniform(-65, -50),
      w_initializer=bp.init.Constant(0.),
  ):
    super(AdEx, self).__init__(size=size, name=name)

    # neuronal parameters
    self.Vth = Vth
    self.delta = delta
    self.tau_ref = tau_ref
    self.tau_w = tau_w
    self.a = a
    self.b = b
    self.C = C
    self.gL = gL
    self.EL = EL
    self.V_reset = V_reset
    self.V_sp_th = V_sp_th

    # synaptic parameters
    self.tau_e = tau_e
    self.tau_i = tau_i
    self.E_e = E_e
    self.E_i = E_i

    # neuronal variables
    self.V = bp.init.variable_(V_initializer, self.num)
    self.w = bp.init.variable_(w_initializer, self.num)
    self.spike = bm.Variable(self.num, dtype=bool)
    self.refractory = bm.Variable(self.num, dtype=bool)
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)

    # synaptic parameters
    self.ge = bm.Variable(self.num)
    self.gi = bm.Variable(self.num)

    # integral
    self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)

  def dge(self, ge, t):
    return -ge / self.tau_e

  def dgi(self, gi, t):
    return -gi / self.tau_i

  def dV(self, V, t, w, ge, gi, Iext=None):
    I = ge * (self.E_e - V) + gi * (self.E_i - V)
    if Iext is not None: I += Iext
    dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
            - w + self.gL * (self.EL - V) + I) / self.C
    return dVdt

  def dw(self, w, t, V):
    dwdt = (self.a * (V - self.EL) - w) / self.tau_w
    return dwdt

  def update(self, tdi, x=None):
    V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
                                 tdi.t, Iext=x, dt=tdi.dt)
    refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
    V = bm.where(refractory, self.V.value, V)
    spike = V >= self.V_sp_th
    self.V.value = bm.where(spike, self.V_reset, V)
    self.w.value = bm.where(spike, w + self.b, w)
    self.ge.value = ge
    self.gi.value = gi
    self.spike.value = spike
    self.refractory.value = bm.logical_or(refractory, spike)
    self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class INGNet(bp.Network):
  def __init__(self, ext_varied_rates, ext_weight=0.9, method='exp_euler', dt=bm.get_dt()):
    super(INGNet, self).__init__()

    self.num_rs = 20000
    self.num_fs = 4000
    self.num_fs2 = 1000
    self.exc_syn_tau = 5.  # ms
    self.inh_syn_tau = 5.  # ms
    self.exc_syn_weight = 1.  # nS
    self.inh_syn_weight = 5.  # nS
    self.num_delay_step = int(1.5 / dt)
    self.ext_varied_rates = ext_varied_rates

    # neuronal populations
    RS_par_ = RS_par.copy()
    FS_par_ = FS_par.copy()
    FS2_par_ = FS_par.copy()
    RS_par_.update(Vth=-50, V_sp_th=-40)
    FS_par_.update(Vth=-50, V_sp_th=-40)
    FS2_par_.update(Vth=-50, V_sp_th=-40)
    self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
    self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
    self.fs2_pop = AdEx(self.num_fs2, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS2_par_)
    self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1))

    # Poisson inputs
    self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)
    self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)
    self.ext_to_FS2 = bp.synapses.Delta(self.ext_pop, self.fs2_pop, bp.conn.FixedProb(0.02),
                                        output=bp.synouts.CUBA(target_var='ge'),
                                        g_max=ext_weight)

    # synaptic projections
    self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_FS2 = bp.synapses.Delta(self.rs_pop, self.fs2_pop, bp.conn.FixedProb(0.15),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=self.exc_syn_weight,
                                       delay_step=self.num_delay_step)

    self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_FS2 = bp.synapses.Delta(self.fs_pop, self.fs2_pop, bp.conn.FixedProb(0.03),
                                       output=bp.synouts.CUBA(target_var='gi'),
                                       g_max=self.inh_syn_weight,
                                       delay_step=self.num_delay_step)

    self.FS2_to_RS = bp.synapses.Delta(self.fs2_pop, self.rs_pop, bp.conn.FixedProb(0.15),
                                       output=bp.synouts.CUBA(target_var='gi'),
                                       g_max=self.exc_syn_weight,
                                       delay_step=self.num_delay_step)
    self.FS2_to_FS = bp.synapses.Delta(self.fs2_pop, self.fs_pop, bp.conn.FixedProb(0.15),
                                       output=bp.synouts.CUBA(target_var='gi'),
                                       g_max=self.exc_syn_weight,
                                       delay_step=self.num_delay_step)
    self.FS2_to_FS2 = bp.synapses.Delta(self.fs2_pop, self.fs2_pop, bp.conn.FixedProb(0.6),
                                        output=bp.synouts.CUBA(target_var='gi'),
                                        g_max=self.exc_syn_weight,
                                        delay_step=self.num_delay_step)

  def change_freq(self, tdi):
    self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]

[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
  dt = bm.get_dt() if dt is None else dt
  t = 0
  num_gap = int(t_gap / dt)
  num_total = int(t_total / dt)
  num_transition = int(t_transition / dt)

  inputs = []
  ramp_up = np.linspace(c_low, c_high, num_transition)
  ramp_down = np.linspace(c_high, c_low, num_transition)
  plato_base = np.ones(num_gap) * c_low
  while t < num_total:
    num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
    inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
    t += (num_gap + num_transition + num_plato + num_transition)
  return bm.asarray(np.concatenate(inputs)[:num_total])

[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
  # sampling_space: in seconds (no units)
  # signal_time: in seconds (no units)
  # low_cut: in Hz (no units)(band to filter)
  # high_cut: in Hz (no units)(band to filter)

  signal = signal - np.mean(signal)
  width = 5.0  # The desired width in Hz of the transition from pass to stop
  ripple_db = 60.0  # The desired attenuation in the stop band, in dB.
  sampling_rate = 1. / sampling_space
  Nyquist = sampling_rate / 2.

  num_taps, beta = kaiserord(ripple_db, width / Nyquist)
  if num_taps % 2 == 0:
    num_taps = num_taps + 1  # Numtaps must be odd
  taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
                pass_zero=False, scale=True)
  filtered_signal = lfilter(taps, 1.0, signal)
  delay = 0.5 * (num_taps - 1) / sampling_rate  # To corrected to zero-phase
  delay_index = int(np.floor(delay * sampling_rate))
  filtered_signal = filtered_signal[num_taps - 1:]  # taking out the "corrupted" signal
  # correcting the delay and taking out the "corrupted" signal part
  filtered_time = signal_time[num_taps - 1:] - delay
  cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]

  # --------------------------------------------------------------------------
  # The hilbert transform are very slow when the signal has odd lenght,
  # This part check if the length is odd, and if this is the case it adds a zero in the end
  # of all the vectors related to the filtered Signal:
  if len(filtered_signal) % 2 != 0:  # If the lengh is odd
    tmp1 = filtered_signal.tolist()
    tmp1.append(0)
    tmp2 = filtered_time.tolist()
    tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
    tmp3 = cutted_signal.tolist()
    tmp3.append(0)
    filtered_signal = np.asarray(tmp1)
    filtered_time = np.asarray(tmp2)
    cutted_signal = np.asarray(tmp3)
  # --------------------------------------------------------------------------

  ht_filtered_signal = hilbert(filtered_signal)
  envelope = np.abs(ht_filtered_signal)
  phase = np.angle(ht_filtered_signal)  # The phase is between -pi and pi in radians

  return filtered_time, filtered_signal, cutted_signal, envelope, phase

[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
                                 xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
  fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
  # 1. input firing rate
  ax = fig.add_subplot(gs[0])
  plt.plot(times, varied_rates)
  if xlim is None:
    xlim = (0, times[-1])
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('External\nRate (Hz)')

  # 2. inhibitory cell rater plot
  ax = fig.add_subplot(gs[1: 3])
  i = 0
  y_ticks = ([], [])
  for key, (sp_matrix, sp_type) in spikes.items():
    iis, sps = np.where(sp_matrix)
    tts = times[iis]
    plt.plot(tts, sps + i, '.', markersize=1, label=key)
    y_ticks[0].append(i + sp_matrix.shape[1] / 2)
    y_ticks[1].append(key)
    i += sp_matrix.shape[1]
  ax.set_xlim(*xlim)
  ax.set_xlabel('')
  ax.set_ylabel('Neuron Index')
  ax.set_xticks([])
  ax.set_yticks(*y_ticks)
  # ax.legend()

  # 3. example membrane potential
  ax = fig.add_subplot(gs[3: 5])
  for key, potential in example_potentials.items():
    vs = np.where(spikes[key][0][:, 0], 0, potential)
    plt.plot(times, vs, label=key)
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('V (mV)')
  ax.legend()

  # 4. LFP
  ax = fig.add_subplot(gs[5:7])
  ax.set_xlim(*xlim)
  t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
  t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
  times = times[t1: t2]
  lfp = 0
  for sp_matrix, sp_type in spikes.values():
    lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
  phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
                                                                    bm.get_dt() * 1e-3)
  plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
  plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
  plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
  plt.legend(loc='best')
  plt.xlabel('Time (ms)')

  # save or show
  if filename:
    plt.savefig(filename, dpi=500)
  plt.show()

[8]:
def simulate_ing_net():
  duration = 6e3
  varied_rates = get_inputs(2., 3., 50., 350, 600, 1e3, duration)

  net = INGNet(varied_rates, ext_weight=0.9)
  runner = bp.DSRunner(
    net,
    inputs=net.change_freq,
    monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
              'FS2.V0': lambda tdi: net.fs2_pop.V[0],
              'RS.V0': lambda tdi: net.rs_pop.V[0],
              'FS.spike': lambda tdi: net.fs_pop.spike,
              'FS2.spike': lambda tdi: net.fs2_pop.spike,
              'RS.spike': lambda tdi: net.rs_pop.spike}
  )
  runner.run(duration)

  visualize_simulation_results(times=runner.mon.ts,
                               spikes={'FS': (runner.mon['FS.spike'], 'inh'),
                                       'FS2': (runner.mon['FS2.spike'], 'inh'),
                                       'RS': (runner.mon['RS.spike'], 'exc')},
                               example_potentials={'FS': runner.mon['FS.V0'],
                                                   'FS2': runner.mon['FS2.V0'],
                                                   'RS': runner.mon['RS.V0']},
                               varied_rates=varied_rates.to_numpy(),
                               xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)
[9]:
simulate_ing_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/oscillation_synchronization_Susin_Destexhe_2021_gamma_oscillation_ING_10_2.png

(Susin & Destexhe, 2021): PING Network for Generating Gamma Oscillation

Implementation of the paper:

  • Susin, Eduarda, and Alain Destexhe. “Integration, coincidence detection and resonance in networks of spiking neurons expressing gamma oscillations and asynchronous states.” PLoS computational biology 17.9 (2021): e1009416.

[1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import kaiserord, lfilter, firwin, hilbert

import brainpy as bp
import brainpy.math as bm
[2]:
# Table 1: specific neuron model parameters
RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65,
              E_e=0., E_i=-80.)
Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65,
              E_e=0., E_i=-80.)
[3]:
class AdEx(bp.NeuGroup):
  def __init__(
      self,
      size,

      # neuronal parameters
      Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150,
      gL=10, EL=-65, V_reset=-65, V_sp_th=-30.,

      # synaptic parameters
      tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80.,

      # other parameters
      name=None, method='exp_euler',
      V_initializer=bp.init.Uniform(-65, -50),
      w_initializer=bp.init.Constant(0.),
  ):
    super(AdEx, self).__init__(size=size, name=name)

    # neuronal parameters
    self.Vth = Vth
    self.delta = delta
    self.tau_ref = tau_ref
    self.tau_w = tau_w
    self.a = a
    self.b = b
    self.C = C
    self.gL = gL
    self.EL = EL
    self.V_reset = V_reset
    self.V_sp_th = V_sp_th

    # synaptic parameters
    self.tau_e = tau_e
    self.tau_i = tau_i
    self.E_e = E_e
    self.E_i = E_i

    # neuronal variables
    self.V = bp.init.variable_(V_initializer, self.num)
    self.w = bp.init.variable_(w_initializer, self.num)
    self.spike = bm.Variable(self.num, dtype=bool)
    self.refractory = bm.Variable(self.num, dtype=bool)
    self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8)

    # synaptic parameters
    self.ge = bm.Variable(self.num)
    self.gi = bm.Variable(self.num)

    # integral
    self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method)

  def dge(self, ge, t):
    return -ge / self.tau_e

  def dgi(self, gi, t):
    return -gi / self.tau_i

  def dV(self, V, t, w, ge, gi, Iext=None):
    I = ge * (self.E_e - V) + gi * (self.E_i - V)
    if Iext is not None: I += Iext
    dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta)
            - w + self.gL * (self.EL - V) + I) / self.C
    return dVdt

  def dw(self, w, t, V):
    dwdt = (self.a * (V - self.EL) - w) / self.tau_w
    return dwdt

  def update(self, tdi, x=None):
    V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value,
                                 tdi.t, Iext=x, dt=tdi.dt)
    refractory = (tdi.t - self.t_last_spike) <= self.tau_ref
    V = bm.where(refractory, self.V.value, V)
    spike = V >= self.V_sp_th
    self.V.value = bm.where(spike, self.V_reset, V)
    self.w.value = bm.where(spike, w + self.b, w)
    self.ge.value = ge
    self.gi.value = gi
    self.spike.value = spike
    self.refractory.value = bm.logical_or(refractory, spike)
    self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike)
[4]:
class PINGNet(bp.Network):
  def __init__(self, ext_varied_rates, ext_weight=4., method='exp_euler', dt=bm.get_dt()):
    super(PINGNet, self).__init__()

    self.num_exc = 20000
    self.num_inh = 5000
    self.exc_syn_tau = 1.  # ms
    self.inh_syn_tau = 7.5  # ms
    self.exc_syn_weight = 5.  # nS
    self.inh_syn_weight = 3.34  # nS
    self.num_delay_step = int(1.5 / dt)
    self.ext_varied_rates = ext_varied_rates

    # neuronal populations
    RS_par_ = RS_par.copy()
    FS_par_ = FS_par.copy()
    RS_par_.update(Vth=-50, V_sp_th=-40)
    FS_par_.update(Vth=-50, V_sp_th=-40)
    self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_)
    self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_)
    self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1))

    # Poisson inputs
    self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)
    self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                       output=bp.synouts.CUBA(target_var='ge'),
                                       g_max=ext_weight)

    # synaptic projections
    self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='ge'),
                                      g_max=self.exc_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)
    self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02),
                                      output=bp.synouts.CUBA(target_var='gi'),
                                      g_max=self.inh_syn_weight,
                                      delay_step=self.num_delay_step)

  def change_freq(self, tdi):
    self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i]
[5]:
def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None):
  dt = bm.get_dt() if dt is None else dt
  t = 0
  num_gap = int(t_gap / dt)
  num_total = int(t_total / dt)
  num_transition = int(t_transition / dt)

  inputs = []
  ramp_up = np.linspace(c_low, c_high, num_transition)
  ramp_down = np.linspace(c_high, c_low, num_transition)
  plato_base = np.ones(num_gap) * c_low
  while t < num_total:
    num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt)
    inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down])
    t += (num_gap + num_transition + num_plato + num_transition)
  return bm.asarray(np.concatenate(inputs)[:num_total])

[6]:
def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space):
  # sampling_space: in seconds (no units)
  # signal_time: in seconds (no units)
  # low_cut: in Hz (no units)(band to filter)
  # high_cut: in Hz (no units)(band to filter)

  signal = signal - np.mean(signal)
  width = 5.0  # The desired width in Hz of the transition from pass to stop
  ripple_db = 60.0  # The desired attenuation in the stop band, in dB.
  sampling_rate = 1. / sampling_space
  Nyquist = sampling_rate / 2.

  num_taps, beta = kaiserord(ripple_db, width / Nyquist)
  if num_taps % 2 == 0:
    num_taps = num_taps + 1  # Numtaps must be odd
  taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0,
                pass_zero=False, scale=True)
  filtered_signal = lfilter(taps, 1.0, signal)
  delay = 0.5 * (num_taps - 1) / sampling_rate  # To corrected to zero-phase
  delay_index = int(np.floor(delay * sampling_rate))
  filtered_signal = filtered_signal[num_taps - 1:]  # taking out the "corrupted" signal
  # correcting the delay and taking out the "corrupted" signal part
  filtered_time = signal_time[num_taps - 1:] - delay
  cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))]

  # --------------------------------------------------------------------------
  # The hilbert transform are very slow when the signal has odd lenght,
  # This part check if the length is odd, and if this is the case it adds a zero in the end
  # of all the vectors related to the filtered Signal:
  if len(filtered_signal) % 2 != 0:  # If the lengh is odd
    tmp1 = filtered_signal.tolist()
    tmp1.append(0)
    tmp2 = filtered_time.tolist()
    tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0])
    tmp3 = cutted_signal.tolist()
    tmp3.append(0)
    filtered_signal = np.asarray(tmp1)
    filtered_time = np.asarray(tmp2)
    cutted_signal = np.asarray(tmp3)
  # --------------------------------------------------------------------------

  ht_filtered_signal = hilbert(filtered_signal)
  envelope = np.abs(ht_filtered_signal)
  phase = np.angle(ht_filtered_signal)  # The phase is between -pi and pi in radians

  return filtered_time, filtered_signal, cutted_signal, envelope, phase

[7]:
def visualize_simulation_results(times, spikes, example_potentials, varied_rates,
                                 xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None):
  fig, gs = bp.visualize.get_figure(7, 1, 1, 12)
  # 1. input firing rate
  ax = fig.add_subplot(gs[0])
  plt.plot(times, varied_rates)
  if xlim is None:
    xlim = (0, times[-1])
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('External\nRate (Hz)')

  # 2. inhibitory cell rater plot
  ax = fig.add_subplot(gs[1: 3])
  i = 0
  y_ticks = ([], [])
  for key, (sp_matrix, sp_type) in spikes.items():
    iis, sps = np.where(sp_matrix)
    tts = times[iis]
    plt.plot(tts, sps + i, '.', markersize=1, label=key)
    y_ticks[0].append(i + sp_matrix.shape[1] / 2)
    y_ticks[1].append(key)
    i += sp_matrix.shape[1]
  ax.set_xlim(*xlim)
  ax.set_xlabel('')
  ax.set_ylabel('Neuron Index')
  ax.set_xticks([])
  ax.set_yticks(*y_ticks)
  # ax.legend()

  # 3. example membrane potential
  ax = fig.add_subplot(gs[3: 5])
  for key, potential in example_potentials.items():
    vs = np.where(spikes[key][0][:, 0], 0, potential)
    plt.plot(times, vs, label=key)
  ax.set_xlim(*xlim)
  ax.set_xticks([])
  ax.set_ylabel('V (mV)')
  ax.legend()

  # 4. LFP
  ax = fig.add_subplot(gs[5:7])
  ax.set_xlim(*xlim)
  t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0
  t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times)
  times = times[t1: t2]
  lfp = 0
  for sp_matrix, sp_type in spikes.values():
    lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type)
  phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50,
                                                                    bm.get_dt() * 1e-3)
  plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP')
  plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)")
  plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope")
  plt.legend(loc='best')
  plt.xlabel('Time (ms)')

  # save or show
  if filename:
    plt.savefig(filename, dpi=500)
  plt.show()

[8]:
def simulate_ping_net():
  duration = 6e3
  varied_rates = get_inputs(2., 3., 50., 150, 600, 1e3, duration)

  net = PINGNet(varied_rates, ext_weight=4.)
  runner = bp.DSRunner(
    net,
    inputs=net.change_freq,
    monitors={'FS.V0': lambda tdi: net.fs_pop.V[0],
              'RS.V0': lambda tdi: net.rs_pop.V[0],
              'FS.spike': lambda tdi: net.fs_pop.spike,
              'RS.spike': lambda tdi: net.rs_pop.spike}
  )
  runner.run(duration)

  visualize_simulation_results(times=runner.mon.ts,
                               spikes={'FS': (runner.mon['FS.spike'], 'inh'),
                                       'RS': (runner.mon['RS.spike'], 'exc')},
                               example_potentials={'FS': runner.mon['FS.V0'],
                                                   'RS': runner.mon['RS.V0']},
                               varied_rates=varied_rates.to_numpy(),
                               xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3)

[9]:
simulate_ping_net()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/oscillation_synchronization_Susin_Destexhe_2021_gamma_oscillation_PING_10_2.png

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 1

Implementation of the figure 1 of:

  • Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.

[1]:
import brainpy as bp
import brainpy.math as bm
from jax import vmap, jit
import numpy as np
import matplotlib.pyplot as plt
[2]:
wIE = 4 + 2.0 / 7.0  # synaptic weight E to I
wII = wIE * 1.1  # synaptic weight I to I
[3]:
class LocalCircuit(bp.DynamicalSystem):
  r"""The model is given by:

  .. math::

     \tau_E \frac{dv_E}{dt} = -v_E + a1 * [I_E]_+ + a2 * [I_I]_+ \\
     \tau_I \frac{dv_I}{dt} = -v_I + a2 * [I_E]_+ + a4 * [I_I]_+

  where :math:`[I_E]_+=max(I_E, 0)`. :math:`v_E` and :math:`v_I` denote the firing rates
  of the excitatory and inhibitory populations respectively, :math:`\tau_E` and
  :math:`\tau_I` are the corresponding intrinsic time constants.
  """

  def __init__(self, wEE, wEI, tau_e=0.02, tau_i=0.02):
    super(LocalCircuit, self).__init__()
    # parameters
    self.gc = bm.asarray([[wEE, -wEI],
                          [wIE, -wII]])
    self.tau = bm.asarray([tau_e, tau_i])  # time constant [s]
    # variables
    self.state = bm.Variable(bm.asarray([1., 0.]))

  def update(self, tdi):
    self.state += (-self.state + self.gc @ self.state) / self.tau * tdi.dt
    self.state.value = bm.maximum(self.state, 0.)

[4]:
def simulate(wEE, wEI, duration, dt=0.0001, numpy_mon_after_run=True):
  model = LocalCircuit(wEE=wEE, wEI=wEI)
  runner = bp.DSRunner(model, monitors=['state'], dt=dt,
                       numpy_mon_after_run=numpy_mon_after_run,
                       progress_bar=False)
  runner.run(duration)
  return runner.mon.state

[5]:

@jit @vmap def get_max_amplitude(wEE, wEI): states = simulate(wEE, wEI, duration=2., dt=0.0001, numpy_mon_after_run=False) return states[:, 0].max()
[6]:
@jit
@vmap
def get_eigen_value(wEE, wEI):
  A = bm.array([[wEE, -wEI], [wIE, -wII]])
  w, _ = bm.linalg.eig(A)
  return w.real.max()
[7]:
# =================== Figure 1B in the paper ==================================#

length, step = 0.6, 0.0001
wEEweak = 4.45  # synaptic weight E to E weak LBA
wEIweak = 4.7  # synaptic weight I to E weak LBA
wEEstrong = 6  # synaptic weight E to E strong LBA
wEIstrong = 6.7  # synaptic weight I to E strong LBA
weak = simulate(wEEweak, wEIweak, duration=length, dt=step)
strong = simulate(wEEstrong, wEIstrong, duration=length, dt=step)

fig, gs = bp.visualize.get_figure(1, 1, 5, 8)
ax = fig.add_subplot(gs[0, 0])
ax.plot(np.arange(step, length, step), weak[:, 0], 'g')
ax.plot(np.arange(step, length, step), strong[:, 0], 'm')
ax.set_ylabel('Excitatory rate (Hz)', fontsize='x-large')
ax.set_xlabel('Time (s)', fontsize='x-large')
ax.set_ylim([0, 6])
ax.set_xlim([0, length])
ax.set_yticks([0, 2, 4, 6])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(['weak LBA', 'strong LBA'], prop={'size': 15}, frameon=False)
plt.show()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure1_8_1.png
[8]:
# =================== Figure 1C in the paper ==================================#

all_wEE = bm.arange(4, 6.5, .025)
all_wEI = bm.arange(4.5, 7, .025)
shape = (len(all_wEE), len(all_wEI))
max_amplitude = bm.ones(shape) * -5
all_wEE, all_wEI = bm.meshgrid(all_wEE, all_wEI)
# select all parameters lead to stable model
max_eigen_values = get_eigen_value(all_wEE.flatten(), all_wEI.flatten())
selected_ids = bm.where(max_eigen_values.reshape(shape) < 1)
# get maximum amplitude of each stable model
num = 100
for i in range(0, selected_ids[0].size, num):
  ids = (selected_ids[0][i: i + num], selected_ids[1][i: i + num])
  max_amps = get_max_amplitude(all_wEE[ids], all_wEI[ids])
  max_amplitude[ids] = max_amps

fig, gs = bp.visualize.get_figure(1, 1, 5, 8)
ax = fig.add_subplot(gs[0, 0])
X, Y = bm.as_numpy(all_wEE), bm.as_numpy(all_wEI)
levels = np.linspace(-5, 8, 20)
plt.contourf(X, Y, max_amplitude.numpy(), levels=levels, cmap='hot')
plt.contourf(X, Y, max_amplitude.numpy(), levels=np.linspace(-5, 0), colors='silver')
plt.plot(wEEstrong, wEIstrong, 'mx')
plt.plot(wEEweak, wEIweak, 'gx')
plt.ylabel('Local I to E coupling', fontsize=15)
plt.xlabel('Local E to E coupling', fontsize=15)
plt.show()
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure1_9_0.png

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 2

Implementation of the figure 2 of:

  • Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.

[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap, jit
from scipy import io as sio
from functools import partial
[2]:
bm.set_dt(dt=2e-4)
[3]:

class ThreshLinearModel(bp.NeuGroup): def __init__( self, size, hier, fln, inp_idx, inp_data, eta=.68, betaE=.066, betaI=.351, tauE=2e-2, tauI=1e-2, omegaEE=24.3, omegaEI=19.7, omegaIE=12.2, omegaII=12.5, muIE=25.3, muEE=28., noiseE=None, noiseI=None, seed=None, desired_ss=None, name=None ): super(ThreshLinearModel, self).__init__(size, name=name) # parameters self.hier = hier self.fln = fln self.eta = bp.init.parameter(eta, self.num, False) self.betaE = bp.init.parameter(betaE, self.num, False) self.betaI = bp.init.parameter(betaI, self.num, False) self.tauE = bp.init.parameter(tauE, self.num, False) self.tauI = bp.init.parameter(tauI, self.num, False) self.omegaEE = bp.init.parameter(omegaEE, self.num, False) self.omegaEI = bp.init.parameter(omegaEI, self.num, False) self.omegaIE = bp.init.parameter(omegaIE, self.num, False) self.omegaII = bp.init.parameter(omegaII, self.num, False) self.muIE = bp.init.parameter(muIE, self.num, False) self.muEE = bp.init.parameter(muEE, self.num, False) self.desired_ss = desired_ss self.seed = seed self.noiseE = bp.init.parameter(noiseE, self.num, True) self.noiseI = bp.init.parameter(noiseI, self.num, True) self.inp_idx, self.inp_data = inp_idx, inp_data # Synaptic weights for intra-areal connections self.wEE_intra = betaE * omegaEE * (1 + eta * hier) self.wIE_intra = betaI * omegaIE * (1 + eta * hier) self.wEI_intra = -betaE * omegaEI self.wII_intra = -betaI * omegaII # Synaptic weights for inter-areal connections self.wEE_inter = bm.asarray(fln.T * (betaE * muEE * (1 + eta * hier))).T self.wIE_inter = bm.asarray(fln.T * (betaI * muIE * (1 + eta * hier))).T # Variables self.re = bm.Variable(bm.zeros(self.num)) self.ri = bm.Variable(bm.zeros(self.num)) if not (self.noiseE is None and self.noiseI is None): self.rng = bm.random.RandomState(seed) else: self.rng = None # get background input if desired_ss is None: self.bgE = bm.zeros(self.num) self.bgI = bm.zeros(self.num) else: if len(desired_ss) != 2: raise ValueError if len(desired_ss[0]) != self.num: raise ValueError if len(desired_ss[1]) != self.num: raise ValueError self.bgE, self.bgI = self.get_background_current(*desired_ss) def get_background_current(self, ssE, ssI): # total weights wEEaux = bm.diag(-1 + self.wEE_intra) + self.wEE_inter wEIaux = self.wEI_intra * bm.eye(self.num) wIEaux = bm.diag(self.wIE_intra) + self.wIE_inter wIIaux = (-1 + self.wII_intra) * bm.eye(self.num) # temp matrices to create matrix A A1 = bm.concatenate((wEEaux, wEIaux), axis=1) A2 = bm.concatenate((wIEaux, wIIaux), axis=1) A = bm.concatenate([A1, A2]) ss = bm.concatenate((ssE, ssI)) cur = -bm.dot(A, ss) self.re.value, self.ri.value = ssE, ssI # state = bm.linalg.lstsq(-A, cur, rcond=None)[0] # self.re.value, self.ri.value = bm.split(state, 2) return bm.split(cur, 2) def reset(self): if self.desired_ss is None: self.re[:] = 0. self.ri[:] = 0. else: self.re.value = self.desired_ss[0] self.ri.value = self.desired_ss[1] if self.rng is not None: self.rng.seed(self.seed) def update(self, tdi): # E population Ie = bm.dot(self.wEE_inter, self.re) + self.wEE_intra * self.re Ie += self.wEI_intra * self.ri + self.bgE if self.noiseE is not None: Ie += self.noiseE * self.rng.randn(self.num) / bm.sqrt(tdi.dt) Ie[self.inp_idx] += self.inp_data[bm.asarray(tdi.t / tdi.dt, dtype=int).value] self.re.value = bm.maximum(self.re + (-self.re + bm.maximum(Ie, 10.)) / self.tauE * tdi.dt, 0) # I population Ii = bm.dot(self.wIE_inter, self.re) + self.wIE_intra * self.re Ii += self.wII_intra * self.ri + self.bgI if self.noiseI is not None: Ii += self.noiseI * self.rng.randn(self.num) / bm.sqrt(tdi.dt) self.ri.value = bm.maximum(self.ri + (-self.ri + bm.maximum(Ii, 35.)) / self.tauI * tdi.dt, 0)

[4]:


def simulate(num_node, muEE, fln, hier, input4v1, duration): model = ThreshLinearModel(int(num_node), hier=hier, fln=fln, inp_idx=0, inp_data=input4v1, muEE=muEE, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35)) runner = bp.dyn.DSRunner(model, monitors=['re'], progress_bar=False, numpy_mon_after_run=False) runner.run(duration) return runner.mon.ts, runner.mon.re
[5]:

def show_firing_rates(ax, hist_t, hist_re, show_duration=None, title=None): hist_t = bm.as_numpy(hist_t) hist_re = bm.as_numpy(hist_re) if show_duration is None: i_start, i_end = (1.75, 5.) else: i_start, i_end = show_duration i_start = round(i_start / bm.get_dt()) i_end = round(i_end / bm.get_dt()) # visualization rateV1 = np.maximum(1e-2, hist_re[i_start:i_end, 0] - hist_re[i_start, 0]) rate24 = np.maximum(1e-2, hist_re[i_start:i_end, -1] - hist_re[i_start, -1]) ax.semilogy(hist_t[i_start:i_end], rateV1, 'dodgerblue') ax.semilogy(hist_t[i_start:i_end], rate24, 'forestgreen') ax.set_ylim([1e-2, 1e2 + 100]) # ax.set_xlim([-0.25, 2.25]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_ylabel('Change in firing rate (Hz)', fontsize='large') ax.set_xlabel('Time (s)', fontsize='large') if title: ax.set_title(title)
[6]:

def show_maximum_rate(ax, muEE_range, peak_rates, title=''): ax.semilogy(muEE_range[4], np.squeeze(peak_rates[4]), 'cornflowerblue', marker="o", markersize=12, markerfacecolor='w') ax.semilogy(muEE_range, np.squeeze(peak_rates), 'cornflowerblue', marker=".", markersize=10) ax.semilogy(muEE_range[3], np.squeeze(peak_rates[3]), 'cornflowerblue', marker="x", markersize=15) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_ylim([1e-6, 1e3]) ax.set_ylabel('Maximum rate in 24c (Hz)', fontsize='large') ax.set_xlabel('Global E to E coupling', fontsize='large') if title: ax.set_title(title)
[7]:

def figure2B(): # data data = sio.loadmat('Joglekar_2018_data/subgraphData.mat') num_node = data['nNodes'][0, 0] hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals'])) # normalize hierarchical position fln = bm.asarray(data['flnMat']) # inputs ampl = 21.8 * 1.9 inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True) # Fig 2B ax = plt.subplot(1, 2, 1) times, res = simulate(int(num_node), fln=fln, hier=hier, muEE=34, input4v1=inputs, duration=duration) show_firing_rates(ax, hist_t=times, hist_re=res, show_duration=(1.75, 5.), title=r'$\mu$EE=34') ax = plt.subplot(1, 2, 2) times, res = simulate(int(num_node), fln=fln, hier=hier, muEE=36, input4v1=inputs, duration=duration) show_firing_rates(ax, hist_t=times, hist_re=res, show_duration=(1.75, 5.), title=r'$\mu$EE=36') plt.show()
[8]:
figure2B()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure2_9_1.png
[9]:

def figure2C(): # data data = sio.loadmat('Joglekar_2018_data/subgraphData.mat') hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals'])) # normalize hierarchical position fln = bm.asarray(data['flnMat']) # inputs ampl = 21.8 * 1.9 inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True) all_muEE = bm.arange(28, 44, 2) i_start, i_end = (round(1.75 / bm.get_dt()), round(5. / bm.get_dt())) @jit @partial(vmap, in_axes=(0, None)) def peak_firing_rate(muEE, fln): _, res = simulate(fln.shape[0], fln=fln, hier=hier, muEE=muEE, input4v1=inputs, duration=duration) return (res[i_start:i_end, -1] - res[i_start, -1]).max() # with feedback ax = plt.subplot(1, 2, 1) area2peak = peak_firing_rate(all_muEE, fln) area2peak = bm.where(area2peak > 500, 500, area2peak) show_maximum_rate(ax, all_muEE.to_numpy(), area2peak.to_numpy(), title='With feedback') # without feedback ax = plt.subplot(1, 2, 2) area2peak = peak_firing_rate(all_muEE, bm.tril(fln)) area2peak = bm.where(area2peak > 500, 500, area2peak) show_maximum_rate(ax, all_muEE.to_numpy(), area2peak.to_numpy(), title='Without feedback') plt.show()

[10]:
figure2C()
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure2_11_0.png
[11]:

sStrong = 1000 sWeak = 100 def show_multiple_area_rates(axes, times, rates, plot_duration, color='green'): t_start, t_end = plot_duration i_start, i_end = round(t_start / bm.get_dt()), round(t_end / bm.get_dt()) areas2plot = [0, 2, 5, 7, 8, 12, 16, 18, 28] for i, j in enumerate(areas2plot): ax = axes[i] ax.plot(times[i_start:i_end] - 100, rates[i_start:i_end, j] - rates[i_start, j], color) ax.set_xlim([-98.25, -96]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) plt.setp(ax.get_xticklabels(), visible=False) ax.tick_params(axis='both', which='both', length=0) peak = (rates[i_start:i_end, j] - rates[i_start, j]).max() if j == 0: ax.set_ylim([0, 140]) ax.set_yticks([round(sWeak * peak, 1) / sWeak]) ax.set_title('Weak GBA') else: ax.set_ylim([0, 1.2 * peak]) ax.set_yticks([round(sWeak * peak, 1) / sWeak]) if i == 4: ax.set_ylabel('Change in firing rate (Hz)', fontsize='large')

[12]:

def figure3BD(): # data data = sio.loadmat('Joglekar_2018_data/subgraphData.mat') hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals'])) fln = bm.asarray(data['flnMat']) num_node = fln.shape[0] fig, axes = plt.subplots(9, 2) # weak GBA muEE = 33.7 omegaEI = 19.7 ampl = 22.05 * 1.9 inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True) model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE, omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35)) runner1 = bp.dyn.DSRunner(model, monitors=['re'], ) runner1.run(duration) show_multiple_area_rates(times=runner1.mon.ts, rates=runner1.mon.re, plot_duration=(1.25, 5.), axes=[axis[0] for axis in axes]) # strong GBA muEE = 51.5 omegaEI = 25.2 ampl = 11.54 * 1.9 inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True) model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE, omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35)) runner2 = bp.dyn.DSRunner(model, monitors=['re'], ) runner2.run(duration) show_multiple_area_rates(times=runner1.mon.ts, rates=runner1.mon.re, plot_duration=(1.75, 5.), axes=[axis[1] for axis in axes], color='green') show_multiple_area_rates(times=runner2.mon.ts, rates=runner2.mon.re, plot_duration=(1.75, 5.), axes=[axis[1] for axis in axes], color='purple') plt.tight_layout() plt.show()

[13]:
figure3BD()
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure2_14_2.png
[14]:
def figure3E():
  # data
  data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
  hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals']))
  fln = bm.asarray(data['flnMat'])
  num_node = fln.shape[0]
  i_start, i_end = round(1.75 / bm.get_dt()), round(5. / bm.get_dt())

  # weak GBA
  muEE = 33.7
  omegaEI = 19.7
  ampl = 22.05 * 1.9
  inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
  model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
                            omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
  runner1 = bp.dyn.DSRunner(model, monitors=['re'], )
  runner1.run(duration)
  peak1 = (runner1.mon.re[i_start: i_end] - runner1.mon.re[i_start]).max(axis=0)

  # strong GBA
  muEE = 51.5
  omegaEI = 25.2
  ampl = 11.54 * 1.9
  inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)
  model = ThreshLinearModel(fln.shape[0], hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
                            omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
  runner2 = bp.dyn.DSRunner(model, monitors=['re'], )
  runner2.run(duration)
  peak2 = (runner2.mon.re[i_start: i_end] - runner2.mon.re[i_start]).max(axis=0)

  # visualization
  fig, ax = plt.subplots()
  ax.semilogy(np.arange(0, num_node), 100 * peak1 / peak1[0], 'green', marker=".", markersize=5)
  ax.semilogy(np.arange(0, num_node), 100 * peak2 / peak2[0], 'purple', marker=".", markersize=5)
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.set_ylim([1e-4, 1e2])
  ax.set_xlim([0, num_node])
  ax.legend(['weak GBA', 'strong GBA'], prop={'size': 10}, loc='upper right',
            bbox_to_anchor=(1.0, 1.2), frameon=False)
  ax.set_xticks(np.arange(0, num_node))
  ax.set_xticklabels(data['areaList'].squeeze(), rotation='vertical', fontsize=10)
  plt.show()

[15]:
figure3E()
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure2_16_2.png
[16]:
def figure3F():
  # data
  data = sio.loadmat('Joglekar_2018_data/subgraphData.mat')
  hier = bm.asarray(data['hierVals'].squeeze() / max(data['hierVals']))
  fln = bm.asarray(data['flnMat'])
  num_node = fln.shape[0]
  i_start, i_end = round(1.75 / bm.get_dt()), round(3.5 / bm.get_dt())

  # input
  ampl = 22.05 * 1.9
  inputs, duration = bp.inputs.section_input([0, ampl, 0], [2., 0.25, 7.75], return_length=True)

  @partial(vmap, in_axes=(0, None))
  def maximum_rate(muEE, omegaEI=None):
    if omegaEI is None: omegaEI = 19.7 + (muEE - 33.7) * 55 / 178
    model = ThreshLinearModel(num_node, hier=hier, fln=fln, inp_idx=0, inp_data=inputs, muEE=muEE,
                              omegaEI=omegaEI, desired_ss=(bm.ones(num_node) * 10, bm.ones(num_node) * 35))
    runner = bp.dyn.DSRunner(model, monitors=['re'], progress_bar=False, numpy_mon_after_run=False)
    runner.run(duration)
    return (runner.mon.re[i_start: i_end, -1] - runner.mon.re[i_start, -1]).max()

  # visualization
  muEErange = bm.arange(20, 52, 2)
  peaks_with_gba = maximum_rate(muEErange, None)
  peaks_without_gba = maximum_rate(muEErange, 19.7)
  peaks_with_gba = bm.where(peaks_with_gba > 500, 500, peaks_with_gba)
  peaks_without_gba = bm.where(peaks_without_gba > 500, 500, peaks_without_gba)
  fig, ax = plt.subplots()
  ax.semilogy(muEErange, peaks_without_gba.to_numpy(), 'cornflowerblue', marker=".", markersize=5)
  ax.semilogy(muEErange, peaks_with_gba.to_numpy(), 'black', marker=".", markersize=5)
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.set_ylim([1e-8, 1e4])
  ax.set_xlim([20, 50])
  ax.set_ylabel('Maximum rate in 24c (Hz)', fontsize='large')
  ax.set_xlabel('Global E to E coupling', fontsize='large')
  plt.show()

[17]:
figure3F()
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure2_18_0.png

(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 5

Implementation of the figure 5 of:

  • Joglekar, Madhura R., et al. “Inter-areal balanced amplification enhances signal propagation in a large-scale circuit model of the primate cortex.” Neuron 98.1 (2018): 222-234.

[1]:
import brainpy as bp
import brainpy.math as bm
from brainpy.dyn import neurons

import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
from scipy.io import loadmat
[2]:
# This model should be run on a GPU device

bm.set_platform('gpu')
[3]:
class MultiAreaNet(bp.Network):
  def __init__(
      self, hier, conn, delay_mat, muIE=0.0475, muEE=.0375, wII=.075,
      wEE=.01, wIE=.075, wEI=.0375, extE=15.4, extI=14.0, alpha=4., seed=None,
  ):
    super(MultiAreaNet, self).__init__()

    # data
    self.hier = hier
    self.conn = conn
    self.delay_mat = delay_mat

    # parameters
    self.muIE = muIE
    self.muEE = muEE
    self.wII = wII
    self.wEE = wEE
    self.wIE = wIE
    self.wEI = wEI
    self.extE = extE
    self.extI = extI
    self.alpha = alpha
    num_area = hier.size
    self.num_area = num_area
    rng = bm.random.RandomState(seed)

    # neuron models
    self.E = neurons.LIF((num_area, 1600), V_th=-50., V_reset=-60.,
                         V_rest=-70., tau=20., tau_ref=2.,
                         noise=3. / bm.sqrt(20.),
                         V_initializer=bp.init.Uniform(-70., -50.),
                         method='euler',
                         keep_size=True)
    self.I = neurons.LIF((num_area, 400), V_th=-50., V_reset=-60.,
                         V_rest=-70., tau=10., tau_ref=2., noise=3. / bm.sqrt(10.),
                         V_initializer=bp.init.Uniform(-70., -50.),
                         method='euler',
                         keep_size=True)

    # delays
    self.intra_delay_step = int(2. / bm.get_dt())
    self.E_delay_steps = bm.asarray(delay_mat.T / bm.get_dt(), dtype=int)
    bm.fill_diagonal(self.E_delay_steps, self.intra_delay_step)
    self.Edelay = bm.LengthDelay(self.E.spike, delay_len=int(self.E_delay_steps.max()))
    self.Idelay = bm.LengthDelay(self.I.spike, delay_len=self.intra_delay_step)

    # synapse model
    syn_fun = lambda pre_spike, weight, conn_mat: weight * (pre_spike @ conn_mat)
    self.f_E_current = vmap(syn_fun)
    self.f_I_current = vmap(syn_fun, in_axes=(0, None, 0))

    # synapses from I
    self.intra_I2E_conn = rng.random((num_area, 400, 1600)) < 0.1
    self.intra_I2I_conn = rng.random((num_area, 400, 400)) < 0.1
    self.intra_I2E_weight = -wEI
    self.intra_I2I_weight = -wII

    # synapses from E
    self.E2E_conns = [rng.random((num_area, 1600, 1600)) < 0.1 for _ in range(num_area)]
    self.E2I_conns = [rng.random((num_area, 1600, 400)) < 0.1 for _ in range(num_area)]
    self.E2E_weights = (1 + alpha * hier) * muEE * conn.T  # inter-area connections
    bm.fill_diagonal(self.E2E_weights, (1 + alpha * hier) * wEE)  # intra-area connections
    self.E2I_weights = (1 + alpha * hier) * muIE * conn.T  # inter-area connections
    bm.fill_diagonal(self.E2I_weights, (1 + alpha * hier) * wIE)  # intra-area connections

  def update(self, tdi, v1_input):
    self.E.input[0] += v1_input
    self.E.input += self.extE
    self.I.input += self.extI
    E_not_ref = bm.logical_not(self.E.refractory)
    I_not_ref = bm.logical_not(self.I.refractory)

    # synapses from E
    for i in range(self.num_area):
      delayed_E_spikes = self.Edelay(self.E_delay_steps[i], i).astype(float)
      current = self.f_E_current(delayed_E_spikes, self.E2E_weights[i], self.E2E_conns[i])
      self.E.V += current * E_not_ref  # E2E
      current = self.f_E_current(delayed_E_spikes, self.E2I_weights[i], self.E2I_conns[i])
      self.I.V += current * I_not_ref  # E2I

    # synapses from I
    delayed_I_spikes = self.Idelay(self.intra_delay_step).astype(float)
    current = self.f_I_current(delayed_I_spikes, self.intra_I2E_weight, self.intra_I2E_conn)
    self.E.V += current * E_not_ref  # I2E
    current = self.f_I_current(delayed_I_spikes, self.intra_I2I_weight, self.intra_I2I_conn)
    self.I.V += current * I_not_ref  # I2I

    # updates
    self.Edelay.update(self.E.spike)
    self.Idelay.update(self.I.spike)
    self.E.update(tdi)
    self.I.update(tdi)
[4]:
def raster_plot(xValues, yValues, duration):
  ticks = np.round(np.arange(0, 29) + 0.5, 2)
  areas = ['V1', 'V2', 'V4', 'DP', 'MT', '8m', '5', '8l', 'TEO', '2', 'F1',
           'STPc', '7A', '46d', '10', '9/46v', '9/46d', 'F5', 'TEpd', 'PBr',
           '7m', '7B', 'F2', 'STPi', 'PROm', 'F7', '8B', 'STPr', '24c']
  N = len(ticks)
  plt.figure(figsize=(8, 6))
  plt.plot(xValues, yValues / (4 * 400), '.', markersize=1)
  plt.plot([0, duration], np.arange(N + 1).repeat(2).reshape(-1, 2).T, 'k-')
  plt.ylabel('Area')
  plt.yticks(np.arange(N))
  plt.xlabel('Time [ms]')
  plt.ylim(0, N)
  plt.yticks(ticks, areas)
  plt.xlim(0, duration)
  plt.tight_layout()
  plt.show()
[5]:
# hierarchy values
hierVals = loadmat('Joglekar_2018_data/hierValspython.mat')
hierValsnew = hierVals['hierVals'].flatten()
hier = bm.asarray(hierValsnew / max(hierValsnew))  # hierarchy normalized.

# fraction of labeled neurons
flnMatp = loadmat('Joglekar_2018_data/efelenMatpython.mat')
conn = bm.asarray(flnMatp['flnMatpython'].squeeze())  # fln values..Cij is strength from j to i

# Distance
speed = 3.5  # axonal conduction velocity
distMatp = loadmat('Joglekar_2018_data/subgraphWiring29.mat')
distMat = distMatp['wiring'].squeeze()  # distances between areas values..
delayMat = bm.asarray(distMat / speed)
[6]:
pars = dict(extE=14.2, extI=14.7, wII=.075, wEE=.01, wIE=.075, wEI=.0375, muEE=.0375, muIE=0.0475)
inps = dict(value=15, duration=150)
[7]:
inputs, length = bp.inputs.section_input(values=[0, inps['value'], 0.],
                                         durations=[300., inps['duration'], 500],
                                         return_length=True)
[8]:
net = MultiAreaNet(hier, conn, delayMat, **pars)
runner = bp.DSRunner(net, fun_monitors={'E.spike': lambda tdi: net.E.spike.flatten()})
runner.run(length, inputs=inputs)
[9]:
times, indices = np.where(runner.mon['E.spike'])
times = runner.mon.ts[times]
raster_plot(times, indices, length)
_images/large_scale_modeling_Joglekar_2018_InterAreal_Balanced_Amplification_figure5_10_0.png

Simulating 1-million-neuron networks with 1GB GPU memory

[1]:
import brainpy as bp
import brainpy.math as bm
import jax

assert bp.__version__ >= '2.3.8'
[2]:
bm.set(dt=0.4)
[3]:
_default_g_max = dict(type='homo', value=1., prob=0.1, seed=123)
_default_uniform = dict(type='uniform', w_low=0.1, w_high=1., prob=0.1, seed=123)
_default_normal = dict(type='normal', w_mu=0.1, w_sigma=1., prob=0.1, seed=123)
[4]:
class Exponential(bp.TwoEndConnNS):
  def __init__(
      self,
      pre: bp.NeuGroup,
      post: bp.NeuGroup,
      output: bp.SynOut = bp.synouts.CUBA(),
      g_max_par=_default_g_max,
      delay_step=None,
      tau=8.0,
      method: str = 'exp_auto',
      name: str = None,
      mode: bm.Mode = None,
  ):
    super().__init__(pre, post, None, output=output, name=name, mode=mode)
    self.tau = tau
    self.g_max_par = g_max_par
    self.g = bp.init.variable_(bm.zeros, self.post.num, self.mode)
    self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
    self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

  def reset_state(self, batch_size=None):
    self.g.value = bp.init.variable_(bm.zeros, self.post.num, batch_size)

  def update(self):
    t = bp.share.load('t')
    dt = bp.share.load('dt')
    pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
    if self.g_max_par['type'] == 'homo':
      f = lambda s: bm.event_matvec_prob_conn_homo_weight(s,
                                                          self.g_max_par['value'],
                                                          conn_prob=self.g_max_par['prob'],
                                                          shape=(self.pre.num, self.post.num),
                                                          seed=self.g_max_par['seed'],
                                                          transpose=True)
    elif self.g_max_par['type'] == 'uniform':
      f = lambda s: bm.event_matvec_prob_conn_uniform_weight(s,
                                                             w_low=self.g_max_par['w_low'],
                                                             w_high=self.g_max_par['w_high'],
                                                             conn_prob=self.g_max_par['prob'],
                                                             shape=(self.pre.num, self.post.num),
                                                             seed=self.g_max_par['seed'],
                                                             transpose=True)
    elif self.g_max_par['type'] == 'normal':
      f = lambda s: bm.event_matvec_prob_conn_normal_weight(s,
                                                            w_mu=self.g_max_par['w_mu'],
                                                            w_sigma=self.g_max_par['w_sigma'],
                                                            conn_prob=self.g_max_par['prob'],
                                                            shape=(self.pre.num, self.post.num),
                                                            seed=self.g_max_par['seed'],
                                                            transpose=True)
    else:
      raise ValueError
    if isinstance(self.mode, bm.BatchingMode):
      f = jax.vmap(f)
    post_vs = f(pre_spike)
    self.g.value = self.integral(self.g.value, t, dt) + post_vs
    return self.output(self.g)
[5]:
class EINet(bp.Network):
  def __init__(self, scale=1.0, method='exp_auto'):
    super().__init__()
    num_exc = int(3200 * scale)
    num_inh = int(800 * scale)
    pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.))
    self.N = bp.neurons.LIF(num_exc + num_inh, **pars, method=method)
    self.E = Exponential(self.N[:num_exc], self.N,
                         g_max_par=dict(type='homo', value=0.6 / scale, prob=0.02, seed=123),
                         tau=5., method=method, output=bp.synouts.COBA(E=0.))
    self.I = Exponential(self.N[num_exc:], self.N,
                         g_max_par=dict(type='homo', value=6.7 / scale, prob=0.02, seed=12345),
                         tau=10., method=method, output=bp.synouts.COBA(E=-80.))
[6]:
duration = 1e2
net = EINet(scale=250)
[7]:
runner = bp.DSRunner(
  net,
  monitors=['N.spike'],
  inputs=('N.input', 20.),
  memory_efficient=True
)
runner.run(duration)
[8]:
bp.visualize.raster_plot(runner.mon.ts, runner.mon['N.spike'], show=True)
_images/large_scale_modeling_EI_net_with_1m_neurons_8_0.png

Integrator RNN Model

In this notebook, we train a vanilla RNN to integrate white noise. This example is useful on its own to understand how RNN training works.

[1]:
from functools import partial
import matplotlib.pyplot as plt

import brainpy as bp
import brainpy.math as bm

bm.set_environment(bm.training_mode)

Parameters

[2]:
dt = 0.04
num_step = int(1.0 / dt)
num_batch = 128

Data

[3]:
@partial(bm.jit,
         dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
         static_argnames=['batch_size'])
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
  # Create the white noise input
  sample = bm.random.normal(size=(batch_size, 1, 1))
  bias = mean * 2.0 * (sample - 0.5)
  samples = bm.random.normal(size=(batch_size, num_step, 1))
  noise_t = scale / dt ** 0.5 * samples
  inputs = bias + noise_t
  targets = bm.cumsum(inputs, axis=1)
  return inputs, targets
[4]:
def train_data():
  for _ in range(100):
    yield build_inputs_and_targets(batch_size=num_batch)

Model

[5]:
class RNN(bp.DynamicalSystem):
  def __init__(self, num_in, num_hidden):
    super(RNN, self).__init__()
    self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
    self.out = bp.layers.Dense(num_hidden, 1)

  def update(self, sha, x):
    return self.out(sha, self.rnn(sha, x))

model = RNN(1, 100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Training

[6]:
# define loss function
def loss(predictions, targets, l2_reg=2e-4):
  mse = bp.losses.mean_squared_error(predictions, targets)
  l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2
  return mse + l2


# define optimizer
lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
opt = bp.optim.Adam(lr=lr, eps=1e-1)


# create a trainer
trainer = bp.BPTT(model, loss_fun=loss, optimizer=opt)
trainer.fit(train_data,
            num_epoch=30,
            num_report=200)
Train 200 steps, use 2.0416 s, loss 0.43417808413505554
Train 400 steps, use 1.1465 s, loss 0.027894824743270874
Train 600 steps, use 1.1387 s, loss 0.02194945327937603
Train 800 steps, use 4.5845 s, loss 0.021538913249969482
Train 1000 steps, use 5.0299 s, loss 0.02128899097442627
Train 1200 steps, use 4.9541 s, loss 0.02115160971879959
Train 1400 steps, use 5.0622 s, loss 0.021017059683799744
Train 1600 steps, use 5.0935 s, loss 0.020916711539030075
Train 1800 steps, use 4.9851 s, loss 0.020782889798283577
Train 2000 steps, use 4.8506 s, loss 0.020689304918050766
Train 2200 steps, use 5.1480 s, loss 0.020607156679034233
Train 2400 steps, use 5.0867 s, loss 0.020528702065348625
Train 2600 steps, use 4.9022 s, loss 0.02044598013162613
Train 2800 steps, use 4.8018 s, loss 0.020371917635202408
Train 3000 steps, use 4.9188 s, loss 0.020304910838603973
[7]:
plt.plot(bm.as_numpy(trainer.get_hist_metric()))
plt.show()
_images/recurrent_networks_integrator_rnn_12_0.png

Testing

[8]:
model.reset_state(1)
x, y = build_inputs_and_targets(batch_size=1)
predicts = trainer.predict(x)

plt.figure(figsize=(8, 2))
plt.plot(bm.as_numpy(y[0]).flatten(), label='Ground Truth')
plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction')
plt.legend()
plt.show()
_images/recurrent_networks_integrator_rnn_14_1.png

Train RNN to Solve Parametric Working Memory

[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym

import neurogym as ngym
[4]:
# Environment
task = 'DelayComparison-v0'
timing = {'delay': ('choice', [200, 400, 800, 1600, 3200]),
          'response': ('constant', 500)}
kwargs = {'dt': 100, 'timing': timing}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task,
                       env_kwargs=kwargs,
                       batch_size=16,
                       seq_len=seq_len)

# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, def_act=0, fig_kwargs={'figsize': (8, 6)})
plt.show()
_images/recurrent_networks_ParametricWorkingMemory_4_0.png
[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size
[6]:
class RNN(bp.DynamicalSystem):
  def __init__(self, num_input, num_hidden, num_output, num_batch,
               w_ir=bp.init.KaimingNormal(scale=1.),
               w_rr=bp.init.KaimingNormal(scale=1.),
               w_ro=bp.init.KaimingNormal(scale=1.),
               dt=None, seed=None):
    super(RNN, self).__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.rng = bm.random.RandomState(seed=seed)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, size=(num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, size=(num_hidden, num_hidden)))
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, size=(num_hidden, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    # variables
    self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
    self.o = bm.Variable(bm.zeros((num_batch, num_output)))

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h.value)
    self.o.value = self.readout(self.h.value)
    return self.h.value, self.o.value

  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs)

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    loss = bp.losses.cross_entropy_loss(os, ys.flatten())
    return loss, os
[7]:
# Instantiate the network and print information
hidden_size = 64
net = RNN(num_input=input_size,
          num_hidden=hidden_size,
          num_output=output_size,
          num_batch=batch_size,
          dt=env.dt)
[8]:
predict = bm.jit(net.predict)
[9]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[10]:
# gradient function
grad = bm.grad(net.loss,
               child_objs=net,
               grad_vars=net.train_vars().unique(),
               return_value=True,
               has_aux=True)
[11]:
@bm.jit
@bm.to_object(child_objs=(grad, opt))
def train(xs, ys):
  grads, loss, os = grad(xs, ys)
  opt.update(grads)
  return loss, os
[12]:
running_acc = 0
running_loss = 0
for i in range(2000):
  inputs, labels_np = dataset()
  inputs = bm.asarray(inputs)
  labels = bm.asarray(labels_np)
  loss, outputs = train(inputs, labels)
  running_loss += loss
  # Compute performance
  output_np = np.argmax(bm.as_numpy(outputs), axis=-1).flatten()
  labels_np = labels_np.flatten()
  ind = labels_np > 0  # Only analyze time points when target is not fixation
  running_acc += np.mean(labels_np[ind] == output_np[ind])
  if i % 100 == 99:
    running_loss /= 100
    running_acc /= 100
    print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(i + 1, running_loss, running_acc))
    running_loss = 0
    running_acc = 0
Step 100, Loss 0.1653, Acc 0.149
Step 200, Loss 0.0321, Acc 0.674
Step 300, Loss 0.0238, Acc 0.764
Step 400, Loss 0.0187, Acc 0.825
Step 500, Loss 0.0150, Acc 0.857
Step 600, Loss 0.0124, Acc 0.884
Step 700, Loss 0.0099, Acc 0.909
Step 800, Loss 0.0110, Acc 0.894
Step 900, Loss 0.0092, Acc 0.912
Step 1000, Loss 0.0089, Acc 0.913
Step 1100, Loss 0.0080, Acc 0.926
Step 1200, Loss 0.0079, Acc 0.921
Step 1300, Loss 0.0079, Acc 0.923
Step 1400, Loss 0.0075, Acc 0.924
Step 1500, Loss 0.0080, Acc 0.921
Step 1600, Loss 0.0071, Acc 0.930
Step 1700, Loss 0.0071, Acc 0.930
Step 1800, Loss 0.0074, Acc 0.920
Step 1900, Loss 0.0069, Acc 0.928
Step 2000, Loss 0.0071, Acc 0.926
[13]:
def run(num_trial=1):
  env.reset(no_step=True)
  perf = 0
  activity_dict = {}
  trial_infos = {}
  for i in range(num_trial):
    env.new_trial()
    ob, gt = env.ob, env.gt
    inputs = bm.asarray(ob[:, np.newaxis, :])
    rnn_activity, action_pred = predict(inputs)
    rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
    activity_dict[i] = rnn_activity
    trial_infos[i] = env.trial

  # Concatenate activity for PCA
  activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0)
  print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)

  # Print trial informations
  for i in range(5):
    if i >= num_trial: break
    print('Trial ', i, trial_infos[i])

  pca = PCA(n_components=2)
  pca.fit(activity)
  # print('Shape of the projected activity: (Time points, PCs): ', activity_pc.shape)

  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(6, 3))
  for i in range(num_trial):
    activity_pc = pca.transform(activity_dict[i])
    trial = trial_infos[i]
    color = 'red' if trial['ground_truth'] == 0 else 'blue'
    _ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
    if i < 3:
      _ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
  ax1.set_xlabel('PC 1')
  ax1.set_ylabel('PC 2')
  plt.show()
[14]:
run(num_trial=1)
Shape of the neural activity: (Time points, Neurons):  (32, 64)
Trial  0 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
_images/recurrent_networks_ParametricWorkingMemory_14_1.png
[15]:
run(num_trial=20)
Shape of the neural activity: (Time points, Neurons):  (562, 64)
Trial  0 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial  1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  2 {'ground_truth': 2, 'vpair': (26, 18), 'v1': 18, 'v2': 26}
Trial  3 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial  4 {'ground_truth': 1, 'vpair': (18, 10), 'v1': 18, 'v2': 10}
_images/recurrent_networks_ParametricWorkingMemory_15_1.png
[16]:
run(num_trial=100)
Shape of the neural activity: (Time points, Neurons):  (2778, 64)
Trial  0 {'ground_truth': 2, 'vpair': (30, 22), 'v1': 22, 'v2': 30}
Trial  1 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  2 {'ground_truth': 1, 'vpair': (26, 18), 'v1': 26, 'v2': 18}
Trial  3 {'ground_truth': 2, 'vpair': (22, 14), 'v1': 14, 'v2': 22}
Trial  4 {'ground_truth': 1, 'vpair': (22, 14), 'v1': 22, 'v2': 14}
_images/recurrent_networks_ParametricWorkingMemory_16_1.png

(Song, et al., 2016): Training excitatory-inhibitory recurrent network

Implementation of the paper:

  • Song, H. F. , G. R. Yang , and X. J. Wang . “Training Excitatory-Inhibitory Recurrent Neural Networks for Cognitive Tasks: A Simple and Flexible Framework.” Plos Computational Biology 12.2(2016):e1004792.

The original code is based on PyTorch (https://github.com/gyyang/nn-brain/blob/master/EI_RNN.ipynb). However, comparing with the PyTorch codes, the training on BrainPy speeds up nearly four folds.

Here we will train recurrent neural network with excitatory and inhibitory neurons on a simple perceptual decision making task.

[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt

Defining a perceptual decision making task

[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym

import neurogym as ngym
[4]:
# Environment
task = 'PerceptualDecisionMaking-v0'
timing = {
  'fixation': ('choice', (50, 100, 200, 400)),
  'stimulus': ('choice', (100, 200, 400, 800)),
}
kwargs = {'dt': 20, 'timing': timing}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task,
                       env_kwargs=kwargs,
                       batch_size=16,
                       seq_len=seq_len)

# A sample environment from dataset
env = dataset.env

# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, fig_kwargs={'figsize': (10, 6)})
plt.show()
_images/recurrent_networks_Song_2016_EI_RNN_7_0.png
[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size

print(f'Input size = {input_size}')
print(f'Output size = {output_size}')
print(f'Bacth size = {batch_size}')
Input size = 3
Output size = 3
Bacth size = 16

Define E-I recurrent network

Here we define a E-I recurrent network, in particular, no self-connections are allowed.

[6]:
class RNN(bp.DynamicalSystem):
  r"""E-I RNN.

  The RNNs are described by the equations

  .. math::

      \begin{gathered}
      \tau \dot{\mathbf{x}}=-\mathbf{x}+W^{\mathrm{rec}} \mathbf{r}+W^{\mathrm{in}}
      \mathbf{u}+\sqrt{2 \tau \sigma_{\mathrm{rec}}^{2}} \xi \\
      \mathbf{r}=[\mathbf{x}]_{+} \\
      \mathbf{z}=W^{\text {out }} \mathbf{r}
      \end{gathered}

  In practice, the continuous-time dynamics are discretized to Euler form
  in time steps of size :math:`\Delta t` as

  .. math::

     \begin{gathered}
      \mathbf{x}_{t}=(1-\alpha) \mathbf{x}_{t-1}+\alpha\left(W^{\mathrm{rec}} \mathbf{r}_{t-1}+
      W^{\mathrm{in}} \mathbf{u}_{t}\right)+\sqrt{2 \alpha \sigma_{\mathrm{rec}}^{2}} \mathbf{N}(0,1) \\
      \mathbf{r}_{t}=\left[\mathbf{x}_{t}\right]_{+} \\
      \mathbf{z}_{t}=W^{\mathrm{out}} \mathbf{r}_{t}
      \end{gathered}

  where :math:`\alpha = \Delta t/\tau` and :math:`N(0, 1)` are normally distributed
  random numbers with zero mean and unit variance, sampled independently at every time step.
  """

  def __init__(self, num_input, num_hidden, num_output, num_batch,
               dt=None, e_ratio=0.8, sigma_rec=0., seed=None,
               w_ir=bp.init.KaimingUniform(scale=1.),
               w_rr=bp.init.KaimingUniform(scale=1.),
               w_ro=bp.init.KaimingUniform(scale=1.)):
    super(RNN, self).__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    self.e_size = int(num_hidden * e_ratio)
    self.i_size = num_hidden - self.e_size
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.sigma_rec = (2 * self.alpha) ** 0.5 * sigma_rec  # Recurrent noise
    self.rng = bm.random.RandomState(seed=seed)

    # hidden mask
    mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1))
    np.fill_diagonal(mask, 0)
    self.mask = bm.asarray(mask, dtype=bm.float_)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
    self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size)
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    bound = 1 / self.e_size ** 0.5
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    # variables
    self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
    self.o = bm.Variable(bm.zeros((num_batch, num_output)))

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ (bm.abs(self.w_rr) * self.mask) + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    state += self.sigma_rec * self.rng.randn(self.num_hidden)
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h)
    self.o.value = self.readout(self.h[:, :self.e_size])
    return self.h.value, self.o.value

  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs)

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    return bp.losses.cross_entropy_loss(os, ys.flatten())

Train the network on the decision making task

[7]:
# Instantiate the network and print information
hidden_size = 50
net = RNN(num_input=input_size,
          num_hidden=hidden_size,
          num_output=output_size,
          num_batch=batch_size,
          dt=env.dt,
          sigma_rec=0.15)
[8]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())
[9]:
# gradient function
grad_f = bm.grad(net.loss,
                 child_objs=net,
                 grad_vars=net.train_vars().unique(),
                 return_value=True)
[10]:
@bm.jit(child_objs=(net, opt))
def train(xs, ys):
  grads, loss = grad_f(xs, ys)
  opt.update(grads)
  return loss

The training speeds up nearly 4 times, comparing with the original PyTorch codes.

[11]:
running_loss = 0
print_step = 200
for i in range(5000):
  inputs, labels = dataset()
  inputs = bm.asarray(inputs)
  labels = bm.asarray(labels)
  loss = train(inputs, labels)
  running_loss += loss
  if i % print_step == (print_step - 1):
    running_loss /= print_step
    print('Step {}, Loss {:0.4f}'.format(i + 1, running_loss))
    running_loss = 0
Step 200, Loss 0.6556
Step 400, Loss 0.4587
Step 600, Loss 0.4140
Step 800, Loss 0.3671
Step 1000, Loss 0.3321
Step 1200, Loss 0.3048
Step 1400, Loss 0.2851
Step 1600, Loss 0.2638
Step 1800, Loss 0.2431
Step 2000, Loss 0.2230
Step 2200, Loss 0.2083
Step 2400, Loss 0.1932
Step 2600, Loss 0.1787
Step 2800, Loss 0.1673
Step 3000, Loss 0.1595
Step 3200, Loss 0.1457
Step 3400, Loss 0.1398
Step 3600, Loss 0.1335
Step 3800, Loss 0.1252
Step 4000, Loss 0.1204
Step 4200, Loss 0.1151
Step 4400, Loss 0.1099
Step 4600, Loss 0.1075
Step 4800, Loss 0.1027
Step 5000, Loss 0.0976

Run the network post-training and record neural activity

[12]:
predict = bm.jit(net.predict, dyn_vars=net.vars())
[13]:
env.reset(no_step=True)
env.timing.update({'fixation': ('constant', 500), 'stimulus': ('constant', 500)})
perf = 0
num_trial = 500
activity_dict = {}
trial_infos = {}
stim_activity = [[], []]  # response for ground-truth 0 and 1
for i in range(num_trial):
  env.new_trial()
  ob, gt = env.ob, env.gt
  inputs = bm.asarray(ob[:, np.newaxis, :])
  rnn_activity, action_pred = predict(inputs)

  # Compute performance
  action_pred = bm.as_numpy(action_pred)
  choice = np.argmax(action_pred[-1, 0, :])
  correct = choice == gt[-1]

  # Log trial info
  trial_info = env.trial
  trial_info.update({'correct': correct, 'choice': choice})
  trial_infos[i] = trial_info

  # Log stimulus period activity
  rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
  activity_dict[i] = rnn_activity

  # Compute stimulus selectivity for all units
  # Compute each neuron's response in trials where ground_truth=0 and 1 respectively
  rnn_activity = rnn_activity[env.start_ind['stimulus']: env.end_ind['stimulus']]
  stim_activity[env.trial['ground_truth']].append(rnn_activity)

print('Average performance', np.mean([val['correct'] for val in trial_infos.values()]))
Average performance 0.81

Plot neural activity from sample trials

[14]:
trial = 2
plt.figure(figsize=(8, 6))
_ = plt.plot(activity_dict[trial][:, :net.e_size], color='blue', label='Excitatory')
_ = plt.plot(activity_dict[trial][:, net.e_size:], color='red', label='Inhibitory')
plt.xlabel('Time step')
plt.ylabel('Activity')
plt.show()
_images/recurrent_networks_Song_2016_EI_RNN_22_0.png

Compute stimulus selectivity for sorting neurons

Here for each neuron we compute its stimulus period selectivity \(d'\)

[15]:
mean_activity = []
std_activity = []
for ground_truth in [0, 1]:
  activity = np.concatenate(stim_activity[ground_truth], axis=0)
  mean_activity.append(np.mean(activity, axis=0))
  std_activity.append(np.std(activity, axis=0))

# Compute d'
selectivity = (mean_activity[0] - mean_activity[1])
selectivity /= np.sqrt((std_activity[0] ** 2 + std_activity[1] ** 2 + 1e-7) / 2)

# Sort index for selectivity, separately for E and I
ind_sort = np.concatenate((np.argsort(selectivity[:net.e_size]),
                           np.argsort(selectivity[net.e_size:]) + net.e_size))

Plot network connectivity sorted by stimulus selectivity

[16]:
# Plot distribution of stimulus selectivity
plt.figure(figsize=(6, 4))
plt.hist(selectivity)
plt.xlabel('Selectivity')
plt.ylabel('Number of neurons')
plt.show()
_images/recurrent_networks_Song_2016_EI_RNN_26_0.png
[17]:
W = (bm.abs(net.w_rr) * net.mask).numpy()
# Sort by selectivity
W = W[:, ind_sort][ind_sort, :]
wlim = np.max(np.abs(W))

plt.figure(figsize=(10, 10))
plt.imshow(W, cmap='bwr_r', vmin=-wlim, vmax=wlim)
plt.colorbar()
plt.xlabel('From neurons')
plt.ylabel('To neurons')
plt.title('Network connectivity')
plt.tight_layout()
plt.show()
_images/recurrent_networks_Song_2016_EI_RNN_27_0.png

(Masse, et al., 2019): RNN with STP for Working Memory

Re-implementation of the paper with BrainPy:

  • Masse, Nicolas Y., Guangyu R. Yang, H. Francis Song, Xiao-Jing Wang, and David J. Freedman. “Circuit mechanisms for the maintenance and manipulation of information in working memory.” Nature neuroscience 22, no. 7 (2019): 1159-1167.

Thanks the original GitHub code: https://github.com/nmasse/Short-term-plasticity-RNN

The code for the implementation of Task please refer to the Masse_2019_STP_RNN_tasks.py.

The analysis methods please refer to the original repository: https://github.com/nmasse/Short-term-plasticity-RNN/blob/master/analysis.py

[11]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[12]:
import os
import math
import pickle
import numpy as np
from Masse_2019_STP_RNN_tasks import Task
[13]:
# Time parameters
dt = 100  # ms
dt_sec = dt / 1000
time_constant = 100  # ms
alpha = dt / time_constant
[14]:
# Loss parameters
spike_regularization = 'L2'  # 'L1' or 'L2'
spike_cost = 2e-2
weight_cost = 0.
clip_max_grad_val = 0.1
[15]:
# Training specs
batch_size = 1024
learning_rate = 2e-2
[16]:
def initialize(shape, prob, size):
  w = bm.random.gamma(shape, size=size)
  w *= (bm.random.random(size) < prob)
  return bm.asarray(w, dtype=bm.float32)

Model

[17]:
class Model(bp.DynamicalSystem):
  def __init__(self, task, num_hidden=100, name=None):
    super(Model, self).__init__(name=name)

    assert isinstance(task, Task)
    self.task = task

    # Network configuration
    self.exc_inh_prop = 0.8  # excitatory/inhibitory ratio
    self.conn_prob = 0.2

    # Network shape
    self.num_output = task.num_output
    self.num_hidden = num_hidden
    self.num_input = task.num_input

    # EI
    self.num_exc = int(self.num_hidden * self.exc_inh_prop)
    self.num_inh = self.num_hidden - self.num_exc
    self.EI_list = bm.ones(self.num_hidden)
    self.EI_list[self.num_exc:] = -1.
    self.EI_matrix = bm.diag(self.EI_list)
    self.inh_index = bm.arange(self.num_exc, self.num_hidden)

    # Input and noise
    self.noise_rnn = math.sqrt(2 * alpha) * 0.5

    # Synaptic plasticity specs
    self.tau_fast = 200  # ms
    self.tau_slow = 1500  # ms
    self.U_stf = 0.15
    self.U_std = 0.45

    # Initial hidden values
    self.init_h = bm.TrainVar(bm.ones((batch_size, self.num_hidden)) * 0.1)
    self.h = bm.Variable(bm.ones((batch_size, self.num_hidden)) * 0.1)

    # Input/recurrent/output weights
    #   1. w_ir (input => recurrent)
    prob = self.conn_prob * task.num_receptive_fields
    self.w_ir = bm.TrainVar(initialize(0.2, prob, (self.num_input, self.num_hidden)))
    self.w_ir_mask = bm.ones((self.num_input, self.num_hidden))
    if task.trial_type == 'location_DMS':
      self.w_ir_mask *= 0.
      target_ind = [range(0, self.num_hidden, 3), range(1, self.num_hidden, 3), range(2, self.num_hidden, 3)]
      for n in range(self.num_input):
        u = int(n // (self.num_input / 3))
        self.w_ir_mask[n, target_ind[u]] = 1.
      self.w_ir *= self.w_ir_mask  # only preserve
    #   2. w_rr (recurrent => recurrent)
    self.w_rr = bm.TrainVar(initialize(0.1, self.conn_prob, (self.num_hidden, self.num_hidden)))
    self.w_rr[:, self.num_exc:] = initialize(0.2, self.conn_prob, (self.num_hidden, self.num_inh))
    self.w_rr[self.num_exc:, :] = initialize(0.2, self.conn_prob, (self.num_inh, self.num_hidden))
    self.w_rr_mask = bm.ones((self.num_hidden, self.num_hidden)) - bm.eye(self.num_hidden)
    self.w_rr *= self.w_rr_mask  # remove self-connections
    self.b_rr = bm.TrainVar(bm.zeros((1, self.num_hidden)))
    #   3. w_ro (input => recurrent)
    self.w_ro = bm.TrainVar(initialize(0.1, self.conn_prob, (self.num_hidden, self.num_output)))
    self.w_ro_mask = bm.ones((self.num_hidden, self.num_output))
    self.w_ro_mask[self.num_exc:, :] = 0.
    self.w_ro *= self.w_ro_mask  # remove inhibitory-to-output connections
    #   4. b_ro (bias)
    self.b_ro = bm.TrainVar(bm.zeros((1, self.num_output)))

    # Synaptic variables
    #   - The first row (first half neurons) are facilitating synapses
    #   - The second row (last half neurons) are depressing synapses
    alpha_stf = bm.ones((2, int(self.num_hidden / 2)))
    alpha_stf[0] = dt / self.tau_slow
    alpha_stf[1] = dt / self.tau_fast
    alpha_std = bm.ones((2, int(self.num_hidden / 2)))
    alpha_std[0] = dt / self.tau_fast
    alpha_std[1] = dt / self.tau_slow
    U = bm.ones((2, int(self.num_hidden / 2)))
    U[0] = 0.15
    U[1] = 0.45
    u = bm.ones((batch_size, 2, int(self.num_hidden / 2))) * 0.3
    u[:, 0] = 0.15
    u[:, 1] = 0.45
    #   - final
    self.alpha_stf = alpha_stf.reshape((1, -1))
    self.alpha_std = alpha_std.reshape((1, -1))
    self.U = U.reshape((1, -1))
    self.u = bm.Variable(u.reshape((batch_size, -1)))
    self.x = bm.Variable(bm.ones((batch_size, self.num_hidden)))
    self.y = bm.Variable(bm.ones((batch_size, self.num_output)))
    self.y_hist = bm.Variable(bm.zeros((task.num_steps, batch_size, task.num_output)))

    # Loss
    self.loss = bm.Variable(bm.zeros(1))
    self.perf_loss = bm.Variable(bm.zeros(1))
    self.spike_loss = bm.Variable(bm.zeros(1))
    self.weight_loss = bm.Variable(bm.zeros(1))

  def reset_state(self, batch_size):
    u = bm.ones((batch_size, 2, int(self.num_hidden / 2))) * 0.3
    u[:, 0] = 0.15
    u[:, 1] = 0.45
    self.u.value = u.reshape((batch_size, -1))
    self.x.value = bm.ones((batch_size, self.num_hidden))
    self.loss[:] = 0.
    self.perf_loss[:] = 0.
    self.spike_loss[:] = 0.
    self.weight_loss[:] = 0.

  def update(self, input):
    # update STP variables
    self.x += (self.alpha_std * (1 - self.x) - dt_sec * self.u * self.x * self.h)
    self.u += (self.alpha_stf * (self.U - self.u) + dt_sec * self.U * (1 - self.u) * self.h)
    self.x.value = bm.minimum(1., bm.relu(self.x))
    self.u.value = bm.minimum(1., bm.relu(self.u))
    h_post = self.u * self.x * self.h

    # Update the hidden state. Only use excitatory projections from input layer to RNN
    # All input and RNN activity will be non-negative
    state = alpha * (input @ bm.relu(self.w_ir) + h_post @ self.w_rr + self.b_rr)
    state += bm.random.normal(0, self.noise_rnn, self.h.shape)
    self.h.value = bm.relu(state) + self.h * (1 - alpha)
    self.y.value = self.h @ bm.relu(self.w_ro) + self.b_ro

  def predict(self, inputs):
    self.h[:] = self.init_h
    scan = bm.make_loop(body_fun=self.update,
                        dyn_vars=[self.x, self.u, self.h, self.y],
                        out_vars=[self.y, self.h])
    logits, hist_h = scan(inputs)
    self.y_hist[:] = logits
    return logits, hist_h

  def loss_func(self, inputs, targets, mask):
    logits, hist_h = self.predict(inputs)

    # Calculate the performance loss
    perf_loss = bp.losses.cross_entropy_loss(logits, targets, reduction='none') * mask
    self.perf_loss[:] = bm.mean(perf_loss)

    # L1/L2 penalty term on hidden state activity to encourage low spike rate solutions
    n = 2 if spike_regularization == 'L2' else 1
    self.spike_loss[:] = bm.mean(hist_h ** n)
    self.weight_loss[:] = bm.mean(bm.relu(self.w_rr) ** n)

    # final loss
    self.loss[:] = self.perf_loss + spike_cost * self.spike_loss + weight_cost * self.weight_loss
    return self.loss.mean()

Analysis

[18]:
def get_perf(target, output, mask):
  """Calculate task accuracy by comparing the actual network output to the desired output
    only examine time points when test stimulus is on, e.g. when y[:,:,0] = 0 """
  target = target.numpy()
  output = output.numpy()
  mask = mask.numpy()

  mask_full = mask > 0
  mask_test = mask_full * (target[:, :, 0] == 0)
  mask_non_match = mask_full * (target[:, :, 1] == 1)
  mask_match = mask_full * (target[:, :, 2] == 1)
  target_max = np.argmax(target, axis=2)
  output_max = np.argmax(output, axis=2)

  match = target_max == output_max
  accuracy = np.sum(match * mask_test) / np.sum(mask_test)
  acc_non_match = np.sum(match * np.squeeze(mask_non_match)) / np.sum(mask_non_match)
  acc_match = np.sum(match * np.squeeze(mask_match)) / np.sum(mask_match)
  return accuracy, acc_non_match, acc_match

Training

[19]:
def trial(task_name, save_fn=None, num_iterations=2000, iter_between_outputs=5):
  task = Task(task_name, dt=dt, tau=time_constant, batch_size=batch_size)
  # trial_info = task.generate_trial(set_rule=None)
  # task.plot_neural_input(trial_info)

  model = Model(task)
  opt = bp.optim.Adam(learning_rate, train_vars=model.train_vars())
  grad_f = bm.grad(model.loss_func,
                   dyn_vars=model.vars(),
                   grad_vars=model.train_vars(),
                   return_value=True)

  @bm.jit
  @bm.to_object(child_objs=(model, opt))
  def train_op(x, y, mask):
    grads, _ = grad_f(x, y, mask)
    capped_gs = dict()
    for key, grad in grads.items():
      if 'w_rr' in key: grad *= model.w_rr_mask
      elif 'w_ro' in key: grad *= model.w_ro_mask
      elif 'w_ri' in key: grad *= model.w_ir_mask
      capped_gs[key] = bm.clip_by_norm(grad, clip_max_grad_val)
    opt.update(grads=capped_gs)

  # keep track of the model performance across training
  model_performance = {'accuracy': [], 'loss': [], 'perf_loss': [],
                       'spike_loss': [], 'weight_loss': [], 'iteration': []}

  for i in range(num_iterations):
    model.reset_state(batch_size)
    # generate batch of batch_train_size
    trial_info = task.generate_trial(set_rule=None)
    inputs = bm.array(trial_info['neural_input'], dtype=bm.float32)
    targets = bm.array(trial_info['desired_output'], dtype=bm.float32)
    mask = bm.array(trial_info['train_mask'], dtype=bm.float32)

    # Run the model
    train_op(inputs, targets, mask)

    # get metrics
    accuracy, _, _ = get_perf(targets, model.y_hist, mask)
    model_performance['accuracy'].append(accuracy)
    model_performance['loss'].append(model.loss)
    model_performance['perf_loss'].append(model.perf_loss)
    model_performance['spike_loss'].append(model.spike_loss)
    model_performance['weight_loss'].append(model.weight_loss)
    model_performance['iteration'].append(i)

    # Save the network model and output model performance to screen
    if i % iter_between_outputs == 0:
      print(task_name +
            f' Iter {i:4d}' +
            f' | Accuracy {accuracy:0.4f}' +
            f' | Perf loss {model.perf_loss[0]:0.4f}' +
            f' | Spike loss {model.spike_loss[0]:0.4f}' +
            f' | Weight loss {model.weight_loss[0]:0.4f}' +
            f' | Mean activity {bm.mean(model.h):0.4f}')

  if save_fn:
    if not os.path.exists(os.path.dirname(save_fn)):
      os.makedirs(os.path.dirname(save_fn))

    # Save model and results
    weights = model.train_vars().unique().dict()
    results = {'weights': weights, 'parameters': {}}
    for k, v in model_performance.items():
      results[k] = v
    pickle.dump(results, open(save_fn, 'wb'))
[20]:
trial('DMS')
DMS Iter    0 | Accuracy 0.5211 | Perf loss 2.7038 | Spike loss 7.8653 | Weight loss 0.0315 | Mean activity 3.3380
DMS Iter    5 | Accuracy 0.4184 | Perf loss 0.4879 | Spike loss 5.5533 | Weight loss 0.0315 | Mean activity 1.6523
DMS Iter   10 | Accuracy 0.4473 | Perf loss 0.2892 | Spike loss 4.3892 | Weight loss 0.0312 | Mean activity 1.2179
DMS Iter   15 | Accuracy 0.4654 | Perf loss 0.2386 | Spike loss 2.8176 | Weight loss 0.0308 | Mean activity 1.0014
DMS Iter   20 | Accuracy 0.5000 | Perf loss 0.2102 | Spike loss 2.4600 | Weight loss 0.0309 | Mean activity 0.8940
DMS Iter   25 | Accuracy 0.4783 | Perf loss 0.2018 | Spike loss 2.2587 | Weight loss 0.0311 | Mean activity 0.8501
DMS Iter   30 | Accuracy 0.5119 | Perf loss 0.1815 | Spike loss 2.0880 | Weight loss 0.0313 | Mean activity 0.7721
DMS Iter   35 | Accuracy 0.5340 | Perf loss 0.1697 | Spike loss 1.7911 | Weight loss 0.0318 | Mean activity 0.7324
DMS Iter   40 | Accuracy 0.5469 | Perf loss 0.1640 | Spike loss 1.7178 | Weight loss 0.0324 | Mean activity 0.6913
DMS Iter   45 | Accuracy 0.5764 | Perf loss 0.1524 | Spike loss 1.6305 | Weight loss 0.0332 | Mean activity 0.6684
DMS Iter   50 | Accuracy 0.6229 | Perf loss 0.1436 | Spike loss 1.5921 | Weight loss 0.0339 | Mean activity 0.6497
DMS Iter   55 | Accuracy 0.6301 | Perf loss 0.1382 | Spike loss 1.5635 | Weight loss 0.0346 | Mean activity 0.6360
DMS Iter   60 | Accuracy 0.6254 | Perf loss 0.1392 | Spike loss 1.4905 | Weight loss 0.0353 | Mean activity 0.6193
DMS Iter   65 | Accuracy 0.6494 | Perf loss 0.1325 | Spike loss 1.4990 | Weight loss 0.0360 | Mean activity 0.6177
DMS Iter   70 | Accuracy 0.6910 | Perf loss 0.1232 | Spike loss 1.4614 | Weight loss 0.0365 | Mean activity 0.6192
DMS Iter   75 | Accuracy 0.6977 | Perf loss 0.1206 | Spike loss 1.4657 | Weight loss 0.0372 | Mean activity 0.6092
DMS Iter   80 | Accuracy 0.7135 | Perf loss 0.1178 | Spike loss 1.4238 | Weight loss 0.0377 | Mean activity 0.6133
DMS Iter   85 | Accuracy 0.7240 | Perf loss 0.1133 | Spike loss 1.4442 | Weight loss 0.0383 | Mean activity 0.6201
DMS Iter   90 | Accuracy 0.7273 | Perf loss 0.1093 | Spike loss 1.4717 | Weight loss 0.0390 | Mean activity 0.6170
DMS Iter   95 | Accuracy 0.7404 | Perf loss 0.1055 | Spike loss 1.4073 | Weight loss 0.0395 | Mean activity 0.6214
DMS Iter  100 | Accuracy 0.7670 | Perf loss 0.0974 | Spike loss 1.4685 | Weight loss 0.0403 | Mean activity 0.6321
DMS Iter  105 | Accuracy 0.7639 | Perf loss 0.0996 | Spike loss 1.4525 | Weight loss 0.0410 | Mean activity 0.6321
DMS Iter  110 | Accuracy 0.7947 | Perf loss 0.0879 | Spike loss 1.4711 | Weight loss 0.0416 | Mean activity 0.6412
DMS Iter  115 | Accuracy 0.7881 | Perf loss 0.0893 | Spike loss 1.4721 | Weight loss 0.0426 | Mean activity 0.6400
DMS Iter  120 | Accuracy 0.7902 | Perf loss 0.0853 | Spike loss 1.4998 | Weight loss 0.0433 | Mean activity 0.6495
DMS Iter  125 | Accuracy 0.8094 | Perf loss 0.0775 | Spike loss 1.4819 | Weight loss 0.0440 | Mean activity 0.6519
DMS Iter  130 | Accuracy 0.8127 | Perf loss 0.0780 | Spike loss 1.4739 | Weight loss 0.0446 | Mean activity 0.6535
DMS Iter  135 | Accuracy 0.8295 | Perf loss 0.0725 | Spike loss 1.4526 | Weight loss 0.0452 | Mean activity 0.6481
DMS Iter  140 | Accuracy 0.8303 | Perf loss 0.0702 | Spike loss 1.4571 | Weight loss 0.0458 | Mean activity 0.6508
DMS Iter  145 | Accuracy 0.8248 | Perf loss 0.0719 | Spike loss 1.4244 | Weight loss 0.0463 | Mean activity 0.6533
DMS Iter  150 | Accuracy 0.8318 | Perf loss 0.0707 | Spike loss 1.4362 | Weight loss 0.0468 | Mean activity 0.6520
DMS Iter  155 | Accuracy 0.8227 | Perf loss 0.0718 | Spike loss 1.4525 | Weight loss 0.0476 | Mean activity 0.6625
DMS Iter  160 | Accuracy 0.8293 | Perf loss 0.0687 | Spike loss 1.4380 | Weight loss 0.0481 | Mean activity 0.6525
DMS Iter  165 | Accuracy 0.8318 | Perf loss 0.0672 | Spike loss 1.4201 | Weight loss 0.0484 | Mean activity 0.6538
DMS Iter  170 | Accuracy 0.8379 | Perf loss 0.0639 | Spike loss 1.4193 | Weight loss 0.0488 | Mean activity 0.6541
DMS Iter  175 | Accuracy 0.8297 | Perf loss 0.0673 | Spike loss 1.4213 | Weight loss 0.0494 | Mean activity 0.6478
DMS Iter  180 | Accuracy 0.8422 | Perf loss 0.0619 | Spike loss 1.4273 | Weight loss 0.0497 | Mean activity 0.6586
DMS Iter  185 | Accuracy 0.8475 | Perf loss 0.0614 | Spike loss 1.3814 | Weight loss 0.0500 | Mean activity 0.6567
DMS Iter  190 | Accuracy 0.8451 | Perf loss 0.0628 | Spike loss 1.3811 | Weight loss 0.0505 | Mean activity 0.6565
DMS Iter  195 | Accuracy 0.8518 | Perf loss 0.0601 | Spike loss 1.3828 | Weight loss 0.0511 | Mean activity 0.6549
DMS Iter  200 | Accuracy 0.8529 | Perf loss 0.0569 | Spike loss 1.3763 | Weight loss 0.0515 | Mean activity 0.6515
DMS Iter  205 | Accuracy 0.8508 | Perf loss 0.0577 | Spike loss 1.3716 | Weight loss 0.0520 | Mean activity 0.6461
DMS Iter  210 | Accuracy 0.8514 | Perf loss 0.0579 | Spike loss 1.3755 | Weight loss 0.0524 | Mean activity 0.6428
DMS Iter  215 | Accuracy 0.8512 | Perf loss 0.0573 | Spike loss 1.3652 | Weight loss 0.0529 | Mean activity 0.6437
DMS Iter  220 | Accuracy 0.8521 | Perf loss 0.0559 | Spike loss 1.3662 | Weight loss 0.0533 | Mean activity 0.6355
DMS Iter  225 | Accuracy 0.8609 | Perf loss 0.0540 | Spike loss 1.3805 | Weight loss 0.0535 | Mean activity 0.6425
DMS Iter  230 | Accuracy 0.8682 | Perf loss 0.0520 | Spike loss 1.3702 | Weight loss 0.0541 | Mean activity 0.6354
DMS Iter  235 | Accuracy 0.8682 | Perf loss 0.0522 | Spike loss 1.3818 | Weight loss 0.0544 | Mean activity 0.6353
DMS Iter  240 | Accuracy 0.8629 | Perf loss 0.0529 | Spike loss 1.3609 | Weight loss 0.0548 | Mean activity 0.6275
DMS Iter  245 | Accuracy 0.8494 | Perf loss 0.0576 | Spike loss 1.3546 | Weight loss 0.0555 | Mean activity 0.6225
DMS Iter  250 | Accuracy 0.8623 | Perf loss 0.0535 | Spike loss 1.3358 | Weight loss 0.0557 | Mean activity 0.6133
DMS Iter  255 | Accuracy 0.8684 | Perf loss 0.0512 | Spike loss 1.3440 | Weight loss 0.0562 | Mean activity 0.6131
DMS Iter  260 | Accuracy 0.8705 | Perf loss 0.0520 | Spike loss 1.3385 | Weight loss 0.0566 | Mean activity 0.6083
DMS Iter  265 | Accuracy 0.8678 | Perf loss 0.0519 | Spike loss 1.3221 | Weight loss 0.0567 | Mean activity 0.6024
DMS Iter  270 | Accuracy 0.8740 | Perf loss 0.0492 | Spike loss 1.3103 | Weight loss 0.0571 | Mean activity 0.5915
DMS Iter  275 | Accuracy 0.8762 | Perf loss 0.0465 | Spike loss 1.2993 | Weight loss 0.0574 | Mean activity 0.5922
DMS Iter  280 | Accuracy 0.8723 | Perf loss 0.0489 | Spike loss 1.2954 | Weight loss 0.0578 | Mean activity 0.5907
DMS Iter  285 | Accuracy 0.8660 | Perf loss 0.0495 | Spike loss 1.3016 | Weight loss 0.0583 | Mean activity 0.5868
DMS Iter  290 | Accuracy 0.8783 | Perf loss 0.0478 | Spike loss 1.2908 | Weight loss 0.0589 | Mean activity 0.5819
DMS Iter  295 | Accuracy 0.8711 | Perf loss 0.0487 | Spike loss 1.2881 | Weight loss 0.0591 | Mean activity 0.5714
DMS Iter  300 | Accuracy 0.8666 | Perf loss 0.0502 | Spike loss 1.2648 | Weight loss 0.0594 | Mean activity 0.5663
DMS Iter  305 | Accuracy 0.8715 | Perf loss 0.0476 | Spike loss 1.2757 | Weight loss 0.0599 | Mean activity 0.5588
DMS Iter  310 | Accuracy 0.8713 | Perf loss 0.0486 | Spike loss 1.2846 | Weight loss 0.0603 | Mean activity 0.5546
DMS Iter  315 | Accuracy 0.8686 | Perf loss 0.0498 | Spike loss 1.2558 | Weight loss 0.0604 | Mean activity 0.5455
DMS Iter  320 | Accuracy 0.8760 | Perf loss 0.0485 | Spike loss 1.2669 | Weight loss 0.0610 | Mean activity 0.5511
DMS Iter  325 | Accuracy 0.8746 | Perf loss 0.0500 | Spike loss 1.2692 | Weight loss 0.0613 | Mean activity 0.5438
DMS Iter  330 | Accuracy 0.8809 | Perf loss 0.0441 | Spike loss 1.2730 | Weight loss 0.0615 | Mean activity 0.5410
DMS Iter  335 | Accuracy 0.8809 | Perf loss 0.0454 | Spike loss 1.2474 | Weight loss 0.0619 | Mean activity 0.5316
DMS Iter  340 | Accuracy 0.8803 | Perf loss 0.0442 | Spike loss 1.2417 | Weight loss 0.0623 | Mean activity 0.5207
DMS Iter  345 | Accuracy 0.8844 | Perf loss 0.0444 | Spike loss 1.2522 | Weight loss 0.0626 | Mean activity 0.5163
DMS Iter  350 | Accuracy 0.8924 | Perf loss 0.0418 | Spike loss 1.2379 | Weight loss 0.0631 | Mean activity 0.5123
DMS Iter  355 | Accuracy 0.8795 | Perf loss 0.0457 | Spike loss 1.1955 | Weight loss 0.0632 | Mean activity 0.5204
DMS Iter  360 | Accuracy 0.8844 | Perf loss 0.0438 | Spike loss 1.2291 | Weight loss 0.0637 | Mean activity 0.5052
DMS Iter  365 | Accuracy 0.8867 | Perf loss 0.0446 | Spike loss 1.2282 | Weight loss 0.0642 | Mean activity 0.5023
DMS Iter  370 | Accuracy 0.8916 | Perf loss 0.0419 | Spike loss 1.2145 | Weight loss 0.0645 | Mean activity 0.4940
DMS Iter  375 | Accuracy 0.8951 | Perf loss 0.0407 | Spike loss 1.2218 | Weight loss 0.0651 | Mean activity 0.4938
DMS Iter  380 | Accuracy 0.8928 | Perf loss 0.0445 | Spike loss 1.2002 | Weight loss 0.0654 | Mean activity 0.4889
DMS Iter  385 | Accuracy 0.8953 | Perf loss 0.0413 | Spike loss 1.2121 | Weight loss 0.0656 | Mean activity 0.4712
DMS Iter  390 | Accuracy 0.8969 | Perf loss 0.0414 | Spike loss 1.2223 | Weight loss 0.0663 | Mean activity 0.4747
DMS Iter  395 | Accuracy 0.8936 | Perf loss 0.0415 | Spike loss 1.1969 | Weight loss 0.0664 | Mean activity 0.4663
DMS Iter  400 | Accuracy 0.8949 | Perf loss 0.0417 | Spike loss 1.1999 | Weight loss 0.0667 | Mean activity 0.4605
DMS Iter  405 | Accuracy 0.8977 | Perf loss 0.0417 | Spike loss 1.2007 | Weight loss 0.0666 | Mean activity 0.4563
DMS Iter  410 | Accuracy 0.8984 | Perf loss 0.0400 | Spike loss 1.1890 | Weight loss 0.0669 | Mean activity 0.4540
DMS Iter  415 | Accuracy 0.8967 | Perf loss 0.0408 | Spike loss 1.1682 | Weight loss 0.0672 | Mean activity 0.4449
DMS Iter  420 | Accuracy 0.8969 | Perf loss 0.0408 | Spike loss 1.1809 | Weight loss 0.0681 | Mean activity 0.4277
DMS Iter  425 | Accuracy 0.8932 | Perf loss 0.0405 | Spike loss 1.1743 | Weight loss 0.0683 | Mean activity 0.4309
DMS Iter  430 | Accuracy 0.8910 | Perf loss 0.0416 | Spike loss 1.1341 | Weight loss 0.0684 | Mean activity 0.4317
DMS Iter  435 | Accuracy 0.9074 | Perf loss 0.0386 | Spike loss 1.1491 | Weight loss 0.0686 | Mean activity 0.4341
DMS Iter  440 | Accuracy 0.8957 | Perf loss 0.0398 | Spike loss 1.1909 | Weight loss 0.0693 | Mean activity 0.4319
DMS Iter  445 | Accuracy 0.8938 | Perf loss 0.0407 | Spike loss 1.1691 | Weight loss 0.0694 | Mean activity 0.4270
DMS Iter  450 | Accuracy 0.9043 | Perf loss 0.0377 | Spike loss 1.1948 | Weight loss 0.0698 | Mean activity 0.4231
DMS Iter  455 | Accuracy 0.8969 | Perf loss 0.0411 | Spike loss 1.1601 | Weight loss 0.0700 | Mean activity 0.4285
DMS Iter  460 | Accuracy 0.9043 | Perf loss 0.0369 | Spike loss 1.1469 | Weight loss 0.0705 | Mean activity 0.4100
DMS Iter  465 | Accuracy 0.9037 | Perf loss 0.0378 | Spike loss 1.1253 | Weight loss 0.0711 | Mean activity 0.4051
DMS Iter  470 | Accuracy 0.9025 | Perf loss 0.0377 | Spike loss 1.1440 | Weight loss 0.0718 | Mean activity 0.4085
DMS Iter  475 | Accuracy 0.9043 | Perf loss 0.0376 | Spike loss 1.1172 | Weight loss 0.0720 | Mean activity 0.4043
DMS Iter  480 | Accuracy 0.9043 | Perf loss 0.0368 | Spike loss 1.1174 | Weight loss 0.0725 | Mean activity 0.3943
DMS Iter  485 | Accuracy 0.9133 | Perf loss 0.0339 | Spike loss 1.1094 | Weight loss 0.0729 | Mean activity 0.4056
DMS Iter  490 | Accuracy 0.9012 | Perf loss 0.0386 | Spike loss 1.1034 | Weight loss 0.0733 | Mean activity 0.3849
DMS Iter  495 | Accuracy 0.9037 | Perf loss 0.0377 | Spike loss 1.1030 | Weight loss 0.0735 | Mean activity 0.3882
DMS Iter  500 | Accuracy 0.9082 | Perf loss 0.0361 | Spike loss 1.0884 | Weight loss 0.0740 | Mean activity 0.3877
DMS Iter  505 | Accuracy 0.9098 | Perf loss 0.0352 | Spike loss 1.0988 | Weight loss 0.0743 | Mean activity 0.3768
DMS Iter  510 | Accuracy 0.9088 | Perf loss 0.0348 | Spike loss 1.0985 | Weight loss 0.0744 | Mean activity 0.3860
DMS Iter  515 | Accuracy 0.9053 | Perf loss 0.0370 | Spike loss 1.0939 | Weight loss 0.0746 | Mean activity 0.3713
DMS Iter  520 | Accuracy 0.9172 | Perf loss 0.0343 | Spike loss 1.1245 | Weight loss 0.0750 | Mean activity 0.3869
DMS Iter  525 | Accuracy 0.9039 | Perf loss 0.0362 | Spike loss 1.1007 | Weight loss 0.0754 | Mean activity 0.3688
DMS Iter  530 | Accuracy 0.9111 | Perf loss 0.0355 | Spike loss 1.0753 | Weight loss 0.0756 | Mean activity 0.3803
DMS Iter  535 | Accuracy 0.9064 | Perf loss 0.0361 | Spike loss 1.1138 | Weight loss 0.0761 | Mean activity 0.3709
DMS Iter  540 | Accuracy 0.9160 | Perf loss 0.0343 | Spike loss 1.1043 | Weight loss 0.0764 | Mean activity 0.3653
DMS Iter  545 | Accuracy 0.9195 | Perf loss 0.0323 | Spike loss 1.1468 | Weight loss 0.0772 | Mean activity 0.3768
DMS Iter  550 | Accuracy 0.9115 | Perf loss 0.0336 | Spike loss 1.1138 | Weight loss 0.0775 | Mean activity 0.3667
DMS Iter  555 | Accuracy 0.9125 | Perf loss 0.0343 | Spike loss 1.0900 | Weight loss 0.0775 | Mean activity 0.3603
DMS Iter  560 | Accuracy 0.9193 | Perf loss 0.0323 | Spike loss 1.0804 | Weight loss 0.0775 | Mean activity 0.3686
DMS Iter  565 | Accuracy 0.9209 | Perf loss 0.0317 | Spike loss 1.0787 | Weight loss 0.0778 | Mean activity 0.3531
DMS Iter  570 | Accuracy 0.9199 | Perf loss 0.0322 | Spike loss 1.0756 | Weight loss 0.0781 | Mean activity 0.3547
DMS Iter  575 | Accuracy 0.9135 | Perf loss 0.0340 | Spike loss 1.0826 | Weight loss 0.0782 | Mean activity 0.3441
DMS Iter  580 | Accuracy 0.9203 | Perf loss 0.0333 | Spike loss 1.0866 | Weight loss 0.0788 | Mean activity 0.3429
DMS Iter  585 | Accuracy 0.9219 | Perf loss 0.0335 | Spike loss 1.1002 | Weight loss 0.0795 | Mean activity 0.3590
DMS Iter  590 | Accuracy 0.9244 | Perf loss 0.0302 | Spike loss 1.0977 | Weight loss 0.0799 | Mean activity 0.3601
DMS Iter  595 | Accuracy 0.9254 | Perf loss 0.0317 | Spike loss 1.0713 | Weight loss 0.0804 | Mean activity 0.3472
DMS Iter  600 | Accuracy 0.9271 | Perf loss 0.0307 | Spike loss 1.0641 | Weight loss 0.0804 | Mean activity 0.3495
DMS Iter  605 | Accuracy 0.9217 | Perf loss 0.0315 | Spike loss 1.0845 | Weight loss 0.0808 | Mean activity 0.3475
DMS Iter  610 | Accuracy 0.9273 | Perf loss 0.0308 | Spike loss 1.0668 | Weight loss 0.0809 | Mean activity 0.3464
DMS Iter  615 | Accuracy 0.9299 | Perf loss 0.0301 | Spike loss 1.0581 | Weight loss 0.0810 | Mean activity 0.3463
DMS Iter  620 | Accuracy 0.9311 | Perf loss 0.0292 | Spike loss 1.0662 | Weight loss 0.0813 | Mean activity 0.3406
DMS Iter  625 | Accuracy 0.9271 | Perf loss 0.0301 | Spike loss 1.0467 | Weight loss 0.0815 | Mean activity 0.3372
DMS Iter  630 | Accuracy 0.9264 | Perf loss 0.0307 | Spike loss 1.0596 | Weight loss 0.0816 | Mean activity 0.3505
DMS Iter  635 | Accuracy 0.9287 | Perf loss 0.0316 | Spike loss 1.0671 | Weight loss 0.0818 | Mean activity 0.3264
DMS Iter  640 | Accuracy 0.9295 | Perf loss 0.0289 | Spike loss 1.0595 | Weight loss 0.0822 | Mean activity 0.3410
DMS Iter  645 | Accuracy 0.9375 | Perf loss 0.0290 | Spike loss 1.0762 | Weight loss 0.0826 | Mean activity 0.3229
DMS Iter  650 | Accuracy 0.9303 | Perf loss 0.0300 | Spike loss 1.0576 | Weight loss 0.0822 | Mean activity 0.3397
DMS Iter  655 | Accuracy 0.9340 | Perf loss 0.0278 | Spike loss 1.0979 | Weight loss 0.0827 | Mean activity 0.3331
DMS Iter  660 | Accuracy 0.9328 | Perf loss 0.0274 | Spike loss 1.0629 | Weight loss 0.0830 | Mean activity 0.3175
DMS Iter  665 | Accuracy 0.9381 | Perf loss 0.0261 | Spike loss 1.0790 | Weight loss 0.0833 | Mean activity 0.3261
DMS Iter  670 | Accuracy 0.9350 | Perf loss 0.0271 | Spike loss 1.0582 | Weight loss 0.0834 | Mean activity 0.3257
DMS Iter  675 | Accuracy 0.9350 | Perf loss 0.0263 | Spike loss 1.0180 | Weight loss 0.0833 | Mean activity 0.3198
DMS Iter  680 | Accuracy 0.9412 | Perf loss 0.0250 | Spike loss 1.0213 | Weight loss 0.0836 | Mean activity 0.3153
DMS Iter  685 | Accuracy 0.9391 | Perf loss 0.0272 | Spike loss 1.0157 | Weight loss 0.0838 | Mean activity 0.3290
DMS Iter  690 | Accuracy 0.9355 | Perf loss 0.0280 | Spike loss 1.0102 | Weight loss 0.0841 | Mean activity 0.3112
DMS Iter  695 | Accuracy 0.9357 | Perf loss 0.0272 | Spike loss 1.0224 | Weight loss 0.0847 | Mean activity 0.3169
DMS Iter  700 | Accuracy 0.9373 | Perf loss 0.0267 | Spike loss 1.0023 | Weight loss 0.0850 | Mean activity 0.3212
DMS Iter  705 | Accuracy 0.9395 | Perf loss 0.0256 | Spike loss 1.0093 | Weight loss 0.0854 | Mean activity 0.3088
DMS Iter  710 | Accuracy 0.9389 | Perf loss 0.0270 | Spike loss 0.9992 | Weight loss 0.0852 | Mean activity 0.3150
DMS Iter  715 | Accuracy 0.9396 | Perf loss 0.0248 | Spike loss 1.0078 | Weight loss 0.0856 | Mean activity 0.3174
DMS Iter  720 | Accuracy 0.9404 | Perf loss 0.0256 | Spike loss 1.0017 | Weight loss 0.0862 | Mean activity 0.3120
DMS Iter  725 | Accuracy 0.9459 | Perf loss 0.0250 | Spike loss 1.0026 | Weight loss 0.0867 | Mean activity 0.3089
DMS Iter  730 | Accuracy 0.9354 | Perf loss 0.0262 | Spike loss 1.0120 | Weight loss 0.0868 | Mean activity 0.3107
DMS Iter  735 | Accuracy 0.9387 | Perf loss 0.0261 | Spike loss 0.9864 | Weight loss 0.0867 | Mean activity 0.2896
DMS Iter  740 | Accuracy 0.9406 | Perf loss 0.0270 | Spike loss 1.0056 | Weight loss 0.0873 | Mean activity 0.3082
DMS Iter  745 | Accuracy 0.9443 | Perf loss 0.0236 | Spike loss 0.9958 | Weight loss 0.0875 | Mean activity 0.3037
DMS Iter  750 | Accuracy 0.9434 | Perf loss 0.0238 | Spike loss 1.0074 | Weight loss 0.0875 | Mean activity 0.2983
DMS Iter  755 | Accuracy 0.9479 | Perf loss 0.0227 | Spike loss 0.9761 | Weight loss 0.0874 | Mean activity 0.2986
DMS Iter  760 | Accuracy 0.9424 | Perf loss 0.0240 | Spike loss 0.9849 | Weight loss 0.0879 | Mean activity 0.3040
DMS Iter  765 | Accuracy 0.9471 | Perf loss 0.0233 | Spike loss 0.9560 | Weight loss 0.0883 | Mean activity 0.2987
DMS Iter  770 | Accuracy 0.9453 | Perf loss 0.0247 | Spike loss 0.9739 | Weight loss 0.0887 | Mean activity 0.2964
DMS Iter  775 | Accuracy 0.9445 | Perf loss 0.0240 | Spike loss 1.0132 | Weight loss 0.0887 | Mean activity 0.3047
DMS Iter  780 | Accuracy 0.9447 | Perf loss 0.0230 | Spike loss 1.0303 | Weight loss 0.0893 | Mean activity 0.3056
DMS Iter  785 | Accuracy 0.9451 | Perf loss 0.0232 | Spike loss 1.0083 | Weight loss 0.0894 | Mean activity 0.2938
DMS Iter  790 | Accuracy 0.9514 | Perf loss 0.0232 | Spike loss 0.9913 | Weight loss 0.0896 | Mean activity 0.2990
DMS Iter  795 | Accuracy 0.9465 | Perf loss 0.0242 | Spike loss 1.0077 | Weight loss 0.0900 | Mean activity 0.2884
DMS Iter  800 | Accuracy 0.9465 | Perf loss 0.0230 | Spike loss 0.9780 | Weight loss 0.0897 | Mean activity 0.2900
DMS Iter  805 | Accuracy 0.9422 | Perf loss 0.0254 | Spike loss 0.9774 | Weight loss 0.0902 | Mean activity 0.2843
DMS Iter  810 | Accuracy 0.9461 | Perf loss 0.0255 | Spike loss 0.9957 | Weight loss 0.0903 | Mean activity 0.2958
DMS Iter  815 | Accuracy 0.9488 | Perf loss 0.0239 | Spike loss 0.9833 | Weight loss 0.0907 | Mean activity 0.2843
DMS Iter  820 | Accuracy 0.9557 | Perf loss 0.0214 | Spike loss 1.0125 | Weight loss 0.0907 | Mean activity 0.2788
DMS Iter  825 | Accuracy 0.9504 | Perf loss 0.0232 | Spike loss 0.9782 | Weight loss 0.0907 | Mean activity 0.2872
DMS Iter  830 | Accuracy 0.9484 | Perf loss 0.0241 | Spike loss 1.0000 | Weight loss 0.0909 | Mean activity 0.2910
DMS Iter  835 | Accuracy 0.9479 | Perf loss 0.0232 | Spike loss 1.0095 | Weight loss 0.0911 | Mean activity 0.2851
DMS Iter  840 | Accuracy 0.9563 | Perf loss 0.0207 | Spike loss 1.0116 | Weight loss 0.0911 | Mean activity 0.2792
DMS Iter  845 | Accuracy 0.9521 | Perf loss 0.0226 | Spike loss 0.9752 | Weight loss 0.0912 | Mean activity 0.2781
DMS Iter  850 | Accuracy 0.9367 | Perf loss 0.0310 | Spike loss 0.9490 | Weight loss 0.0914 | Mean activity 0.2766
DMS Iter  855 | Accuracy 0.9434 | Perf loss 0.0261 | Spike loss 1.0368 | Weight loss 0.0929 | Mean activity 0.2730
DMS Iter  860 | Accuracy 0.9529 | Perf loss 0.0222 | Spike loss 1.0724 | Weight loss 0.0933 | Mean activity 0.2844
DMS Iter  865 | Accuracy 0.9504 | Perf loss 0.0234 | Spike loss 1.0631 | Weight loss 0.0938 | Mean activity 0.2806
DMS Iter  870 | Accuracy 0.9473 | Perf loss 0.0234 | Spike loss 1.0414 | Weight loss 0.0940 | Mean activity 0.2765
DMS Iter  875 | Accuracy 0.9490 | Perf loss 0.0230 | Spike loss 1.0663 | Weight loss 0.0946 | Mean activity 0.2675
DMS Iter  880 | Accuracy 0.9527 | Perf loss 0.0210 | Spike loss 1.0612 | Weight loss 0.0950 | Mean activity 0.2682
DMS Iter  885 | Accuracy 0.9563 | Perf loss 0.0204 | Spike loss 1.0489 | Weight loss 0.0951 | Mean activity 0.2637
DMS Iter  890 | Accuracy 0.9564 | Perf loss 0.0208 | Spike loss 1.0035 | Weight loss 0.0946 | Mean activity 0.2539
DMS Iter  895 | Accuracy 0.9547 | Perf loss 0.0207 | Spike loss 1.0087 | Weight loss 0.0948 | Mean activity 0.2553
DMS Iter  900 | Accuracy 0.9477 | Perf loss 0.0230 | Spike loss 0.9910 | Weight loss 0.0949 | Mean activity 0.2516
DMS Iter  905 | Accuracy 0.9557 | Perf loss 0.0221 | Spike loss 0.9653 | Weight loss 0.0950 | Mean activity 0.2548
DMS Iter  910 | Accuracy 0.9541 | Perf loss 0.0216 | Spike loss 0.9534 | Weight loss 0.0953 | Mean activity 0.2559
DMS Iter  915 | Accuracy 0.9557 | Perf loss 0.0194 | Spike loss 0.9345 | Weight loss 0.0950 | Mean activity 0.2515
DMS Iter  920 | Accuracy 0.9486 | Perf loss 0.0233 | Spike loss 0.9277 | Weight loss 0.0951 | Mean activity 0.2611
DMS Iter  925 | Accuracy 0.9553 | Perf loss 0.0231 | Spike loss 0.9238 | Weight loss 0.0951 | Mean activity 0.2570
DMS Iter  930 | Accuracy 0.9574 | Perf loss 0.0205 | Spike loss 0.9232 | Weight loss 0.0952 | Mean activity 0.2584
DMS Iter  935 | Accuracy 0.9545 | Perf loss 0.0219 | Spike loss 0.9053 | Weight loss 0.0956 | Mean activity 0.2516
DMS Iter  940 | Accuracy 0.9518 | Perf loss 0.0212 | Spike loss 0.9380 | Weight loss 0.0963 | Mean activity 0.2494
DMS Iter  945 | Accuracy 0.9557 | Perf loss 0.0200 | Spike loss 0.9615 | Weight loss 0.0968 | Mean activity 0.2467
DMS Iter  950 | Accuracy 0.9488 | Perf loss 0.0231 | Spike loss 0.9416 | Weight loss 0.0972 | Mean activity 0.2489
DMS Iter  955 | Accuracy 0.9553 | Perf loss 0.0201 | Spike loss 0.9724 | Weight loss 0.0975 | Mean activity 0.2519
DMS Iter  960 | Accuracy 0.9572 | Perf loss 0.0198 | Spike loss 0.9679 | Weight loss 0.0974 | Mean activity 0.2542
DMS Iter  965 | Accuracy 0.9578 | Perf loss 0.0188 | Spike loss 0.9389 | Weight loss 0.0972 | Mean activity 0.2496
DMS Iter  970 | Accuracy 0.9563 | Perf loss 0.0207 | Spike loss 0.9153 | Weight loss 0.0973 | Mean activity 0.2544
DMS Iter  975 | Accuracy 0.9596 | Perf loss 0.0205 | Spike loss 0.9059 | Weight loss 0.0972 | Mean activity 0.2454
DMS Iter  980 | Accuracy 0.9549 | Perf loss 0.0203 | Spike loss 0.9092 | Weight loss 0.0971 | Mean activity 0.2442
DMS Iter  985 | Accuracy 0.9523 | Perf loss 0.0236 | Spike loss 0.9461 | Weight loss 0.0975 | Mean activity 0.2440
DMS Iter  990 | Accuracy 0.9543 | Perf loss 0.0211 | Spike loss 0.9405 | Weight loss 0.0982 | Mean activity 0.2470
DMS Iter  995 | Accuracy 0.9605 | Perf loss 0.0198 | Spike loss 0.9788 | Weight loss 0.0992 | Mean activity 0.2469
DMS Iter 1000 | Accuracy 0.9555 | Perf loss 0.0207 | Spike loss 0.9792 | Weight loss 0.1000 | Mean activity 0.2461
DMS Iter 1005 | Accuracy 0.9541 | Perf loss 0.0207 | Spike loss 0.9978 | Weight loss 0.1002 | Mean activity 0.2378
DMS Iter 1010 | Accuracy 0.9580 | Perf loss 0.0211 | Spike loss 0.9724 | Weight loss 0.1003 | Mean activity 0.2392
DMS Iter 1015 | Accuracy 0.9502 | Perf loss 0.0214 | Spike loss 0.9557 | Weight loss 0.1001 | Mean activity 0.2403
DMS Iter 1020 | Accuracy 0.9551 | Perf loss 0.0217 | Spike loss 0.9422 | Weight loss 0.0999 | Mean activity 0.2300
DMS Iter 1025 | Accuracy 0.9590 | Perf loss 0.0203 | Spike loss 0.9418 | Weight loss 0.1002 | Mean activity 0.2368
DMS Iter 1030 | Accuracy 0.9574 | Perf loss 0.0206 | Spike loss 0.9781 | Weight loss 0.1001 | Mean activity 0.2349
DMS Iter 1035 | Accuracy 0.9627 | Perf loss 0.0194 | Spike loss 0.9598 | Weight loss 0.0998 | Mean activity 0.2411
DMS Iter 1040 | Accuracy 0.9592 | Perf loss 0.0208 | Spike loss 0.9538 | Weight loss 0.0998 | Mean activity 0.2443
DMS Iter 1045 | Accuracy 0.9578 | Perf loss 0.0201 | Spike loss 0.9230 | Weight loss 0.0996 | Mean activity 0.2371
DMS Iter 1050 | Accuracy 0.9525 | Perf loss 0.0219 | Spike loss 0.9439 | Weight loss 0.1003 | Mean activity 0.2424
DMS Iter 1055 | Accuracy 0.9578 | Perf loss 0.0204 | Spike loss 1.0266 | Weight loss 0.1015 | Mean activity 0.2359
DMS Iter 1060 | Accuracy 0.9410 | Perf loss 0.0260 | Spike loss 1.0010 | Weight loss 0.1021 | Mean activity 0.2344
DMS Iter 1065 | Accuracy 0.9602 | Perf loss 0.0201 | Spike loss 1.0120 | Weight loss 0.1031 | Mean activity 0.2487
DMS Iter 1070 | Accuracy 0.9568 | Perf loss 0.0196 | Spike loss 1.0056 | Weight loss 0.1035 | Mean activity 0.2442
DMS Iter 1075 | Accuracy 0.9563 | Perf loss 0.0197 | Spike loss 0.9897 | Weight loss 0.1033 | Mean activity 0.2477
DMS Iter 1080 | Accuracy 0.9588 | Perf loss 0.0201 | Spike loss 0.9577 | Weight loss 0.1032 | Mean activity 0.2462
DMS Iter 1085 | Accuracy 0.9516 | Perf loss 0.0208 | Spike loss 0.9455 | Weight loss 0.1034 | Mean activity 0.2504
DMS Iter 1090 | Accuracy 0.9605 | Perf loss 0.0182 | Spike loss 0.9264 | Weight loss 0.1029 | Mean activity 0.2421
DMS Iter 1095 | Accuracy 0.9641 | Perf loss 0.0180 | Spike loss 0.9008 | Weight loss 0.1024 | Mean activity 0.2502
DMS Iter 1100 | Accuracy 0.9576 | Perf loss 0.0197 | Spike loss 0.8982 | Weight loss 0.1028 | Mean activity 0.2418
DMS Iter 1105 | Accuracy 0.9551 | Perf loss 0.0217 | Spike loss 0.9124 | Weight loss 0.1040 | Mean activity 0.2322
DMS Iter 1110 | Accuracy 0.9627 | Perf loss 0.0177 | Spike loss 0.9525 | Weight loss 0.1051 | Mean activity 0.2316
DMS Iter 1115 | Accuracy 0.9596 | Perf loss 0.0185 | Spike loss 0.9270 | Weight loss 0.1055 | Mean activity 0.2248
DMS Iter 1120 | Accuracy 0.9602 | Perf loss 0.0183 | Spike loss 0.9226 | Weight loss 0.1058 | Mean activity 0.2201
DMS Iter 1125 | Accuracy 0.9607 | Perf loss 0.0206 | Spike loss 0.9030 | Weight loss 0.1056 | Mean activity 0.2226
DMS Iter 1130 | Accuracy 0.9588 | Perf loss 0.0176 | Spike loss 0.9071 | Weight loss 0.1052 | Mean activity 0.2286
DMS Iter 1135 | Accuracy 0.9594 | Perf loss 0.0185 | Spike loss 0.9155 | Weight loss 0.1056 | Mean activity 0.2299
DMS Iter 1140 | Accuracy 0.9609 | Perf loss 0.0191 | Spike loss 0.9045 | Weight loss 0.1057 | Mean activity 0.2285
DMS Iter 1145 | Accuracy 0.9643 | Perf loss 0.0176 | Spike loss 0.8940 | Weight loss 0.1052 | Mean activity 0.2287
DMS Iter 1150 | Accuracy 0.9543 | Perf loss 0.0206 | Spike loss 0.8978 | Weight loss 0.1052 | Mean activity 0.2285
DMS Iter 1155 | Accuracy 0.9594 | Perf loss 0.0186 | Spike loss 0.9168 | Weight loss 0.1056 | Mean activity 0.2298
DMS Iter 1160 | Accuracy 0.9566 | Perf loss 0.0184 | Spike loss 0.9188 | Weight loss 0.1057 | Mean activity 0.2252
DMS Iter 1165 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.9011 | Weight loss 0.1057 | Mean activity 0.2345
DMS Iter 1170 | Accuracy 0.9596 | Perf loss 0.0192 | Spike loss 0.8788 | Weight loss 0.1055 | Mean activity 0.2340
DMS Iter 1175 | Accuracy 0.9568 | Perf loss 0.0186 | Spike loss 0.8635 | Weight loss 0.1054 | Mean activity 0.2299
DMS Iter 1180 | Accuracy 0.9611 | Perf loss 0.0177 | Spike loss 0.8675 | Weight loss 0.1056 | Mean activity 0.2248
DMS Iter 1185 | Accuracy 0.9625 | Perf loss 0.0167 | Spike loss 0.8573 | Weight loss 0.1056 | Mean activity 0.2284
DMS Iter 1190 | Accuracy 0.9563 | Perf loss 0.0189 | Spike loss 0.8538 | Weight loss 0.1055 | Mean activity 0.2284
DMS Iter 1195 | Accuracy 0.9641 | Perf loss 0.0178 | Spike loss 0.8566 | Weight loss 0.1054 | Mean activity 0.2249
DMS Iter 1200 | Accuracy 0.9580 | Perf loss 0.0188 | Spike loss 0.8522 | Weight loss 0.1053 | Mean activity 0.2254
DMS Iter 1205 | Accuracy 0.9604 | Perf loss 0.0184 | Spike loss 0.8308 | Weight loss 0.1052 | Mean activity 0.2275
DMS Iter 1210 | Accuracy 0.9629 | Perf loss 0.0186 | Spike loss 0.8230 | Weight loss 0.1053 | Mean activity 0.2253
DMS Iter 1215 | Accuracy 0.9633 | Perf loss 0.0176 | Spike loss 0.8264 | Weight loss 0.1052 | Mean activity 0.2207
DMS Iter 1220 | Accuracy 0.9584 | Perf loss 0.0182 | Spike loss 0.8192 | Weight loss 0.1052 | Mean activity 0.2272
DMS Iter 1225 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.8148 | Weight loss 0.1053 | Mean activity 0.2230
DMS Iter 1230 | Accuracy 0.9621 | Perf loss 0.0179 | Spike loss 0.8174 | Weight loss 0.1053 | Mean activity 0.2225
DMS Iter 1235 | Accuracy 0.9588 | Perf loss 0.0189 | Spike loss 0.8043 | Weight loss 0.1053 | Mean activity 0.2274
DMS Iter 1240 | Accuracy 0.9631 | Perf loss 0.0177 | Spike loss 0.8178 | Weight loss 0.1057 | Mean activity 0.2176
DMS Iter 1245 | Accuracy 0.9637 | Perf loss 0.0166 | Spike loss 0.8162 | Weight loss 0.1059 | Mean activity 0.2195
DMS Iter 1250 | Accuracy 0.9652 | Perf loss 0.0165 | Spike loss 0.8167 | Weight loss 0.1062 | Mean activity 0.2236
DMS Iter 1255 | Accuracy 0.9590 | Perf loss 0.0190 | Spike loss 0.8093 | Weight loss 0.1065 | Mean activity 0.2233
DMS Iter 1260 | Accuracy 0.9586 | Perf loss 0.0203 | Spike loss 0.8348 | Weight loss 0.1070 | Mean activity 0.2247
DMS Iter 1265 | Accuracy 0.9621 | Perf loss 0.0187 | Spike loss 0.8456 | Weight loss 0.1071 | Mean activity 0.2237
DMS Iter 1270 | Accuracy 0.9582 | Perf loss 0.0199 | Spike loss 0.8415 | Weight loss 0.1074 | Mean activity 0.2230
DMS Iter 1275 | Accuracy 0.9602 | Perf loss 0.0180 | Spike loss 0.8544 | Weight loss 0.1078 | Mean activity 0.2218
DMS Iter 1280 | Accuracy 0.9646 | Perf loss 0.0171 | Spike loss 0.8512 | Weight loss 0.1082 | Mean activity 0.2227
DMS Iter 1285 | Accuracy 0.9566 | Perf loss 0.0218 | Spike loss 0.8360 | Weight loss 0.1081 | Mean activity 0.2254
DMS Iter 1290 | Accuracy 0.9613 | Perf loss 0.0180 | Spike loss 0.8803 | Weight loss 0.1081 | Mean activity 0.2248
DMS Iter 1295 | Accuracy 0.9654 | Perf loss 0.0172 | Spike loss 0.8958 | Weight loss 0.1080 | Mean activity 0.2216
DMS Iter 1300 | Accuracy 0.9609 | Perf loss 0.0178 | Spike loss 0.8779 | Weight loss 0.1077 | Mean activity 0.2268
DMS Iter 1305 | Accuracy 0.9631 | Perf loss 0.0170 | Spike loss 0.8719 | Weight loss 0.1079 | Mean activity 0.2253
DMS Iter 1310 | Accuracy 0.9631 | Perf loss 0.0166 | Spike loss 0.8505 | Weight loss 0.1080 | Mean activity 0.2237
DMS Iter 1315 | Accuracy 0.9652 | Perf loss 0.0168 | Spike loss 0.8249 | Weight loss 0.1082 | Mean activity 0.2165
DMS Iter 1320 | Accuracy 0.9637 | Perf loss 0.0168 | Spike loss 0.8397 | Weight loss 0.1081 | Mean activity 0.2241
DMS Iter 1325 | Accuracy 0.9645 | Perf loss 0.0164 | Spike loss 0.8215 | Weight loss 0.1082 | Mean activity 0.2139
DMS Iter 1330 | Accuracy 0.9619 | Perf loss 0.0181 | Spike loss 0.7980 | Weight loss 0.1082 | Mean activity 0.2095
DMS Iter 1335 | Accuracy 0.9648 | Perf loss 0.0157 | Spike loss 0.7928 | Weight loss 0.1080 | Mean activity 0.2122
DMS Iter 1340 | Accuracy 0.9600 | Perf loss 0.0188 | Spike loss 0.8036 | Weight loss 0.1082 | Mean activity 0.2132
DMS Iter 1345 | Accuracy 0.9596 | Perf loss 0.0176 | Spike loss 0.8062 | Weight loss 0.1082 | Mean activity 0.2136
DMS Iter 1350 | Accuracy 0.9588 | Perf loss 0.0217 | Spike loss 0.8128 | Weight loss 0.1081 | Mean activity 0.2039
DMS Iter 1355 | Accuracy 0.9605 | Perf loss 0.0230 | Spike loss 0.7970 | Weight loss 0.1081 | Mean activity 0.2139
DMS Iter 1360 | Accuracy 0.9605 | Perf loss 0.0243 | Spike loss 0.8114 | Weight loss 0.1088 | Mean activity 0.2005
DMS Iter 1365 | Accuracy 0.9475 | Perf loss 0.0305 | Spike loss 0.8264 | Weight loss 0.1094 | Mean activity 0.2037
DMS Iter 1370 | Accuracy 0.9498 | Perf loss 0.0268 | Spike loss 0.9131 | Weight loss 0.1122 | Mean activity 0.2222
DMS Iter 1375 | Accuracy 0.9627 | Perf loss 0.0213 | Spike loss 0.9920 | Weight loss 0.1148 | Mean activity 0.2216
DMS Iter 1380 | Accuracy 0.9596 | Perf loss 0.0233 | Spike loss 0.9783 | Weight loss 0.1151 | Mean activity 0.2213
DMS Iter 1385 | Accuracy 0.9600 | Perf loss 0.0208 | Spike loss 0.9487 | Weight loss 0.1151 | Mean activity 0.2169
DMS Iter 1390 | Accuracy 0.9598 | Perf loss 0.0205 | Spike loss 0.9231 | Weight loss 0.1150 | Mean activity 0.2160
DMS Iter 1395 | Accuracy 0.9627 | Perf loss 0.0188 | Spike loss 0.8971 | Weight loss 0.1146 | Mean activity 0.2105
DMS Iter 1400 | Accuracy 0.9645 | Perf loss 0.0171 | Spike loss 0.8679 | Weight loss 0.1142 | Mean activity 0.2143
DMS Iter 1405 | Accuracy 0.9613 | Perf loss 0.0194 | Spike loss 0.8436 | Weight loss 0.1140 | Mean activity 0.2101
DMS Iter 1410 | Accuracy 0.9596 | Perf loss 0.0201 | Spike loss 0.8148 | Weight loss 0.1135 | Mean activity 0.1999
DMS Iter 1415 | Accuracy 0.9576 | Perf loss 0.0204 | Spike loss 0.8096 | Weight loss 0.1127 | Mean activity 0.2089
DMS Iter 1420 | Accuracy 0.9631 | Perf loss 0.0182 | Spike loss 0.8169 | Weight loss 0.1127 | Mean activity 0.2145
DMS Iter 1425 | Accuracy 0.9645 | Perf loss 0.0180 | Spike loss 0.8101 | Weight loss 0.1126 | Mean activity 0.2188
DMS Iter 1430 | Accuracy 0.9635 | Perf loss 0.0187 | Spike loss 0.8028 | Weight loss 0.1127 | Mean activity 0.2143
DMS Iter 1435 | Accuracy 0.9596 | Perf loss 0.0191 | Spike loss 0.8156 | Weight loss 0.1135 | Mean activity 0.2054
DMS Iter 1440 | Accuracy 0.9648 | Perf loss 0.0187 | Spike loss 0.8341 | Weight loss 0.1143 | Mean activity 0.2148
DMS Iter 1445 | Accuracy 0.9689 | Perf loss 0.0192 | Spike loss 0.8380 | Weight loss 0.1146 | Mean activity 0.2115
DMS Iter 1450 | Accuracy 0.9600 | Perf loss 0.0202 | Spike loss 0.8435 | Weight loss 0.1150 | Mean activity 0.2156
DMS Iter 1455 | Accuracy 0.9578 | Perf loss 0.0209 | Spike loss 0.8924 | Weight loss 0.1161 | Mean activity 0.2215
DMS Iter 1460 | Accuracy 0.9623 | Perf loss 0.0182 | Spike loss 0.8875 | Weight loss 0.1166 | Mean activity 0.2158
DMS Iter 1465 | Accuracy 0.9580 | Perf loss 0.0208 | Spike loss 0.9091 | Weight loss 0.1172 | Mean activity 0.2109
DMS Iter 1470 | Accuracy 0.9588 | Perf loss 0.0179 | Spike loss 0.9209 | Weight loss 0.1175 | Mean activity 0.2048
DMS Iter 1475 | Accuracy 0.9678 | Perf loss 0.0167 | Spike loss 0.9268 | Weight loss 0.1177 | Mean activity 0.2073
DMS Iter 1480 | Accuracy 0.9641 | Perf loss 0.0176 | Spike loss 0.9199 | Weight loss 0.1178 | Mean activity 0.2083
DMS Iter 1485 | Accuracy 0.9639 | Perf loss 0.0173 | Spike loss 0.8908 | Weight loss 0.1178 | Mean activity 0.2164
DMS Iter 1490 | Accuracy 0.9578 | Perf loss 0.0196 | Spike loss 0.8575 | Weight loss 0.1176 | Mean activity 0.2109
DMS Iter 1495 | Accuracy 0.9639 | Perf loss 0.0165 | Spike loss 0.8525 | Weight loss 0.1177 | Mean activity 0.2117
DMS Iter 1500 | Accuracy 0.9676 | Perf loss 0.0158 | Spike loss 0.8359 | Weight loss 0.1171 | Mean activity 0.2040
DMS Iter 1505 | Accuracy 0.9613 | Perf loss 0.0182 | Spike loss 0.8243 | Weight loss 0.1170 | Mean activity 0.2112
DMS Iter 1510 | Accuracy 0.9592 | Perf loss 0.0174 | Spike loss 0.8195 | Weight loss 0.1173 | Mean activity 0.2075
DMS Iter 1515 | Accuracy 0.9639 | Perf loss 0.0171 | Spike loss 0.8180 | Weight loss 0.1177 | Mean activity 0.2046
DMS Iter 1520 | Accuracy 0.9666 | Perf loss 0.0169 | Spike loss 0.8242 | Weight loss 0.1178 | Mean activity 0.2081
DMS Iter 1525 | Accuracy 0.9609 | Perf loss 0.0172 | Spike loss 0.8321 | Weight loss 0.1184 | Mean activity 0.2012
DMS Iter 1530 | Accuracy 0.9617 | Perf loss 0.0176 | Spike loss 0.8266 | Weight loss 0.1186 | Mean activity 0.2057
DMS Iter 1535 | Accuracy 0.9668 | Perf loss 0.0162 | Spike loss 0.8740 | Weight loss 0.1194 | Mean activity 0.2095
DMS Iter 1540 | Accuracy 0.9607 | Perf loss 0.0216 | Spike loss 0.8984 | Weight loss 0.1201 | Mean activity 0.2094
DMS Iter 1545 | Accuracy 0.9602 | Perf loss 0.0186 | Spike loss 0.8961 | Weight loss 0.1209 | Mean activity 0.2056
DMS Iter 1550 | Accuracy 0.9600 | Perf loss 0.0254 | Spike loss 0.8728 | Weight loss 0.1216 | Mean activity 0.2052
DMS Iter 1555 | Accuracy 0.9598 | Perf loss 0.0212 | Spike loss 0.8902 | Weight loss 0.1228 | Mean activity 0.2036
DMS Iter 1560 | Accuracy 0.9582 | Perf loss 0.0221 | Spike loss 0.9120 | Weight loss 0.1241 | Mean activity 0.2131
DMS Iter 1565 | Accuracy 0.9535 | Perf loss 0.0230 | Spike loss 0.8988 | Weight loss 0.1242 | Mean activity 0.2163
DMS Iter 1570 | Accuracy 0.9563 | Perf loss 0.0234 | Spike loss 0.9593 | Weight loss 0.1249 | Mean activity 0.2230
DMS Iter 1575 | Accuracy 0.9551 | Perf loss 0.0220 | Spike loss 1.0544 | Weight loss 0.1264 | Mean activity 0.2237
DMS Iter 1580 | Accuracy 0.9486 | Perf loss 0.0342 | Spike loss 1.0935 | Weight loss 0.1277 | Mean activity 0.2217
DMS Iter 1585 | Accuracy 0.9621 | Perf loss 0.0195 | Spike loss 1.1112 | Weight loss 0.1280 | Mean activity 0.2241
DMS Iter 1590 | Accuracy 0.9615 | Perf loss 0.0198 | Spike loss 1.0739 | Weight loss 0.1275 | Mean activity 0.2263
DMS Iter 1595 | Accuracy 0.9623 | Perf loss 0.0200 | Spike loss 1.0368 | Weight loss 0.1270 | Mean activity 0.2160
DMS Iter 1600 | Accuracy 0.9563 | Perf loss 0.0193 | Spike loss 1.0115 | Weight loss 0.1268 | Mean activity 0.2078
DMS Iter 1605 | Accuracy 0.9631 | Perf loss 0.0181 | Spike loss 0.9856 | Weight loss 0.1264 | Mean activity 0.2050
DMS Iter 1610 | Accuracy 0.9633 | Perf loss 0.0179 | Spike loss 0.9625 | Weight loss 0.1259 | Mean activity 0.2071
DMS Iter 1615 | Accuracy 0.9633 | Perf loss 0.0175 | Spike loss 0.9268 | Weight loss 0.1253 | Mean activity 0.2140
DMS Iter 1620 | Accuracy 0.9588 | Perf loss 0.0214 | Spike loss 0.9385 | Weight loss 0.1259 | Mean activity 0.2098
DMS Iter 1625 | Accuracy 0.9676 | Perf loss 0.0171 | Spike loss 1.0057 | Weight loss 0.1272 | Mean activity 0.2155
DMS Iter 1630 | Accuracy 0.9580 | Perf loss 0.0209 | Spike loss 1.0043 | Weight loss 0.1276 | Mean activity 0.2070
DMS Iter 1635 | Accuracy 0.9615 | Perf loss 0.0177 | Spike loss 1.0109 | Weight loss 0.1278 | Mean activity 0.2061
DMS Iter 1640 | Accuracy 0.9613 | Perf loss 0.0197 | Spike loss 1.0107 | Weight loss 0.1283 | Mean activity 0.2072
DMS Iter 1645 | Accuracy 0.9680 | Perf loss 0.0160 | Spike loss 1.0006 | Weight loss 0.1282 | Mean activity 0.2114
DMS Iter 1650 | Accuracy 0.9627 | Perf loss 0.0167 | Spike loss 0.9525 | Weight loss 0.1277 | Mean activity 0.2022
DMS Iter 1655 | Accuracy 0.9602 | Perf loss 0.0190 | Spike loss 0.9282 | Weight loss 0.1274 | Mean activity 0.2034
DMS Iter 1660 | Accuracy 0.9660 | Perf loss 0.0161 | Spike loss 0.9054 | Weight loss 0.1271 | Mean activity 0.2124
DMS Iter 1665 | Accuracy 0.9680 | Perf loss 0.0150 | Spike loss 0.8919 | Weight loss 0.1267 | Mean activity 0.2075
DMS Iter 1670 | Accuracy 0.9701 | Perf loss 0.0154 | Spike loss 0.8714 | Weight loss 0.1264 | Mean activity 0.2108
DMS Iter 1675 | Accuracy 0.9650 | Perf loss 0.0182 | Spike loss 0.8894 | Weight loss 0.1261 | Mean activity 0.2144
DMS Iter 1680 | Accuracy 0.9682 | Perf loss 0.0172 | Spike loss 0.8823 | Weight loss 0.1261 | Mean activity 0.2129
DMS Iter 1685 | Accuracy 0.9670 | Perf loss 0.0164 | Spike loss 0.9009 | Weight loss 0.1269 | Mean activity 0.2084
DMS Iter 1690 | Accuracy 0.9664 | Perf loss 0.0156 | Spike loss 0.9136 | Weight loss 0.1277 | Mean activity 0.2111
DMS Iter 1695 | Accuracy 0.9609 | Perf loss 0.0198 | Spike loss 0.9154 | Weight loss 0.1278 | Mean activity 0.2100
DMS Iter 1700 | Accuracy 0.9666 | Perf loss 0.0149 | Spike loss 0.9216 | Weight loss 0.1275 | Mean activity 0.2079
DMS Iter 1705 | Accuracy 0.9674 | Perf loss 0.0154 | Spike loss 0.9027 | Weight loss 0.1271 | Mean activity 0.2043
DMS Iter 1710 | Accuracy 0.9639 | Perf loss 0.0181 | Spike loss 0.8879 | Weight loss 0.1270 | Mean activity 0.2094
DMS Iter 1715 | Accuracy 0.9625 | Perf loss 0.0167 | Spike loss 0.8706 | Weight loss 0.1265 | Mean activity 0.2025
DMS Iter 1720 | Accuracy 0.9629 | Perf loss 0.0181 | Spike loss 0.8676 | Weight loss 0.1265 | Mean activity 0.1984
DMS Iter 1725 | Accuracy 0.9652 | Perf loss 0.0159 | Spike loss 0.8568 | Weight loss 0.1265 | Mean activity 0.1971
DMS Iter 1730 | Accuracy 0.9686 | Perf loss 0.0148 | Spike loss 0.8430 | Weight loss 0.1263 | Mean activity 0.1966
DMS Iter 1735 | Accuracy 0.9646 | Perf loss 0.0195 | Spike loss 0.8329 | Weight loss 0.1260 | Mean activity 0.1966
DMS Iter 1740 | Accuracy 0.9662 | Perf loss 0.0163 | Spike loss 0.8172 | Weight loss 0.1259 | Mean activity 0.1943
DMS Iter 1745 | Accuracy 0.9639 | Perf loss 0.0163 | Spike loss 0.8191 | Weight loss 0.1260 | Mean activity 0.1986
DMS Iter 1750 | Accuracy 0.9688 | Perf loss 0.0151 | Spike loss 0.8169 | Weight loss 0.1259 | Mean activity 0.1961
DMS Iter 1755 | Accuracy 0.9703 | Perf loss 0.0144 | Spike loss 0.8091 | Weight loss 0.1258 | Mean activity 0.1986
DMS Iter 1760 | Accuracy 0.9592 | Perf loss 0.0193 | Spike loss 0.8076 | Weight loss 0.1257 | Mean activity 0.1952
DMS Iter 1765 | Accuracy 0.9645 | Perf loss 0.0172 | Spike loss 0.7946 | Weight loss 0.1256 | Mean activity 0.1965
DMS Iter 1770 | Accuracy 0.9625 | Perf loss 0.0171 | Spike loss 0.7859 | Weight loss 0.1256 | Mean activity 0.1962
DMS Iter 1775 | Accuracy 0.9639 | Perf loss 0.0183 | Spike loss 0.7740 | Weight loss 0.1258 | Mean activity 0.1944
DMS Iter 1780 | Accuracy 0.9662 | Perf loss 0.0156 | Spike loss 0.7814 | Weight loss 0.1262 | Mean activity 0.1977
DMS Iter 1785 | Accuracy 0.9625 | Perf loss 0.0162 | Spike loss 0.8024 | Weight loss 0.1267 | Mean activity 0.1944
DMS Iter 1790 | Accuracy 0.9719 | Perf loss 0.0138 | Spike loss 0.8086 | Weight loss 0.1273 | Mean activity 0.1952
DMS Iter 1795 | Accuracy 0.9672 | Perf loss 0.0151 | Spike loss 0.8010 | Weight loss 0.1276 | Mean activity 0.1906
DMS Iter 1800 | Accuracy 0.9680 | Perf loss 0.0148 | Spike loss 0.7924 | Weight loss 0.1277 | Mean activity 0.1936
DMS Iter 1805 | Accuracy 0.9699 | Perf loss 0.0137 | Spike loss 0.7918 | Weight loss 0.1275 | Mean activity 0.1901
DMS Iter 1810 | Accuracy 0.9672 | Perf loss 0.0152 | Spike loss 0.7841 | Weight loss 0.1272 | Mean activity 0.1907
DMS Iter 1815 | Accuracy 0.9684 | Perf loss 0.0147 | Spike loss 0.7764 | Weight loss 0.1266 | Mean activity 0.1949
DMS Iter 1820 | Accuracy 0.9641 | Perf loss 0.0170 | Spike loss 0.7545 | Weight loss 0.1262 | Mean activity 0.1955
DMS Iter 1825 | Accuracy 0.9654 | Perf loss 0.0175 | Spike loss 0.7483 | Weight loss 0.1261 | Mean activity 0.1876
DMS Iter 1830 | Accuracy 0.9680 | Perf loss 0.0169 | Spike loss 0.7479 | Weight loss 0.1263 | Mean activity 0.1868
DMS Iter 1835 | Accuracy 0.9631 | Perf loss 0.0177 | Spike loss 0.7709 | Weight loss 0.1277 | Mean activity 0.1898
DMS Iter 1840 | Accuracy 0.9662 | Perf loss 0.0155 | Spike loss 0.8514 | Weight loss 0.1294 | Mean activity 0.1925
DMS Iter 1845 | Accuracy 0.9658 | Perf loss 0.0191 | Spike loss 0.8877 | Weight loss 0.1302 | Mean activity 0.1952
DMS Iter 1850 | Accuracy 0.9572 | Perf loss 0.0198 | Spike loss 0.8889 | Weight loss 0.1305 | Mean activity 0.1948
DMS Iter 1855 | Accuracy 0.9645 | Perf loss 0.0170 | Spike loss 0.8962 | Weight loss 0.1304 | Mean activity 0.1962
DMS Iter 1860 | Accuracy 0.9641 | Perf loss 0.0163 | Spike loss 0.9061 | Weight loss 0.1308 | Mean activity 0.1955
DMS Iter 1865 | Accuracy 0.9652 | Perf loss 0.0161 | Spike loss 0.9028 | Weight loss 0.1311 | Mean activity 0.1890
DMS Iter 1870 | Accuracy 0.9613 | Perf loss 0.0167 | Spike loss 0.8960 | Weight loss 0.1311 | Mean activity 0.1936
DMS Iter 1875 | Accuracy 0.9631 | Perf loss 0.0168 | Spike loss 0.8865 | Weight loss 0.1312 | Mean activity 0.1906
DMS Iter 1880 | Accuracy 0.9688 | Perf loss 0.0154 | Spike loss 0.8781 | Weight loss 0.1314 | Mean activity 0.1909
DMS Iter 1885 | Accuracy 0.9648 | Perf loss 0.0166 | Spike loss 0.8425 | Weight loss 0.1312 | Mean activity 0.1904
DMS Iter 1890 | Accuracy 0.9715 | Perf loss 0.0149 | Spike loss 0.8410 | Weight loss 0.1314 | Mean activity 0.1846
DMS Iter 1895 | Accuracy 0.9643 | Perf loss 0.0153 | Spike loss 0.8335 | Weight loss 0.1312 | Mean activity 0.1834
DMS Iter 1900 | Accuracy 0.9684 | Perf loss 0.0146 | Spike loss 0.8169 | Weight loss 0.1310 | Mean activity 0.1786
DMS Iter 1905 | Accuracy 0.9646 | Perf loss 0.0166 | Spike loss 0.8116 | Weight loss 0.1310 | Mean activity 0.1794
DMS Iter 1910 | Accuracy 0.9715 | Perf loss 0.0149 | Spike loss 0.8409 | Weight loss 0.1313 | Mean activity 0.1780
DMS Iter 1915 | Accuracy 0.9611 | Perf loss 0.0183 | Spike loss 0.8462 | Weight loss 0.1313 | Mean activity 0.1881
DMS Iter 1920 | Accuracy 0.9627 | Perf loss 0.0215 | Spike loss 0.8177 | Weight loss 0.1311 | Mean activity 0.1832
DMS Iter 1925 | Accuracy 0.9631 | Perf loss 0.0203 | Spike loss 0.8071 | Weight loss 0.1312 | Mean activity 0.1819
DMS Iter 1930 | Accuracy 0.9697 | Perf loss 0.0135 | Spike loss 0.7942 | Weight loss 0.1309 | Mean activity 0.1781
DMS Iter 1935 | Accuracy 0.9674 | Perf loss 0.0156 | Spike loss 0.7923 | Weight loss 0.1307 | Mean activity 0.1771
DMS Iter 1940 | Accuracy 0.9699 | Perf loss 0.0130 | Spike loss 0.7848 | Weight loss 0.1304 | Mean activity 0.1755
DMS Iter 1945 | Accuracy 0.9660 | Perf loss 0.0151 | Spike loss 0.7687 | Weight loss 0.1304 | Mean activity 0.1756
DMS Iter 1950 | Accuracy 0.9717 | Perf loss 0.0134 | Spike loss 0.7760 | Weight loss 0.1305 | Mean activity 0.1797
DMS Iter 1955 | Accuracy 0.9658 | Perf loss 0.0149 | Spike loss 0.7787 | Weight loss 0.1305 | Mean activity 0.1841
DMS Iter 1960 | Accuracy 0.9670 | Perf loss 0.0184 | Spike loss 0.7772 | Weight loss 0.1306 | Mean activity 0.1826
DMS Iter 1965 | Accuracy 0.9641 | Perf loss 0.0193 | Spike loss 0.7668 | Weight loss 0.1308 | Mean activity 0.1853
DMS Iter 1970 | Accuracy 0.9609 | Perf loss 0.0163 | Spike loss 0.7810 | Weight loss 0.1319 | Mean activity 0.1779
DMS Iter 1975 | Accuracy 0.9707 | Perf loss 0.0163 | Spike loss 0.7959 | Weight loss 0.1327 | Mean activity 0.1697
DMS Iter 1980 | Accuracy 0.9623 | Perf loss 0.0163 | Spike loss 0.8059 | Weight loss 0.1326 | Mean activity 0.1691
DMS Iter 1985 | Accuracy 0.9680 | Perf loss 0.0154 | Spike loss 0.7976 | Weight loss 0.1324 | Mean activity 0.1717
DMS Iter 1990 | Accuracy 0.9697 | Perf loss 0.0140 | Spike loss 0.7889 | Weight loss 0.1322 | Mean activity 0.1747
DMS Iter 1995 | Accuracy 0.9635 | Perf loss 0.0159 | Spike loss 0.7953 | Weight loss 0.1324 | Mean activity 0.1737

(Yang, 2020): Dynamical system analysis for RNN

Implementation of the paper:

  • Yang G R, Wang X J. Artificial neural networks for neuroscientists: A primer[J]. Neuron, 2020, 107(6): 1048-1070.

The original implementation is based on PyTorch: https://github.com/gyyang/nn-brain/blob/master/RNN%2BDynamicalSystemAnalysis.ipynb

[1]:
import brainpy as bp
import brainpy.math as bm
bp.math.set_platform('cpu')
[2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In this tutorial, we will use supervised learning to train a recurrent neural network on a simple perceptual decision making task, and analyze the trained network using dynamical system analysis.

Defining a cognitive task

[3]:
# We will import the task from the neurogym library.
# Please install neurogym:
#
# https://github.com/neurogym/neurogym

import neurogym as ngym
[4]:
# Environment
task = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 100}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,
                       seq_len=seq_len)

# A sample environment from dataset
env = dataset.env
# Visualize the environment with 2 sample trials
_ = ngym.utils.plot_env(env, num_trials=2, fig_kwargs={'figsize': (8, 6)})
_images/recurrent_networks_Yang_2020_RNN_Analysis_7_0.png
[5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
batch_size = dataset.batch_size

Define a vanilla continuous-time recurrent network

Here we will define a continuous-time neural network but discretize it in time using the Euler method.

\begin{align} \tau \frac{d\mathbf{r}}{dt} = -\mathbf{r}(t) + f(W_r \mathbf{r}(t) + W_x \mathbf{x}(t) + \mathbf{b}_r). \end{align}

This continuous-time system can then be discretized using the Euler method with a time step of \(\Delta t\),

\begin{align} \mathbf{r}(t+\Delta t) = \mathbf{r}(t) + \Delta \mathbf{r} = \mathbf{r}(t) + \frac{\Delta t}{\tau}[-\mathbf{r}(t) + f(W_r \mathbf{r}(t) + W_x \mathbf{x}(t) + \mathbf{b}_r)]. \end{align}
[6]:
class RNN(bp.DynamicalSystem):
  def __init__(self,
               num_input,
               num_hidden,
               num_output,
               num_batch,
               dt=None, seed=None,
               w_ir=bp.init.KaimingNormal(scale=1.),
               w_rr=bp.init.KaimingNormal(scale=1.),
               w_ro=bp.init.KaimingNormal(scale=1.)):
    super(RNN, self).__init__()

    # parameters
    self.tau = 100
    self.num_batch = num_batch
    self.num_input = num_input
    self.num_hidden = num_hidden
    self.num_output = num_output
    if dt is None:
      self.alpha = 1
    else:
      self.alpha = dt / self.tau
    self.rng = bm.random.RandomState(seed=seed)

    # input weight
    self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden)))

    # recurrent weight
    bound = 1 / num_hidden ** 0.5
    self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden)))
    self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

    # readout weight
    self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (num_hidden, num_output)))
    self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

    # variables
    self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
    self.o = bm.Variable(bm.zeros((num_batch, num_output)))

  def cell(self, x, h):
    ins = x @ self.w_ir + h @ self.w_rr + self.b_rr
    state = h * (1 - self.alpha) + ins * self.alpha
    return bm.relu(state)

  def readout(self, h):
    return h @ self.w_ro + self.b_ro

  def update(self, x):
    self.h.value = self.cell(x, self.h)
    self.o.value = self.readout(self.h)
    return self.h.value, self.o.value

  def predict(self, xs):
    self.h[:] = 0.
    return bm.for_loop(self.update, xs, dyn_vars=[self.h, self.o])

  def loss(self, xs, ys):
    hs, os = self.predict(xs)
    os = os.reshape((-1, os.shape[-1]))
    loss = bp.losses.cross_entropy_loss(os, ys.flatten())
    return loss, os

Train the recurrent network on the decision-making task

[7]:
# Instantiate the network and print information
hidden_size = 64
with bm.training_environment():
    net = RNN(num_input=input_size,
              num_hidden=hidden_size,
              num_output=output_size,
              num_batch=batch_size,
              dt=env.dt)
[8]:
# Adam optimizer
opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique())

# gradient function
grad_f = bm.grad(net.loss,
                 child_objs=net,
                 grad_vars=net.train_vars().unique(),
                 return_value=True,
                 has_aux=True)

# training function
@bm.jit
@bm.to_object(child_objs=(grad_f, opt))
def train(xs, ys):
  grads, loss, os = grad_f(xs, ys)
  opt.update(grads)
  return loss, os
[9]:
running_acc = 0
running_loss = 0
for i in range(1500):
  inputs, labels_np = dataset()
  inputs = bm.asarray(inputs)
  labels = bm.asarray(labels_np)
  loss, outputs = train(inputs, labels)
  running_loss += loss
  # Compute performance
  output_np = np.argmax(bm.as_numpy(outputs), axis=-1).flatten()
  labels_np = labels_np.flatten()
  ind = labels_np > 0  # Only analyze time points when target is not fixation
  running_acc += np.mean(labels_np[ind] == output_np[ind])
  if i % 100 == 99:
    running_loss /= 100
    running_acc /= 100
    print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(i + 1, running_loss, running_acc))
    running_loss = 0
    running_acc = 0
Step 100, Loss 0.0994, Acc 0.458
Step 200, Loss 0.0201, Acc 0.868
Step 300, Loss 0.0155, Acc 0.876
Step 400, Loss 0.0136, Acc 0.883
Step 500, Loss 0.0129, Acc 0.878
Step 600, Loss 0.0127, Acc 0.881
Step 700, Loss 0.0122, Acc 0.884
Step 800, Loss 0.0123, Acc 0.884
Step 900, Loss 0.0116, Acc 0.885
Step 1000, Loss 0.0116, Acc 0.885
Step 1100, Loss 0.0111, Acc 0.888
Step 1200, Loss 0.0111, Acc 0.885
Step 1300, Loss 0.0108, Acc 0.891
Step 1400, Loss 0.0111, Acc 0.885
Step 1500, Loss 0.0106, Acc 0.891

Visualize neural activity for in sample trials

We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space.

[10]:
env.reset(no_step=True)
perf = 0
num_trial = 100
activity_dict = {}
trial_infos = {}
for i in range(num_trial):
    env.new_trial()
    ob, gt = env.ob, env.gt
    inputs = bm.asarray(ob[:, np.newaxis, :])
    rnn_activity, action_pred = net.predict(inputs)
    rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
    activity_dict[i] = rnn_activity
    trial_infos[i] = env.trial

# Concatenate activity for PCA
activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0)
print('Shape of the neural activity: (Time points, Neurons): ', activity.shape)

# Print trial informations
for i in range(5):
    print('Trial ', i, trial_infos[i])
Shape of the neural activity: (Time points, Neurons):  (2200, 64)
Trial  0 {'ground_truth': 0, 'coh': 25.6}
Trial  1 {'ground_truth': 1, 'coh': 0.0}
Trial  2 {'ground_truth': 0, 'coh': 6.4}
Trial  3 {'ground_truth': 0, 'coh': 6.4}
Trial  4 {'ground_truth': 1, 'coh': 0.0}
[11]:
pca = PCA(n_components=2)
pca.fit(activity)
[11]:
PCA(n_components=2)

Transform individual trials and Visualize in PC space based on ground-truth color. We see that the neural activity is organized by stimulus ground-truth in PC1

[12]:
plt.rcdefaults()
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5))
for i in range(num_trial):
    activity_pc = pca.transform(activity_dict[i])
    trial = trial_infos[i]
    color = 'red' if trial['ground_truth'] == 0 else 'blue'
    _ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)
    if i < 5:
        _ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color)

ax1.set_xlabel('PC 1')
ax1.set_ylabel('PC 2')
plt.show()
_images/recurrent_networks_Yang_2020_RNN_Analysis_20_0.png

Search for approximate fixed points

Here we search for approximate fixed points and visualize them in the same PC space. In a generic dynamical system,

\begin{align} \frac{d\mathbf{x}}{dt} = F(\mathbf{x}), \end{align}

We can search for fixed points by doing the optimization

\begin{align} \mathrm{argmin}_{\mathbf{x}} |F(\mathbf{x})|^2. \end{align}
[13]:
f_cell = lambda h: net.cell(bm.asarray([1, 0.5, 0.5]), h)
[14]:
fp_candidates = bm.vstack([activity_dict[i] for i in range(num_trial)])
fp_candidates.shape
[14]:
(2200, 64)
[15]:
finder = bp.analysis.SlowPointFinder(f_cell=f_cell, f_type='discrete')
finder.find_fps_with_gd_method(
  candidates=fp_candidates,
  tolerance=1e-5, num_batch=200,
    optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(tolerance=1e-5)
finder.keep_unique(tolerance=0.03)
finder.exclude_outliers(0.1)
fixed_points = finder.fixed_points
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.74 sec, Training loss 0.0012789472
    Batches 201-400 in 0.69 sec, Training loss 0.0005247793
    Batches 401-600 in 0.68 sec, Training loss 0.0003411663
    Batches 601-800 in 0.69 sec, Training loss 0.0002612533
    Batches 801-1000 in 0.69 sec, Training loss 0.0002112046
    Batches 1001-1200 in 0.73 sec, Training loss 0.0001740992
    Batches 1201-1400 in 0.70 sec, Training loss 0.0001455246
    Batches 1401-1600 in 0.74 sec, Training loss 0.0001233126
    Batches 1601-1800 in 0.70 sec, Training loss 0.0001058684
    Batches 1801-2000 in 0.77 sec, Training loss 0.0000921902
    Batches 2001-2200 in 0.72 sec, Training loss 0.0000814899
    Batches 2201-2400 in 0.57 sec, Training loss 0.0000731918
    Batches 2401-2600 in 0.75 sec, Training loss 0.0000667152
    Batches 2601-2800 in 0.74 sec, Training loss 0.0000616670
    Batches 2801-3000 in 0.70 sec, Training loss 0.0000576638
    Batches 3001-3200 in 0.56 sec, Training loss 0.0000543614
    Batches 3201-3400 in 0.60 sec, Training loss 0.0000515535
    Batches 3401-3600 in 0.61 sec, Training loss 0.0000491180
    Batches 3601-3800 in 0.71 sec, Training loss 0.0000469084
    Batches 3801-4000 in 0.75 sec, Training loss 0.0000448209
    Batches 4001-4200 in 0.67 sec, Training loss 0.0000428211
    Batches 4201-4400 in 0.73 sec, Training loss 0.0000408868
    Batches 4401-4600 in 0.71 sec, Training loss 0.0000390110
    Batches 4601-4800 in 0.77 sec, Training loss 0.0000371877
    Batches 4801-5000 in 0.58 sec, Training loss 0.0000354225
    Batches 5001-5200 in 0.58 sec, Training loss 0.0000337130
    Batches 5201-5400 in 0.69 sec, Training loss 0.0000320665
    Batches 5401-5600 in 0.65 sec, Training loss 0.0000304907
    Batches 5601-5800 in 0.76 sec, Training loss 0.0000289806
    Batches 5801-6000 in 0.64 sec, Training loss 0.0000275472
    Batches 6001-6200 in 0.55 sec, Training loss 0.0000261840
    Batches 6201-6400 in 0.60 sec, Training loss 0.0000248957
    Batches 6401-6600 in 0.68 sec, Training loss 0.0000236831
    Batches 6601-6800 in 0.67 sec, Training loss 0.0000225569
    Batches 6801-7000 in 0.71 sec, Training loss 0.0000215175
    Batches 7001-7200 in 0.66 sec, Training loss 0.0000205547
    Batches 7201-7400 in 0.70 sec, Training loss 0.0000196768
    Batches 7401-7600 in 0.68 sec, Training loss 0.0000188892
    Batches 7601-7800 in 0.71 sec, Training loss 0.0000181973
    Batches 7801-8000 in 0.67 sec, Training loss 0.0000175776
    Batches 8001-8200 in 0.72 sec, Training loss 0.0000170112
    Batches 8201-8400 in 0.75 sec, Training loss 0.0000165215
    Batches 8401-8600 in 0.54 sec, Training loss 0.0000160938
    Batches 8601-8800 in 0.60 sec, Training loss 0.0000157249
    Batches 8801-9000 in 0.52 sec, Training loss 0.0000154070
    Batches 9001-9200 in 0.68 sec, Training loss 0.0000151405
    Batches 9201-9400 in 0.54 sec, Training loss 0.0000149151
    Batches 9401-9600 in 0.53 sec, Training loss 0.0000147287
    Batches 9601-9800 in 0.53 sec, Training loss 0.0000145786
    Batches 9801-10000 in 0.53 sec, Training loss 0.0000144657
Excluding fixed points with squared speed above tolerance 1e-05:
    Kept 1157/2200 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
    Kept 10/1157 unique fixed points with uniqueness tolerance 0.03.
Excluding outliers:
    Kept 8/10 fixed points with within outlier tolerance 0.1.

Visualize the found approximate fixed points.

We see that they found an approximate line attrator, corresponding to our PC1, along which evidence is integrated during the stimulus period.

[16]:
# Plot in the same space as activity
plt.figure(figsize=(10, 5))
for i in range(10):
    activity_pc = pca.transform(activity_dict[i])
    trial = trial_infos[i]
    color = 'red' if trial['ground_truth'] == 0 else 'blue'
    plt.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color, alpha=0.1)

# Fixed points are shown in cross
fixedpoints_pc = pca.transform(fixed_points)
plt.plot(fixedpoints_pc[:, 0], fixedpoints_pc[:, 1], 'x', label='fixed points')

plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.legend()
plt.show()
_images/recurrent_networks_Yang_2020_RNN_Analysis_27_0.png

Computing the Jacobian and finding the line attractor

[17]:
from jax import jacobian
[19]:
dFdh = jacobian(f_cell)(fixed_points[2])

eigval, eigvec = np.linalg.eig(bm.as_numpy(dFdh))
[20]:
# Plot distribution of eigenvalues in a 2-d real-imaginary plot
plt.figure()
plt.scatter(np.real(eigval), np.imag(eigval))
plt.plot([1, 1], [-1, 1], '--')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.show()
_images/recurrent_networks_Yang_2020_RNN_Analysis_31_0.png

(Bellec, et. al, 2020): eprop for Evidence Accumulation Task

Implementation of the paper:

  • Bellec, G., Scherr, F., Subramoney, A., Hajek, E., Salaj, D., Legenstein, R., & Maass, W. (2020). A solution to the learning dilemma for recurrent networks of spiking neurons. Nature communications, 11(1), 1-15.

[1]:
import matplotlib.pyplot as plt
import numpy as np
import brainpy as bp
import brainpy.math as bm
from jax.lax import stop_gradient
from matplotlib import patches

bm.set_environment(bm.training_mode, dt=1.)
[2]:
# training parameters
n_batch = 128  # batch size

# neuron model and simulation parameters
reg_f = 1.  # regularization coefficient for firing rate
reg_rate = 10  # target firing rate for regularization [Hz]

# Experiment parameters
t_cue_spacing = 150  # distance between two consecutive cues in ms

# Frequencies
input_f0 = 40. / 1000.  # poisson firing rate of input neurons in khz
regularization_f0 = reg_rate / 1000.  # mean target network firing frequency
[3]:
class EligSNN(bp.Network):
  def __init__(self, num_in, num_rec, num_out, eprop=True, tau_a=2e3, tau_v=2e1):
    super(EligSNN, self).__init__()

    # parameters
    self.num_in = num_in
    self.num_rec = num_rec
    self.num_out = num_out
    self.eprop = eprop

    # neurons
    self.i = bp.neurons.InputGroup(num_in)
    self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, mode=bm.training_mode)

    n_regular = int(num_rec / 2)
    n_adaptive = num_rec - n_regular
    beta1 = bm.exp(- bm.get_dt() / tau_a)
    beta2 = 1.7 * (1 - beta1) / (1 - bm.exp(-1 / tau_v))
    beta = bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta2])
    self.r = bp.neurons.ALIFBellec2020(
      num_rec,
      V_rest=0.,
      tau_ref=5.,
      V_th=0.6,
      tau_a=tau_a,
      tau=tau_v,
      beta=beta,
      V_initializer=bp.init.ZeroInit(),
      a_initializer=bp.init.ZeroInit(),
      mode=bm.training_mode, eprop=eprop
    )

    # synapses
    self.i2r = bp.layers.Dense(num_in, num_rec,
                               W_initializer=bp.init.KaimingNormal(),
                               b_initializer=None)
    self.i2r.W *= tau_v
    self.r2r = bp.layers.Dense(num_rec, num_rec,
                               W_initializer=bp.init.KaimingNormal(),
                               b_initializer=None)
    self.r2r.W *= tau_v
    self.r2o = bp.layers.Dense(num_rec, num_out,
                               W_initializer=bp.init.KaimingNormal(),
                               b_initializer=None)

  def update(self, shared, x):
    self.r.input += self.i2r(shared, x)
    z = stop_gradient(self.r.spike.value) if self.eprop else self.r.spike.value
    self.r.input += self.r2r(shared, z)
    self.r(shared)
    self.o.input += self.r2o(shared, self.r.spike.value)
    self.o(shared)
    return self.o.V.value

[4]:
net = EligSNN(num_in=40, num_rec=100, num_out=2, eprop=False)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
@bp.tools.numba_jit
def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, prob, f0=0.5,
                             n_cues=7, t_cue=100, t_interval=150, n_input_symbols=4):
  n_channel = n_neuron // n_input_symbols

  # assign input spike probabilities
  probs = np.where(np.random.random((batch_size, 1)) < 0.5, prob, 1 - prob)

  # for each example in batch, draw which cues are going to be active (left or right)
  cue_assignments = np.asarray(np.random.random(n_cues) > probs, dtype=np.int_)

  # generate input nums - 0: left, 1: right, 2:recall, 3:background noise
  input_nums = 3 * np.ones((batch_size, seq_len), dtype=np.int_)
  input_nums[:, :n_cues] = cue_assignments
  input_nums[:, -1] = 2

  # generate input spikes
  input_spike_prob = np.zeros((batch_size, seq_len, n_neuron))
  d_silence = t_interval - t_cue
  for b in range(batch_size):
    for k in range(n_cues):
      # input channels only fire when they are selected (left or right)
      c = cue_assignments[b, k]
      # reverse order of cues
      i_seq = d_silence + k * t_interval
      i_neu = c * n_channel
      input_spike_prob[b, i_seq:i_seq + t_cue, i_neu:i_neu + n_channel] = f0
  # recall cue
  input_spike_prob[:, -recall_duration:, 2 * n_channel:3 * n_channel] = f0
  # background noise
  input_spike_prob[:, :, 3 * n_channel:] = f0 / 4.
  input_spikes = input_spike_prob > np.random.rand(*input_spike_prob.shape)

  # generate targets
  target_mask = np.zeros((batch_size, seq_len), dtype=np.bool_)
  target_mask[:, -1] = True
  target_nums = (np.sum(cue_assignments, axis=1) > n_cues / 2).astype(np.int_)
  return input_spikes, input_nums, target_nums, target_mask
[6]:
def get_data(batch_size, n_in, t_interval, f0):
  # used for obtaining a new randomly generated batch of examples
  def generate_data():
    seq_len = int(t_interval * 7 + 1200)
    for _ in range(10):
      spk_data, _, target_data, _ = generate_click_task_data(
        batch_size=batch_size, seq_len=seq_len, n_neuron=n_in, recall_duration=150,
        prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4
      )
      yield spk_data, target_data

  return generate_data
[7]:
def loss_fun(predicts, targets):
  predicts, mon = predicts

  # we only use network output at the end for classification
  output_logits = predicts[:, -t_cue_spacing:]

  # Define the accuracy
  y_predict = bm.argmax(bm.mean(output_logits, axis=1), axis=1)
  accuracy = bm.equal(targets, y_predict).astype(bm.dftype()).mean()

  # loss function
  tiled_targets = bm.tile(bm.expand_dims(targets, 1), (1, t_cue_spacing))
  loss_cls = bm.mean(bp.losses.cross_entropy_loss(output_logits, tiled_targets))

  # Firing rate regularization:
  # For historical reason we often use this regularization,
  # but the other one is easier to implement in an "online" fashion by a single agent.
  av = bm.mean(mon['r.spike'], axis=(0, 1)) / bm.get_dt()
  loss_reg_f = bm.sum(bm.square(av - regularization_f0) * reg_f)

  # Aggregate the losses #
  loss = loss_reg_f + loss_cls
  loss_res = {'loss': loss, 'loss reg': loss_reg_f, 'accuracy': accuracy}
  return bm.as_jax(loss), loss_res

Training

[8]:
# Training
trainer = bp.BPTT(
  net,
  loss_fun,
  loss_has_aux=True,
  optimizer=bp.optim.Adam(lr=0.01),
  monitors={'r.spike': net.r.spike},
)
trainer.fit(get_data(n_batch, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0),
            num_epoch=40,
            num_report=10)
Train 10 steps, use 9.8286 s, loss 0.7034175992012024, accuracy 0.4281249940395355, loss reg 0.004963034763932228
Train 20 steps, use 7.3619 s, loss 0.7023941874504089, accuracy 0.4046874940395355, loss reg 0.004842016380280256
Train 30 steps, use 7.3519 s, loss 0.6943514943122864, accuracy 0.596875011920929, loss reg 0.004735519178211689
Train 40 steps, use 7.3225 s, loss 0.6893576383590698, accuracy 0.5882812738418579, loss reg 0.004621113184839487
Train 50 steps, use 7.4316 s, loss 0.7021943926811218, accuracy 0.45703125, loss reg 0.0045189810916781425
Train 60 steps, use 7.3455 s, loss 0.7054350972175598, accuracy 0.3843750059604645, loss reg 0.004483774304389954
Train 70 steps, use 7.5527 s, loss 0.6898735761642456, accuracy 0.56640625, loss reg 0.0044373138807713985
Train 80 steps, use 7.3871 s, loss 0.69098299741745, accuracy 0.5992187857627869, loss reg 0.004469707142561674
Train 90 steps, use 7.3317 s, loss 0.6886122822761536, accuracy 0.6968750357627869, loss reg 0.004408594686537981
Train 100 steps, use 7.4113 s, loss 0.6826426386833191, accuracy 0.749218761920929, loss reg 0.004439009819179773
Train 110 steps, use 7.4938 s, loss 0.6881702542304993, accuracy 0.633593738079071, loss reg 0.004656294826418161
Train 120 steps, use 7.2987 s, loss 0.6720143556594849, accuracy 0.772656261920929, loss reg 0.004621779080480337
Train 130 steps, use 7.4011 s, loss 0.6499411463737488, accuracy 0.8023437857627869, loss reg 0.004800095688551664
Train 140 steps, use 7.3568 s, loss 0.6571835875511169, accuracy 0.710156261920929, loss reg 0.005033737048506737
Train 150 steps, use 7.1787 s, loss 0.6523110866546631, accuracy 0.7250000238418579, loss reg 0.00559642817825079
Train 160 steps, use 7.1760 s, loss 0.608909010887146, accuracy 0.821093738079071, loss reg 0.005754651036113501
Train 170 steps, use 7.2558 s, loss 0.5620784163475037, accuracy 0.844531238079071, loss reg 0.006214872468262911
Train 180 steps, use 7.2844 s, loss 0.5986811518669128, accuracy 0.750781238079071, loss reg 0.006925530731678009
Train 190 steps, use 7.4182 s, loss 0.544775664806366, accuracy 0.848437488079071, loss reg 0.006775358226150274
Train 200 steps, use 7.4347 s, loss 0.5496039390563965, accuracy 0.831250011920929, loss reg 0.007397319655865431
Train 210 steps, use 7.4629 s, loss 0.5447431206703186, accuracy 0.813281238079071, loss reg 0.006942986976355314
Train 220 steps, use 7.3833 s, loss 0.5015143752098083, accuracy 0.85546875, loss reg 0.0072592394426465034
Train 230 steps, use 7.4328 s, loss 0.5421426296234131, accuracy 0.854687511920929, loss reg 0.0077950432896614075
Train 240 steps, use 7.4438 s, loss 0.4893417954444885, accuracy 0.875781238079071, loss reg 0.007711453828960657
Train 250 steps, use 7.3671 s, loss 0.48076897859573364, accuracy 0.8203125, loss reg 0.006535724736750126
Train 260 steps, use 7.3650 s, loss 0.46686863899230957, accuracy 0.8617187738418579, loss reg 0.007533709984272718
Train 270 steps, use 7.4364 s, loss 0.4155255854129791, accuracy 0.9156250357627869, loss reg 0.007653679233044386
Train 280 steps, use 7.5231 s, loss 0.5252839922904968, accuracy 0.8070312738418579, loss reg 0.0074622235260903835
Train 290 steps, use 7.4474 s, loss 0.4552551209926605, accuracy 0.840624988079071, loss reg 0.007330414839088917
Train 300 steps, use 7.3283 s, loss 0.4508514404296875, accuracy 0.8617187738418579, loss reg 0.007133393082767725
Train 310 steps, use 7.3453 s, loss 0.38369470834732056, accuracy 0.925000011920929, loss reg 0.007317659445106983
Train 320 steps, use 7.2690 s, loss 0.4067922532558441, accuracy 0.9125000238418579, loss reg 0.00806522835046053
Train 330 steps, use 7.4205 s, loss 0.4162019193172455, accuracy 0.8843750357627869, loss reg 0.0077808513306081295
Train 340 steps, use 7.3557 s, loss 0.42762160301208496, accuracy 0.8695312738418579, loss reg 0.00763324648141861
Train 350 steps, use 7.4115 s, loss 0.38524919748306274, accuracy 0.8984375, loss reg 0.00784334447234869
Train 360 steps, use 7.3625 s, loss 0.36755821108818054, accuracy 0.905468761920929, loss reg 0.007520874030888081
Train 370 steps, use 7.3553 s, loss 0.4653354585170746, accuracy 0.839062511920929, loss reg 0.007807845715433359
Train 380 steps, use 7.4220 s, loss 0.46386781334877014, accuracy 0.828906238079071, loss reg 0.007937172427773476
Train 390 steps, use 7.3668 s, loss 0.5748793482780457, accuracy 0.75, loss reg 0.007791445590555668
Train 400 steps, use 7.2759 s, loss 0.3976801037788391, accuracy 0.8812500238418579, loss reg 0.007725016679614782

Visualization

[9]:
# visualization
dataset, _ = next(get_data(20, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0)())
runner = bp.DSTrainer(net, monitors={'spike': net.r.spike})
outs = runner.predict(dataset, reset_state=True)

for i in range(10):
  fig, gs = bp.visualize.get_figure(3, 1, 2., 6.)
  ax_inp = fig.add_subplot(gs[0, 0])
  ax_rec = fig.add_subplot(gs[1, 0])
  ax_out = fig.add_subplot(gs[2, 0])

  data = dataset[i]
  # insert empty row
  n_channel = data.shape[1] // 4
  zero_fill = np.zeros((data.shape[0], int(n_channel / 2)))
  data = np.concatenate((data[:, 3 * n_channel:], zero_fill,
                         data[:, 2 * n_channel:3 * n_channel], zero_fill,
                         data[:, :n_channel], zero_fill,
                         data[:, n_channel:2 * n_channel]), axis=1)
  ax_inp.set_yticklabels([])
  ax_inp.add_patch(patches.Rectangle((0, 2 * n_channel + 2 * int(n_channel / 2)),
                                     data.shape[0], n_channel,
                                     facecolor="red", alpha=0.1))
  ax_inp.add_patch(patches.Rectangle((0, 3 * n_channel + 3 * int(n_channel / 2)),
                                     data.shape[0], n_channel,
                                     facecolor="blue", alpha=0.1))
  bp.visualize.raster_plot(runner.mon.ts, data, ax=ax_inp, marker='|')
  ax_inp.set_ylabel('Input Activity')
  ax_inp.set_xticklabels([])
  ax_inp.set_xticks([])

  # spiking activity
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'][i], ax=ax_rec, marker='|')
  ax_rec.set_ylabel('Spiking Activity')
  ax_rec.set_xticklabels([])
  ax_rec.set_xticks([])
  # decision activity
  ax_out.set_yticks([0, 0.5, 1])
  ax_out.set_ylabel('Output Activity')
  ax_out.plot(runner.mon.ts, outs[i, :, 0], label='Readout 0', alpha=0.7)
  ax_out.plot(runner.mon.ts, outs[i, :, 1], label='Readout 1', alpha=0.7)
  ax_out.set_xticklabels([])
  ax_out.set_xticks([])
  ax_out.set_xlabel('Time [ms]')
  plt.legend()
  plt.show()
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_1.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_2.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_3.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_4.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_5.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_6.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_7.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_8.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_9.png
_images/recurrent_networks_Bellec_2020_eprop_evidence_accumulation_12_10.png

(Bouchacourt & Buschman, 2019) Flexible Working Memory Model

Implementation of :

  • Bouchacourt, Flora, and Timothy J. Buschman. “A flexible model of working memory.” Neuron 103.1 (2019): 147-160.

Author:

[1]:
import matplotlib.pyplot as plt
[2]:
import brainpy as bp
import brainpy.math as bm
[3]:
# increase in order to run multiple trials with the same network
num_trials = 1
num_item_to_load = 6
[4]:
# Parameters for network architecture
# ------------------------------------

num_sensory_neuron = 512  # Number of recurrent neurons per sensory network
num_sensory_pool = 8  # Number of ring-like sensory network
num_all_sensory = num_sensory_pool * num_sensory_neuron
num_all_random = 1024  # Number of neuron in the random network
fI_slope = 0.4  # slope of the non-linear f-I function
bias = 0.  # bias in the neuron firing response (cf page 1 right column of Burak, Fiete 2012)
tau = 10.  # Synaptic time constant [ms]
init_range = 0.01  # Range to randomly initialize synaptic variables
[5]:
# Parameters for sensory network
# -------------------------------

k2 = 0.25  # width of negative surround suppression
k1 = 1.  # width of positive amplification
A = 2.  # amplitude of weight function
lambda_ = 0.28  # baseline of weight matrix for recurrent network
[6]:
# Parameters for interaction of
# random network <-> sensory network
# -----------------------------------

forward_balance = -1.  # if -1, perfect feed-forward balance from SN to RN
backward_balance = -1.  # if -1, perfect feedback balance from RN to SN
alpha = 2.1  # parameter used to compute the feedforward weight, before balancing
beta = 0.2  # parameter used to compute the feedback weight, before balancing
gamma = 0.35  # connectivity (gamma in the paper)
factor = 1000  # factor for computing weights values
[7]:
# Parameters for stimulus
# -----------------------

simulation_time = 1100  # # the simulation time [ms]
start_stimulation = 100  # [ms]
end_stimulation = 200  # [ms]
input_strength = 10  # strength of the stimulation
num_sensory_input_width = 32
# the width for input stimulation of the gaussian distribution
sigma = round(num_sensory_neuron / num_sensory_input_width)
three_sigma = 3 * sigma
activity_threshold = 3
[8]:
# Weights initialization
# ----------------------

# weight matrix within sensory network
sensory_encoding = 2. * bm.pi * bm.arange(1, num_sensory_neuron + 1) / num_sensory_neuron
diff = sensory_encoding.reshape((-1, 1)) - sensory_encoding
weight_mat_of_sensory = lambda_ + A * bm.exp(k1 * (bm.cos(diff) - 1)) - A * bm.exp(k2 * (bm.cos(diff) - 1))
diag = bm.arange(num_sensory_neuron)
weight_mat_of_sensory[diag, diag] = 0.
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[9]:
# connectivity matrix between sensory and random network
conn_matrix_sensory2random = bm.random.rand(num_all_sensory, num_all_random) < gamma
[10]:
# weight matrix of sensory2random
ws = factor * alpha / conn_matrix_sensory2random.sum(axis=0)
weight_mat_sensory2random = conn_matrix_sensory2random * ws.reshape((1, -1))
ws = weight_mat_sensory2random.sum(axis=0).reshape((1, -1))
weight_mat_sensory2random += forward_balance / num_all_sensory * ws  # balance
[11]:
# weight matrix of random2sensory
ws = factor * beta / conn_matrix_sensory2random.sum(axis=1)
weight_mat_random2sensory = conn_matrix_sensory2random.T * ws.reshape((1, -1))
ws = weight_mat_random2sensory.sum(axis=0).reshape((1, -1))
weight_mat_random2sensory += backward_balance / num_all_random * ws  # balance
[12]:
@bm.jit
def f(inp_ids, center):
  inp_scale = bm.exp(-(inp_ids - center) ** 2 / 2 / sigma ** 2) / (bm.sqrt(2 * bm.pi) * sigma)
  inp_scale /= bm.max(inp_scale)
  inp_ids = bm.remainder(inp_ids - 1, num_sensory_neuron)
  input = bm.zeros(num_sensory_neuron)
  input[inp_ids] = input_strength * inp_scale
  input -= bm.sum(input) / num_sensory_neuron
  return input


def get_input(center):
  inp_ids = bm.arange(bm.asarray(center - three_sigma, dtype=bm.int32),
                      bm.asarray(center + three_sigma + 1, dtype=bm.int32),
                      1)
  return f(inp_ids, center)
[13]:
def get_activity_vector(rates):
  exp_stim_encoding = bm.exp(1j * sensory_encoding)
  timed_abs = bm.zeros(num_sensory_pool)
  timed_angle = bm.zeros(num_sensory_pool)
  for si in range(num_sensory_pool):
    start = si * num_sensory_neuron
    end = (si + 1) * num_sensory_neuron
    exp_rates = bm.multiply(rates[start:end], exp_stim_encoding)
    mean_rates = bm.mean(exp_rates)
    timed_angle[si] = bm.angle(mean_rates) * num_sensory_neuron / (2 * bm.pi)
    timed_abs[si] = bm.absolute(mean_rates)
  timed_angle[timed_angle < 0] += num_sensory_neuron
  return timed_abs, timed_angle
[14]:
class PoissonNeuron(bp.NeuGroup):
  def __init__(self, size, **kwargs):
    super(PoissonNeuron, self).__init__(size=size, **kwargs)

    self.s = bm.Variable(bm.zeros(self.num))
    self.r = bm.Variable(bm.zeros(self.num))
    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()

    self.int_s = bp.odeint(lambda s, t: -s / tau, method='exp_euler')

  def update(self, tdi):
    self.s.value = self.int_s(self.s, tdi.t, tdi.dt)
    self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (self.input + self.s + bias) - 3.)) / tau
    self.spike.value = self.rng.random(self.s.shape) < self.r * tdi.dt
    self.input[:] = 0.

  def reset_state(self, batch_size=None):
    self.s.value = self.rng.random(self.num) * init_range
    self.r.value = 0.4 * (1. + bm.tanh(fI_slope * (bias + self.s) - 3.)) / tau
    self.input.value = bm.zeros(self.num)
    self.spike.value = bm.zeros(self.num, dtype=bool)
[15]:
class Sen2SenSyn(bp.SynConn):
  def __init__(self, pre, post, **kwargs):
    super(Sen2SenSyn, self).__init__(pre=pre, post=post, **kwargs)

  def update(self, tdi):
    for i in range(num_sensory_pool):
      start = i * num_sensory_neuron
      end = (i + 1) * num_sensory_neuron
      self.post.s[start: end] += bm.dot(self.pre.spike[start: end],
                                        weight_mat_of_sensory)
[16]:
class OtherSyn(bp.SynConn):
  def __init__(self, pre, post, weights, **kwargs):
    super(OtherSyn, self).__init__(pre=pre, post=post, **kwargs)
    self.weights = weights

  def update(self, tdi):
    self.post.s += bm.dot(self.pre.spike, self.weights)
[17]:
class Network(bp.Network):
  def __init__(self):
    super(Network, self).__init__()

    self.sensory = PoissonNeuron(num_all_sensory)
    self.random = PoissonNeuron(num_all_random)
    self.sensory2sensory = Sen2SenSyn(pre=self.sensory, post=self.sensory)
    self.random2sensory = OtherSyn(pre=self.random,
                                   post=self.sensory,
                                   weights=weight_mat_random2sensory)
    self.sensory2random = OtherSyn(pre=self.sensory,
                                   post=self.random,
                                   weights=weight_mat_sensory2random)
[18]:
for trial_idx in range(num_trials):
  # inputs
  # ------
  pools_receiving_inputs = bm.random.choice(num_sensory_pool, num_item_to_load, replace=False)
  print(f"Load {num_item_to_load} items in trial {trial_idx}.\n")

  input_center = bm.ones(num_sensory_pool) * num_sensory_neuron / 2
  inp_vector = bm.zeros((num_sensory_pool, num_sensory_neuron))
  for si in pools_receiving_inputs:
    inp_vector[si, :] = get_input(input_center[si])
  inp_vector = inp_vector.flatten()
  Iext, duration = bp.inputs.constant_input(
    [(0., start_stimulation),
     (inp_vector, end_stimulation - start_stimulation),
     (0., simulation_time - end_stimulation)]
  )

  # running
  # -------
  net = Network()
  runner = bp.dyn.DSRunner(net,
                           inputs=(net.sensory.input, Iext, 'iter'),
                           monitors={'S.r': net.sensory.r,
                                     'S.spike': net.sensory.spike,
                                     'R.spike': net.random.spike})
  runner.predict(duration, reset_state=True)

  # results
  # --------

  rate_abs, rate_angle = get_activity_vector(runner.mon['S.r'][-1] * 1e3)
  print(f"Stimulus is given in: {bm.sort(pools_receiving_inputs)}")
  print(f"Memory is found in: {bm.where(rate_abs > activity_threshold)[0]}")

  prob_maintained, prob_spurious = 0, 0
  for si in range(num_sensory_pool):
    if rate_abs[si] > activity_threshold:
      if si in pools_receiving_inputs:
        prob_maintained += 1
      else:
        prob_spurious += 1
  print(str(prob_maintained) + ' maintained memories')
  print(str(pools_receiving_inputs.shape[0] - prob_maintained) + ' forgotten memories')
  print(str(prob_spurious) + ' spurious memories\n')
  prob_maintained /= float(num_item_to_load)
  if num_item_to_load != num_sensory_pool:
    prob_spurious /= float(num_sensory_pool - num_item_to_load)

  # visualization
  # -------------
  fig, gs = bp.visualize.get_figure(6, 1, 1.5, 12)
  xlim = (0, duration)

  fig.add_subplot(gs[0:4, 0])
  bp.visualize.raster_plot(runner.mon.ts,
                           runner.mon['S.spike'],
                           ylabel='Sensory Network', xlim=xlim)
  for index_sn in range(num_sensory_pool + 1):
    plt.axhline(index_sn * num_sensory_neuron)
  plt.yticks([num_sensory_neuron * (i + 0.5) for i in range(num_sensory_pool)],
             [f'pool-{i}' for i in range(num_sensory_pool)])

  fig.add_subplot(gs[4:6, 0])
  bp.visualize.raster_plot(runner.mon.ts,
                           runner.mon['R.spike'],
                           ylabel='Random Network', xlim=xlim, show=True)
Load 6 items in trial 0.

Stimulus is given in: [0 1 2 4 5 7]
Memory is found in: [0 7]
2 maintained memories
4 forgotten memories
0 spurious memories

_images/working_memory_Bouchacourt_2019_Flexible_working_memory_19_3.png

(Mi, et. al., 2017) STP for Working Memory Capacity

Implementation of the paper:

  • Mi, Yuanyuan, Mikhail Katkov, and Misha Tsodyks. “Synaptic correlates of working memory capacity.” Neuron 93.2 (2017): 323-330.

[1]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
[2]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bm.enable_x64()
[3]:
alpha = 1.5
J_EE = 8.  # the connection strength in each excitatory neural clusters
J_IE = 1.75  # Synaptic efficacy E → I
J_EI = 1.1  # Synaptic efficacy I → E
tau_f = 1.5  # time constant of STF  [s]
tau_d = .3  # time constant of STD  [s]
U = 0.3  # minimum STF value
tau = 0.008  # time constant of firing rate of the excitatory neurons [s]
tau_I = tau  # time constant of firing rate of the inhibitory neurons

Ib = 8.  # background input and external input
Iinh = 0.  # the background input of inhibtory neuron

cluster_num = 16  # the number of the clusters
[4]:
# the parameters of external input

stimulus_num = 5
Iext_train = 225  # the strength of the external input
Ts_interval = 0.070  # the time interval between the consequent external input [s]
Ts_duration = 0.030  # the time duration of the external input [s]
duration = 2.500  # [s]
[5]:
# the excitatory cluster model and the inhibitory pool model

class WorkingMemoryModel(bp.DynamicalSystem):
  def __init__(self, **kwargs):
    super(WorkingMemoryModel, self).__init__(**kwargs)

    # variables
    self.u = bm.Variable(bm.ones(cluster_num) * U)
    self.x = bm.Variable(bm.ones(cluster_num))
    self.h = bm.Variable(bm.zeros(cluster_num))
    self.r = bm.Variable(self.log(self.h))
    self.input = bm.Variable(bm.zeros(cluster_num))
    self.inh_h = bm.Variable(bm.zeros(1))
    self.inh_r = bm.Variable(self.log(self.inh_h))

  def log(self, h):
    # return alpha * bm.log(1. + bm.exp(h / alpha))
    return alpha * bm.log1p(bm.exp(h / alpha))

  def update(self, tdi):
    uxr = self.u * self.x * self.r
    du = (U - self.u) / tau_f + U * (1 - self.u) * self.r
    dx = (1 - self.x) / tau_d - uxr
    dh = (-self.h + J_EE * uxr - J_EI * self.inh_r + self.input + Ib) / tau
    dhi = (-self.inh_h + J_IE * bm.sum(self.r) + Iinh) / tau_I
    self.u += du * tdi.dt
    self.x += dx * tdi.dt
    self.h += dh * tdi.dt
    self.inh_h += dhi * tdi.dt
    self.r[:] = self.log(self.h)
    self.inh_r[:] = self.log(self.inh_h)
    self.input[:] = 0.
[6]:
dt = 0.0001  # [s]
# the external input
I_inputs = bm.zeros((int(duration / dt), cluster_num))
for i in range(stimulus_num):
  t_start = (Ts_interval + Ts_duration) * i + Ts_interval
  t_end = t_start + Ts_duration
  idx_start, idx_end = int(t_start / dt), int(t_end / dt)
  I_inputs[idx_start: idx_end, i] = Iext_train
[7]:
# running
runner = bp.DSRunner(WorkingMemoryModel(),
                     inputs=['input', I_inputs, 'iter'],
                     monitors=['u', 'x', 'r', 'h'],
                     dt=dt)
runner(duration)
[8]:
# visualization
colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys())

fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
fig.add_subplot(gs[0, 0])
for i in range(stimulus_num):
  plt.plot(runner.mon.ts, runner.mon.r[:, i], label='Cluster-{}'.format(i))
plt.ylabel("$r (Hz)$")
plt.legend(loc='right')

fig.add_subplot(gs[1, 0])
hist_Jux = J_EE * runner.mon.u * runner.mon.x
for i in range(stimulus_num):
  plt.plot(runner.mon.ts, hist_Jux[:, i])
plt.ylabel("$J_{EE}ux$")

fig.add_subplot(gs[2, 0])
for i in range(stimulus_num):
  plt.plot(runner.mon.ts, runner.mon.u[:, i], colors[i])
plt.ylabel('u')

fig.add_subplot(gs[3, 0])
for i in range(stimulus_num):
  plt.plot(runner.mon.ts, runner.mon.x[:, i], colors[i])
plt.ylabel('x')

fig.add_subplot(gs[4, 0])
for i in range(stimulus_num):
  plt.plot(runner.mon.ts, runner.mon.r[:, i], colors[i])
plt.ylabel('h')
plt.xlabel('time [s]')

plt.show()
_images/working_memory_Mi_2017_working_memory_capacity_9_0.png

[1D] Simple systems

[1]:
import brainpy as bp

bp.math.enable_x64()
# bp.math.set_platform('cpu')

Phase plane

Here we will show the birfurcation analysis of 1D system with dummy test neuronal model.

\[\dot{x} = x^3-x + I\]

First, let’s define the model.

[2]:
@bp.odeint
def int_x(x, t, Iext):
  dx = x ** 3 - x + Iext
  return dx
[3]:
analyzer = bp.analysis.PhasePlane1D(int_x,
                                    target_vars={'x': [-10, 10]},
                                    pars_update={'Iext': 0.})
analyzer.plot_vector_field()
analyzer.plot_fixed_point(show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I am creating the vector field ...
I am searching fixed points ...
Fixed point #1 at x=-1.0000000000000002 is a unstable point.
Fixed point #2 at x=-7.771561172376096e-16 is a stable point.
Fixed point #3 at x=1.0000000000000002 is a unstable point.
_images/dynamics_analysis_1d_simple_systems_5_1.png

Codimension1

Then, create a bifurcation analyzer with bp.symbolic.Bifurcation.

[4]:
an = bp.analysis.Bifurcation1D(
  int_x,
  target_pars={'Iext': [-0.5, 0.5]},
  target_vars={"x": [-2, 2]},
  resolutions={'Iext': 0.001}
)
an.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/dynamics_analysis_1d_simple_systems_8_1.png

Codimension2

Here we define the following 1D model for codimension 2 bifurcation testing.

\[\dot{x} = \mu+ \lambda x - x^3\]
[5]:
@bp.odeint
def int_x(x, t, mu, lambda_):
  dxdt = mu + lambda_ * x - x ** 3
  return dxdt
[6]:
analyzer = bp.analysis.Bifurcation1D(
  int_x,
  target_pars={'lambda_': [-1, 4], 'mu': [-4, 4], },
  target_vars={'x': [-3, 3]},
  resolutions={'lambda_': 0.1, 'mu': 0.1}
)
analyzer.plot_bifurcation(show=True)
I am making bifurcation analysis ...
_images/dynamics_analysis_1d_simple_systems_12_1.png

[2D] NaK model analysis

Here we will show you the neurodynamics analysis of a two-dimensional system model with the example of the \(I_{\rm{Na,p+}}-I_K\) Model.

The dynamical system is given by:

\[C\dot{V} = I_{ext} - g_L * (V-E_L)-g_{Na}*m_\infty(V)(V-E_{Na})-g_K*n*(V-E_K)\]
\[\dot{n} = \frac{n_\infty(V)-n}{\tau(V)}\]

where

\[m_\infty(V) = 1 \ / \ ({1+\exp(\frac{V_{\rm{n_{half}}}-V}{k_m})})\]
\[n_\infty(V) = 1 \ / \ ({1+\exp(\frac{V_{\rm{n_{half}}}-V}{k_n})})\]

This model specifies a leak current \(I_L\), persistent sodium current \(I_{\rm{Na, p}}\) with instantaneous activation kinetic, and a relatively slower persistent potassium current \(I_K\) with either high or low threshold (the two choices result in fundamentally different dynamics).

[1]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bm.set_dt(dt=0.02)
bm.enable_x64()
[2]:
C = 1
E_L = -78  # different from high-threshold model
g_L = 8
g_Na = 20
g_K = 10
E_K = -90
E_Na = 60
Vm_half = -20
k_m = 15
Vn_half = -45  # different from high-threshold model
k_n = 5
tau = 1


@bp.odeint(method='exp_auto')
def int_V(V, t, n, Iext):
    m_inf = 1 / (1 + bm.exp((Vm_half - V) / k_m))
    I_leak = g_L * (V - E_L)
    I_Na = g_Na * m_inf * (V - E_Na)
    I_K = g_K * n * (V - E_K)
    dvdt = (-I_leak - I_Na - I_K + Iext) / C
    return dvdt


@bp.odeint(method='exp_auto')
def int_n(n, t, V):
    n_inf = 1 / (1 + bm.exp((Vn_half - V) / k_n))
    dndt = (n_inf - n) / tau
    return dndt

Phase plane analysis

[3]:
analyzer = bp.analysis.PhasePlane2D(
    model=[int_n, int_V],
    target_vars={'n': [0., 1.], 'V': [-90, 20]},
    pars_update={'Iext': 50.},
    resolutions={'n': 0.01, 'V': 0.1}
)
analyzer.plot_nullcline()
analyzer.plot_vector_field()
analyzer.plot_fixed_point()
analyzer.plot_trajectory({'n': [0.2, 0.4], 'V': [-10, -80]},
                         duration=100., show=True)
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 680 candidates
I am trying to filter out duplicate fixed points ...
        Found 1 fixed points.
        #1 n=0.210529358619151, V=-51.60868761096133 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_2d_NaK_model_5_1.png

Codimension 1 bifurcation analysis

Here we show the codimension 1 bifurcation analysis of the \(I_{\rm{Na,p+}}-I_K\) Model, in which \(I_{ext}\) is varied in [0., 50.].

[4]:
analyzer = bp.analysis.Bifurcation2D(
  model=[int_V, int_n],
  target_vars={"V": [-90., 20.], 'n': [0., 1.]},
  target_pars={'Iext': [0, 50.]},
  resolutions={'Iext': 0.1},
)
analyzer.plot_bifurcation(num_rank=30)
analyzer.plot_limit_cycle_by_sim(duration=100., show=True)
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
        There are 15000 candidates
I am trying to filter out duplicate fixed points ...
        Found 325 fixed points.
I am plotting the limit cycle ...
_images/dynamics_analysis_2d_NaK_model_8_1.png
_images/dynamics_analysis_2d_NaK_model_8_2.png

Reference

  1. Izhikevich, Eugene M. Dynamical systems in neuroscience (Chapter 4). MIT press, 2007.

[2D] Wilson-Cowan model

[1]:
import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
class WilsonCowanModel(bp.DynamicalSystem):
  def __init__(self, num, method='exp_auto'):
    super(WilsonCowanModel, self).__init__()

    # Connection weights
    self.wEE = 12
    self.wEI = 4
    self.wIE = 13
    self.wII = 11

    # Refractory parameter
    self.r = 1

    # Excitatory parameters
    self.E_tau = 1  # Timescale of excitatory population
    self.E_a = 1.2  # Gain of excitatory population
    self.E_theta = 2.8  # Threshold of excitatory population

    # Inhibitory parameters
    self.I_tau = 1  # Timescale of inhibitory population
    self.I_a = 1  # Gain of inhibitory population
    self.I_theta = 4  # Threshold of inhibitory population

    # variables
    self.i = bm.Variable(bm.ones(num))
    self.e = bm.Variable(bm.ones(num))
    self.Iext = bm.Variable(bm.zeros(num))

    # functions
    def F(x, a, theta):
      return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta))

    def de(e, t, i, Iext=0.):
      x = self.wEE * e - self.wEI * i + Iext
      return (-e + (1 - self.r * e) * F(x, self.E_a, self.E_theta)) / self.E_tau

    def di(i, t, e):
      x = self.wIE * e - self.wII * i
      return (-i + (1 - self.r * i) * F(x, self.I_a, self.I_theta)) / self.I_tau

    self.int_e = bp.odeint(de, method=method)
    self.int_i = bp.odeint(di, method=method)

  def update(self, tdi):
    self.e.value = self.int_e(self.e, tdi.t, self.i, self.Iext, tdi.dt)
    self.i.value = self.int_i(self.i, tdi.t, self.e, tdi.dt)
    self.Iext[:] = 0.

Simulation

[3]:
model = WilsonCowanModel(2)
model.e[:] = [-0.1, 0.4]
model.i[:] = [0.5, 0.6]
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
# simulation
runner = bp.DSRunner(model, monitors=['e', 'i'])
runner.run(100)
[5]:
fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[0], legend='e', linestyle='--')
bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[0], legend='i', linestyle='--')
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[1], legend='e')
bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[1], legend='i', show=True)
_images/dynamics_analysis_2d_wilson_cowan_model_6_0.png

Phase Plane Analysis

[6]:
# phase plane analysis
pp = bp.analysis.PhasePlane2D(
  model,
  target_vars={'e': [-0.2, 1.], 'i': [-0.2, 1.]},
  resolutions=0.001,
)
pp.plot_vector_field()
pp.plot_nullcline(coords={'i': 'i-e'})
pp.plot_fixed_point()
pp.plot_trajectory(initials={'i': [0.5, 0.6], 'e': [-0.1, 0.4]},
                   duration=10, dt=0.1)
pp.show_figure()
I am creating the vector field ...
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 1966 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 e=-1.6868322611378924e-07, i=-3.4570638908889285e-08 is a stable node.
        #2 e=0.17803581298567323, i=0.06279970624572738 is a saddle node.
        #3 e=0.46246485137398247, i=0.24336408752379407 is a stable node.
I am plotting the trajectory ...
_images/dynamics_analysis_2d_wilson_cowan_model_8_1.png

[2D] Decision Making Model with SlowPointFinder

[1]:
import brainpy as bp
import brainpy.math as bm

bp.math.enable_x64()
# bp.math.set_platform('cpu')
[2]:
# parameters
gamma = 0.641  # Saturation factor for gating variable
tau = 0.06  # Synaptic time constant [sec]
a = 270.
b = 108.
d = 0.154
[3]:
JE = 0.3725  # self-coupling strength [nA]
JI = -0.1137  # cross-coupling strength [nA]
JAext = 0.00117  # Stimulus input strength [nA]
[4]:
mu = 20.  # Stimulus firing rate [spikes/sec]
coh = 0.5  # Stimulus coherence [%]
Ib1 = 0.3297
Ib2 = 0.3297
[5]:
@bp.odeint
def int_s1(s1, t, s2, coh=0.5, mu=20.):
  I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh)
  r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b)))
  return - s1 / tau + (1. - s1) * gamma * r1
[6]:
@bp.odeint
def int_s2(s2, t, s1, coh=0.5, mu=20.):
  I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh)
  r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b)))
  return - s2 / tau + (1. - s2) * gamma * r2
[7]:
def cell(s):
  ds1 = int_s1.f(s[0], 0., s[1])
  ds2 = int_s2.f(s[1], 0., s[0])
  return bm.asarray([ds1, ds2])
[8]:
finder = bp.analysis.SlowPointFinder(f_cell=cell, f_type='continuous')
finder.find_fps_with_gd_method(
  candidates=bm.random.random((1000, 2)),
  tolerance=1e-5, num_batch=200,
  optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)),
)
finder.filter_loss(1e-5)
finder.keep_unique()

print('fixed_points: ', finder.fixed_points)
print('losses:', finder.losses)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Optimizing with Adam(lr=ExponentialDecay(0.01, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.53 sec, Training loss 0.0209022794
    Batches 201-400 in 0.39 sec, Training loss 0.0034664037
    Batches 401-600 in 0.40 sec, Training loss 0.0009952186
    Batches 601-800 in 1.03 sec, Training loss 0.0003619036
    Batches 801-1000 in 0.95 sec, Training loss 0.0001457642
    Batches 1001-1200 in 0.42 sec, Training loss 0.0000614027
    Batches 1201-1400 in 0.39 sec, Training loss 0.0000262476
    Batches 1401-1600 in 0.49 sec, Training loss 0.0000111700
    Batches 1601-1800 in 1.29 sec, Training loss 0.0000046670
    Stop optimization as mean training loss 0.0000046670 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-05:
    Kept 957/1000 fixed points with tolerance under 1e-05.
Excluding non-unique fixed points:
    Kept 3/957 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  [[0.70045189 0.00486431]
 [0.01394651 0.65738904]
 [0.28276323 0.40635171]]
losses: [2.16451334e-30 1.51498296e-28 3.08194113e-16]
[9]:
jac = finder.compute_jacobians(finder.fixed_points, plot=True, num_col=1)
_images/dynamics_analysis_2d_decision_making_model_9_0.png

[2D] Decision Making Model with Low-dimensional Analyzer

Please refer to Anlysis of A Decision Making Model.

[3D] Hindmarsh Rose Model

[1]:
import brainpy as bp

# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
import matplotlib.pyplot as plt
import numpy as np
[3]:
class HindmarshRose(bp.DynamicalSystem):
  def __init__(self, method='exp_auto'):
    super(HindmarshRose, self).__init__()

    # parameters
    self.a = 1.
    self.b = 2.5
    self.c = 1.
    self.d = 5.
    self.s = 4.
    self.x_r = -1.6
    self.r = 0.001

    # variables
    self.x = bp.math.Variable(bp.math.ones(1))
    self.y = bp.math.Variable(bp.math.ones(1))
    self.z = bp.math.Variable(bp.math.ones(1))
    self.I = bp.math.Variable(bp.math.zeros(1))

    # integral functions
    def dx(x, t, y, z, Isyn):
      return y - self.a * x ** 3 + self.b * x * x - z + Isyn

    def dy(y, t, x):
      return self.c - self.d * x * x - y

    def dz(z, t, x):
      return self.r * (self.s * (x - self.x_r) - z)

    self.int_x = bp.odeint(f=dx, method=method)
    self.int_y = bp.odeint(f=dy, method=method)
    self.int_z = bp.odeint(f=dz, method=method)

  def update(self, tdi):
    self.x.value = self.int_x(self.x, tdi.t, self.y, self.z, self.I, tdi.dt)
    self.y.value = self.int_y(self.y, tdi.t, self.x, tdi.dt)
    self.z.value = self.int_z(self.z, tdi.t, self.x, tdi.dt)
    self.I[:] = 0.

Simulation

[4]:
model = HindmarshRose()

runner = bp.DSRunner(model, monitors=['x', 'y', 'z'], inputs=['I', 1.5])
runner.run(2000.)
bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/dynamics_analysis_3d_hindmarsh_rose_model_5_2.png

Bifurcation analysis

[5]:
analyzer = bp.analysis.FastSlow2D(
    [model.int_x, model.int_y, model.int_z],
    fast_vars={'x': [-3, 2], 'y': [-20., 3.]},
    slow_vars={'z': [-0.5, 3.]},
    pars_update={'Isyn': 1.5},
    resolutions={'z': 0.01},
    # options={bp.analysis.C.y_by_x_in_fy: lambda x: model.c - model.d * x * x}
)
analyzer.plot_bifurcation(num_rank=20)
analyzer.plot_trajectory({'x': [1.], 'y': [1.], 'z': [1.]},
                       duration=1700,
                       plot_durations=[360, 1680])
analyzer.show_figure()
I am making bifurcation analysis ...
I am filtering out fixed point candidates with auxiliary function ...
I am trying to find fixed points by optimization ...
        There are 7000 candidates
I am trying to filter out duplicate fixed points ...
        Found 789 fixed points.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_7_1.png
_images/dynamics_analysis_3d_hindmarsh_rose_model_7_2.png

Phase plane analysis

[6]:
for z in np.arange(0., 2.5, 0.3):
    analyzer = bp.analysis.PhasePlane2D(
      [model.int_x, model.int_y],
      target_vars={'x': [-3, 2], 'y': [-20., 3.]},
      pars_update={'Isyn': 1.5, 'z': z},
      resolutions={'x': 0.01, 'y': 0.01},
    )
    analyzer.plot_nullcline()
    analyzer.plot_vector_field()
    fps = analyzer.plot_fixed_point(with_return=True)
    analyzer.plot_trajectory({'x': [fps[-1, 0] + 0.1], 'y': [fps[-1, 0] + 0.1]},
                             duration=500, plot_durations=[400, 500])
    plt.title(f'z={z:.2f}')
    plt.show()
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2111 candidates
I am trying to filter out duplicate fixed points ...
        Found 1 fixed points.
        #1 x=0.862288349294799, y=-2.7177058646654935 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_1.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2141 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-1.872654283356628, y=-16.5341715842593 is a stable node.
        #2 x=-1.4420391946597404, y=-9.397382585390593 is a saddle node.
        #3 x=0.8146858103927409, y=-2.318565022506274 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_3.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2171 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.0462132406600912, y=-19.934943131130087 is a stable node.
        #2 x=-1.2168555131297534, y=-6.40368673337723 is a saddle node.
        #3 x=0.7630687306850773, y=-1.911369548008847 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_5.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2201 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.1556926349477843, y=-22.235053679257277 is a stable node.
        #2 x=-1.05070800710585, y=-4.519936589150076 is a saddle node.
        #3 x=0.7064006730222961, y=-1.4950094398974472 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_7.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2231 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.2411861683124883, y=-24.1145772063309 is a stable node.
        #2 x=-0.901932179236277, y=-3.067408474244581 is a saddle node.
        #3 x=0.6431188901157783, y=-1.0680095250383574 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_9.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2261 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.3130990884027134, y=-25.75213708580098 is a stable node.
        #2 x=-0.7575691876779478, y=-1.8695551041910168 is a saddle node.
        #3 x=0.5706681801629564, y=-0.6283107114652488 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_11.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2291 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.3760052650455847, y=-27.227005096799942 is a stable node.
        #2 x=-0.6083084538224832, y=-0.8501958918559419 is a saddle node.
        #3 x=0.48431473687387644, y=-0.17280300724589612 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_13.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2321 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.4323928622096647, y=-28.582675180900974 is a stable node.
        #2 x=-0.4407309093835196, y=0.028781414591295275 is a saddle node.
        #3 x=0.37312367222430437, y=0.30389369711344255 is a unstable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_15.png
I am computing fx-nullcline ...
I am evaluating fx-nullcline by optimization ...
I am computing fy-nullcline ...
I am evaluating fy-nullcline by optimization ...
I am creating the vector field ...
I am searching fixed points ...
I am trying to find fixed points by optimization ...
        There are 2351 candidates
I am trying to filter out duplicate fixed points ...
        Found 3 fixed points.
        #1 x=-2.4837904744530963, y=-29.84607557713767 is a stable node.
        #2 x=-0.20892015454771234, y=0.7817619225980599 is a saddle node.
        #3 x=0.19271039914787816, y=0.8143135053948518 is a stable focus.
I am plotting the trajectory ...
_images/dynamics_analysis_3d_hindmarsh_rose_model_9_17.png

Continuous-attractor Neural Network

[1]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
[2]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

Model

[3]:
class CANN1D(bp.NeuGroup):
  def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4.,
               z_min=-bm.pi, z_max=bm.pi, **kwargs):
    super(CANN1D, self).__init__(size=num, **kwargs)

    # parameters
    self.tau = tau  # The synaptic time constant
    self.k = k  # Degree of the rescaled inhibition
    self.a = a  # Half-width of the range of excitatory connections
    self.A = A  # Magnitude of the external input
    self.J0 = J0  # maximum connection value

    # feature space
    self.z_min = z_min
    self.z_max = z_max
    self.z_range = z_max - z_min
    self.x = bm.linspace(z_min, z_max, num)  # The encoded feature values
    self.rho = num / self.z_range  # The neural density
    self.dx = self.z_range / num  # The stimulus density

    # variables
    self.u = bm.Variable(bm.zeros(num))
    self.input = bm.Variable(bm.zeros(num))

    # The connection matrix
    self.conn_mat = self.make_conn(self.x)

    # function
    self.integral = bp.odeint(self.derivative)

  def derivative(self, u, t, Iext):
    r1 = bm.square(u)
    r2 = 1.0 + self.k * bm.sum(r1)
    r = r1 / r2
    Irec = bm.dot(self.conn_mat, r)
    du = (-u + Irec + Iext) / self.tau
    return du

  def dist(self, d):
    d = bm.remainder(d, self.z_range)
    d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
    return d

  def make_conn(self, x):
    assert bm.ndim(x) == 1
    x_left = bm.reshape(x, (-1, 1))
    x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
    d = self.dist(x_left - x_right)
    Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
    return Jxx

  def get_stimulus_by_pos(self, pos):
    return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))

  def update(self, tdi):
    self.u.value = self.integral(self.u, tdi.t, self.input, tdi.t)
    self.input[:] = 0.

Find fixed points

[4]:
model = CANN1D(num=512, k=0.1, A=30, a=0.5)
[5]:
candidates = model.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.005).reshape((-1, 1)))

finder = bp.analysis.SlowPointFinder(f_cell=model, target_vars={'u': model.u})
finder.find_fps_with_gd_method(
  candidates={'u': candidates},
  tolerance=1e-6,
  num_batch=200,
  optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.1, 2, 0.999)),
)
finder.filter_loss(1e-7)
finder.keep_unique()
# finder.exclude_outliers(tolerance=1e1)

print('Losses of fixed points:')
print(finder.losses)
Optimizing with Adam(lr=ExponentialDecay(0.1, decay_steps=2, decay_rate=0.999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 5.00 sec, Training loss 0.0000000000
    Stop optimization as mean training loss 0.0000000000 is below tolerance 0.0000010000.
Excluding fixed points with squared speed above tolerance 1e-07:
    Kept 1257/1257 fixed points with tolerance under 1e-07.
Excluding non-unique fixed points:
    Kept 1257/1257 unique fixed points with uniqueness tolerance 0.025.
Losses of fixed points:
[0. 0. 0. ... 0. 0. 0.]

Visualize fixed points

[6]:
pca = PCA(2)
fp_pcs = pca.fit_transform(finder.fixed_points['u'])
plt.plot(fp_pcs[:, 0], fp_pcs[:, 1], 'x', label='fixed points')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('Fixed points PCA')
plt.legend()
plt.show()
_images/dynamics_analysis_highdim_CANN_9_0.png
[7]:
fps = finder.fixed_points['u']
plot_ids = (10, 100, 200, 300,)
plot_ids = np.asarray(plot_ids)

for i in plot_ids:
  plt.plot(model.x, fps[i], label=f'FP-{i}')
plt.legend()
plt.show()
_images/dynamics_analysis_highdim_CANN_10_0.png

Verify the stabilities of fixed points

[8]:
from jax.tree_util import tree_map

_ = finder.compute_jacobians(
  tree_map(lambda x: x[plot_ids], finder._fixed_points),
  plot=True
)
_images/dynamics_analysis_highdim_CANN_12_0.png

Gap junction-coupled FitzHugh-Nagumo Model

[1]:
import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')
bp.math.enable_x64()
[2]:
class GJCoupledFHN(bp.DynamicalSystem):
  def __init__(self, num=2, method='exp_auto'):
    super(GJCoupledFHN, self).__init__()

    # parameters
    self.num = num
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5
    self.gjw = 0.0001

    # variables
    self.V = bm.Variable(bm.random.uniform(-2, 2, num))
    self.w = bm.Variable(bm.random.uniform(-2, 2, num))
    self.Iext = bm.Variable(bm.zeros(num))

    # functions
    self.int_V = bp.odeint(self.dV, method=method)
    self.int_w = bp.odeint(self.dw, method=method)

  def dV(self, V, t, w, Iext=0.):
    gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw
    dV = V - V * V * V / 3 - w + Iext + gj
    return dV

  def dw(self, w, t, V):
    dw = (V + self.a - self.b * w) / self.tau
    return dw

  def update(self, tdi):
    self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
    self.w.value = self.int_w(self.w, tdi.t, self.V, tdi.dt)
[3]:
def analyze_net(num=2, gjw=0.01, Iext=bm.asarray([0., 0.6])):
    assert isinstance(Iext, (int, float)) or (len(Iext) == num)

    model = GJCoupledFHN(num)
    model.gjw = gjw
    model.Iext[:] = Iext

    # simulation
    runner = bp.DSRunner(model, monitors=['V'])
    runner.run(300.)
    bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',
                           plot_ids=list(range(model.num)), show=True)

    # analysis
    finder = bp.analysis.SlowPointFinder(f_cell=model,
                                         target_vars={'V': model.V, 'w': model.w})
    finder.find_fps_with_gd_method(
      candidates={'V': bm.random.normal(0., 2., (1000, model.num)),
                  'w': bm.random.normal(0., 2., (1000, model.num))},
      tolerance=1e-5,
      num_batch=200,
      optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)),
    )
    finder.filter_loss(1e-8)
    finder.keep_unique()

    print('fixed_points: ', finder.fixed_points)
    print('losses:', finder.losses)

    _ = finder.compute_jacobians(finder.fixed_points, plot=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

4D system

[4]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.6]))
_images/dynamics_analysis_highdim_gj_coupled_fhn_5_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.47 sec, Training loss 0.0002525318
    Batches 201-400 in 1.82 sec, Training loss 0.0002021251
    Batches 401-600 in 1.06 sec, Training loss 0.0001682558
    Batches 601-800 in 0.64 sec, Training loss 0.0001426594
    Batches 801-1000 in 0.88 sec, Training loss 0.0001222235
    Batches 1001-1200 in 1.86 sec, Training loss 0.0001054805
    Batches 1201-1400 in 0.51 sec, Training loss 0.0000914966
    Batches 1401-1600 in 0.60 sec, Training loss 0.0000796636
    Batches 1601-1800 in 0.75 sec, Training loss 0.0000695541
    Batches 1801-2000 in 2.01 sec, Training loss 0.0000608892
    Batches 2001-2200 in 0.63 sec, Training loss 0.0000534558
    Batches 2201-2400 in 0.63 sec, Training loss 0.0000470366
    Batches 2401-2600 in 1.92 sec, Training loss 0.0000414673
    Batches 2601-2800 in 0.86 sec, Training loss 0.0000366226
    Batches 2801-3000 in 0.44 sec, Training loss 0.0000323996
    Batches 3001-3200 in 0.47 sec, Training loss 0.0000287081
    Batches 3201-3400 in 0.65 sec, Training loss 0.0000254832
    Batches 3401-3600 in 0.53 sec, Training loss 0.0000226843
    Batches 3601-3800 in 0.80 sec, Training loss 0.0000202542
    Batches 3801-4000 in 1.89 sec, Training loss 0.0000181205
    Batches 4001-4200 in 0.64 sec, Training loss 0.0000162401
    Batches 4201-4400 in 0.52 sec, Training loss 0.0000145707
    Batches 4401-4600 in 2.00 sec, Training loss 0.0000130935
    Batches 4601-4800 in 0.67 sec, Training loss 0.0000118004
    Batches 4801-5000 in 0.59 sec, Training loss 0.0000106516
    Batches 5001-5200 in 0.86 sec, Training loss 0.0000096305
    Stop optimization as mean training loss 0.0000096305 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 306/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/306 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.1731516 , -0.73801606]]), 'w': array([[-0.59144118, -0.04753834]])}
losses: [6.8918834e-15]
_images/dynamics_analysis_highdim_gj_coupled_fhn_5_3.png
[5]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.1]))
_images/dynamics_analysis_highdim_gj_coupled_fhn_6_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 1.39 sec, Training loss 0.0002524889
    Batches 201-400 in 0.47 sec, Training loss 0.0002039485
    Batches 401-600 in 0.39 sec, Training loss 0.0001722068
    Batches 601-800 in 0.49 sec, Training loss 0.0001481642
    Batches 801-1000 in 0.67 sec, Training loss 0.0001288039
    Batches 1001-1200 in 0.69 sec, Training loss 0.0001127157
    Batches 1201-1400 in 2.42 sec, Training loss 0.0000990425
    Batches 1401-1600 in 0.71 sec, Training loss 0.0000872609
    Batches 1601-1800 in 0.45 sec, Training loss 0.0000770819
    Batches 1801-2000 in 1.56 sec, Training loss 0.0000682273
    Batches 2001-2200 in 1.07 sec, Training loss 0.0000605319
    Batches 2201-2400 in 0.59 sec, Training loss 0.0000538263
    Batches 2401-2600 in 0.38 sec, Training loss 0.0000479664
    Batches 2601-2800 in 1.37 sec, Training loss 0.0000428171
    Batches 2801-3000 in 0.77 sec, Training loss 0.0000382962
    Batches 3001-3200 in 0.53 sec, Training loss 0.0000343242
    Batches 3201-3400 in 0.39 sec, Training loss 0.0000307939
    Batches 3401-3600 in 1.54 sec, Training loss 0.0000276377
    Batches 3601-3800 in 0.71 sec, Training loss 0.0000248339
    Batches 3801-4000 in 0.41 sec, Training loss 0.0000223625
    Batches 4001-4200 in 0.46 sec, Training loss 0.0000202066
    Batches 4201-4400 in 0.58 sec, Training loss 0.0000182917
    Batches 4401-4600 in 0.47 sec, Training loss 0.0000165570
    Batches 4601-4800 in 1.13 sec, Training loss 0.0000150186
    Batches 4801-5000 in 0.96 sec, Training loss 0.0000136606
    Batches 5001-5200 in 0.50 sec, Training loss 0.0000124335
    Batches 5201-5400 in 0.35 sec, Training loss 0.0000113094
    Batches 5401-5600 in 1.40 sec, Training loss 0.0000103006
    Batches 5601-5800 in 0.83 sec, Training loss 0.0000093835
    Stop optimization as mean training loss 0.0000093835 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 485/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/485 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.19613916, -1.14106972]]), 'w': array([[-0.62017396, -0.55133715]])}
losses: [2.01213341e-24]
_images/dynamics_analysis_highdim_gj_coupled_fhn_6_3.png

8D system

[6]:
analyze_net(num=4, gjw=0.1, Iext=bm.asarray([0., 0., 0., 0.6]))
_images/dynamics_analysis_highdim_gj_coupled_fhn_8_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.58 sec, Training loss 0.0002563626
    Batches 201-400 in 0.63 sec, Training loss 0.0002068401
    Batches 401-600 in 1.46 sec, Training loss 0.0001746516
    Batches 601-800 in 0.53 sec, Training loss 0.0001505581
    Batches 801-1000 in 0.34 sec, Training loss 0.0001312724
    Batches 1001-1200 in 0.39 sec, Training loss 0.0001152650
    Batches 1201-1400 in 0.57 sec, Training loss 0.0001017105
    Batches 1401-1600 in 0.36 sec, Training loss 0.0000900637
    Batches 1601-1800 in 1.23 sec, Training loss 0.0000799450
    Batches 1801-2000 in 0.90 sec, Training loss 0.0000710771
    Batches 2001-2200 in 0.58 sec, Training loss 0.0000632785
    Batches 2201-2400 in 0.39 sec, Training loss 0.0000564160
    Batches 2401-2600 in 1.61 sec, Training loss 0.0000503661
    Batches 2601-2800 in 0.55 sec, Training loss 0.0000449990
    Batches 2801-3000 in 0.51 sec, Training loss 0.0000402453
    Batches 3001-3200 in 0.72 sec, Training loss 0.0000360206
    Batches 3201-3400 in 1.17 sec, Training loss 0.0000322757
    Batches 3401-3600 in 0.47 sec, Training loss 0.0000289431
    Batches 3601-3800 in 0.44 sec, Training loss 0.0000259914
    Batches 3801-4000 in 0.49 sec, Training loss 0.0000233745
    Batches 4001-4200 in 1.22 sec, Training loss 0.0000210534
    Batches 4201-4400 in 0.56 sec, Training loss 0.0000189919
    Batches 4401-4600 in 0.30 sec, Training loss 0.0000171629
    Batches 4601-4800 in 0.33 sec, Training loss 0.0000155418
    Batches 4801-5000 in 0.46 sec, Training loss 0.0000141007
    Batches 5001-5200 in 0.32 sec, Training loss 0.0000128158
    Batches 5201-5400 in 0.51 sec, Training loss 0.0000116671
    Batches 5401-5600 in 1.17 sec, Training loss 0.0000106446
    Batches 5601-5800 in 0.57 sec, Training loss 0.0000097291
    Stop optimization as mean training loss 0.0000097291 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 266/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/266 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.17905465, -1.19006061, -1.17904012, -0.81757202]]), 'w': array([[-0.59764936, -0.58889147, -0.59766085, -0.14496394]])}
losses: [5.78896779e-09]
_images/dynamics_analysis_highdim_gj_coupled_fhn_8_3.png

Hénon map

The Hénon map is a discrete-time dynamical system. It is one of the most studied examples of dynamical systems that exhibit chaotic behavior. The Hénon map takes a point \((x_n, y_n)\) in the plane and maps it to a new point

\[\begin{split}\begin{cases}x_{n+1} = 1-a x_n^2 + y_n\\y_{n+1} = b x_n.\end{cases}\end{split}\]
[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
class HenonMap(bp.DynamicalSystem):
  """Hénon map."""

  def __init__(self, num, a=1.4, b=0.3):
    super(HenonMap, self).__init__()

    # parameters
    self.a = a
    self.b = b
    self.num = num

    # variables
    self.x = bm.Variable(bm.zeros(num))
    self.y = bm.Variable(bm.zeros(num))

  def update(self, tdi):
    x_new = 1 - self.a * self.x * self.x + self.y
    self.y.value = self.b * self.x
    self.x.value = x_new
[3]:
map = HenonMap(4)
map.a = bm.asarray([0.5, 1.0, 1.4, 2.0])
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
[5]:
fig, gs = bp.visualize.get_figure(4, 1, 4, 6)
for i in range(map.num):
  fig.add_subplot(gs[i, 0])
  plt.plot(runner.mon.x[:, i], runner.mon.y[:, i], '.k')
  plt.xlabel('x')
  plt.ylabel('y')
  plt.title(f'a={map.a[i]}')
  if (i + 1) == map.num:
    plt.show()
_images/classical_dynamical_systems_henon_map_6_0.png

The strange attractor illustrated above is obtained for \(a=1.4\) and \(b=0.3\).

[6]:
map = HenonMap(1)
map.a = 0.2
map.b = 0.9991

runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], ',k')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.show()
_images/classical_dynamical_systems_henon_map_8_1.png
[7]:
map = HenonMap(1)
map.a = 0.2
map.b = -0.9999

runner = bp.DSRunner(map, monitors=['x', 'y'], dt=1.)
runner.run(10000)
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], ',k')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
_images/classical_dynamical_systems_henon_map_9_1.png

Logistic map

The logistic map is a one-dimensional discrete time dynamical system that is defined by the equation

\[x_{n+1} =\lambda x_{n}(1-x_{n})\]

For an initial value \(0\leq x_{0}\leq1\) this map generates a sequence of values \(x_{0},x_{1},...x_{n},x_{n+1},...\) The growth parameter is chosen to be

\[0<\lambda\leq4\]

which implies that for all \(n\) the state variable will remain bounded in the unit interval. Despite its simplicity this famous dynamical system can exhibit an unbelievable dynamic richness.

[6]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[7]:
class LogisticMap(bp.NeuGroup):
  def __init__(self, num, mu0=2., mu1=4.):
    super(LogisticMap, self).__init__(num)

    self.mu = bm.linspace(mu0, mu1, num)
    self.x = bm.Variable(bm.ones(num) * 0.2)

  def update(self, tdi):
    self.x.value = self.mu * self.x * ( 1- self.x)
[8]:
map = LogisticMap(10000, 2, 4)
[9]:
runner = bp.DSRunner(map, monitors=['x'], dt=1.)
runner.run(1100)
[10]:
plt.plot(map.mu, runner.mon.x[1000:].T, ',k', alpha=0.25)
plt.xlabel('mu')
plt.ylabel('x')
plt.show()
_images/classical_dynamical_systems_logistic_map_6_0.png

Lorenz system

The Lorenz system, originally intended as a simplified model of atmospheric convection, has instead become a standard example of sensitive dependence on initial conditions; that is, tiny differences in the starting condition for the system rapidly become magnified. The system also exhibits what is known as the “Lorenz attractor”, that is, the collection of trajectories for different starting points tends to approach a peculiar butterfly-shaped region.

The Lorenz system includes three ordinary differential equations:

dx/dt = sigma ( y - x )
dy/dt = x ( rho - z ) - y
dz/dt = xy - beta z

where the parameters beta, rho and sigma are usually assumed to be positive. The classic case uses the parameter values

beta = 8 / 3
rho = 28
sigma = 10
[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
sigma = 10
beta = 8 / 3
rho = 28
[3]:
dx = lambda x, t, y: sigma * (y - x)
dy = lambda y, t, x, z: x * (rho - z) - y
dz = lambda z, t, x, y: x * y - beta * z
[4]:
integral = bp.odeint(bp.JointEq([dx, dy, dz]), method='exp_auto')
[5]:
runner = bp.IntegratorRunner(integral,
                             monitors=['x', 'y', 'z'],
                             inits=dict(x=8, y=1, z=1),
                             dt=0.01)
runner.run(100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[6]:
fig = plt.figure()
fig.add_subplot(111, projection='3d')
plt.plot(runner.mon.x[100:, 0], runner.mon.y[100:, 0], runner.mon.z[100:, 0])
plt.show()
_images/classical_dynamical_systems_lorenz_system_7_0.png

Mackey-Glass equation

The Mackey-Glass equation is the nonlinear time delay differential equation

\[\frac{dx}{dt} = \beta \frac{ x_{\tau} }{1+{x_{\tau}}^n}-\gamma x, \quad \gamma,\beta,n > 0,\]

where \(\beta, \gamma, \tau, n\) are real numbers, and \(x_{\tau}\) represents the value of the variable \(x\) at time \((t−\tau)\). Depending on the values of the parameters, this equation displays a range of periodic and chaotic dynamics.

[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
bm.set_dt(0.05)
[3]:
class MackeyGlassEq(bp.NeuGroup):
  def __init__(self, num, beta=2., gamma=1., tau=2., n=9.65):
    super(MackeyGlassEq, self).__init__(num)

    # parameters
    self.beta = beta
    self.gamma = gamma
    self.tau = tau
    self.n = n
    self.delay_len = int(tau/bm.get_dt())

    # variables
    self.x = bm.Variable(bm.zeros(num))
    self.x_delay = bm.LengthDelay(
      self.x,
      delay_len=self.delay_len,
      initial_delay_data=lambda sh, dtype: 1.2+0.2*(bm.random.random(sh)-0.5)
    )
    self.x_oldest = bm.Variable(self.x_delay(self.delay_len))

    # functions
    self.integral = bp.odeint(lambda x, t, x_tau: self.beta * x_tau / (1 + x_tau ** self.n) - self.gamma * x,
                              method='exp_auto')

  def update(self, tdi):
    self.x.value = self.integral(self.x.value, tdi.t, self.x_oldest.value, tdi.dt)
    self.x_delay.update(self.x.value)
    self.x_oldest.value = self.x_delay(self.delay_len)
[4]:
eq = MackeyGlassEq(1, beta=0.2, gamma=0.1, tau=17, n=10)
# eq = MackeyGlassEq(1, )
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
runner = bp.DSRunner(eq, monitors=['x', 'x_oldest'])
runner.run(1000)
[6]:
plt.plot(runner.mon.x[1000:, 0], runner.mon.x_oldest[1000:, 0])
plt.show()
_images/classical_dynamical_systems_mackey_glass_eq_7_0.png

Multiscroll chaotic attractor (多卷波混沌吸引子)

Multiscroll attractors also called n-scroll attractor include the Chen attractor, the Lu Chen attractor, the modified Chen chaotic attractor, PWL Duffing attractor, Rabinovich Fabrikant attractor, modified Chua chaotic attractor, that is, multiple scrolls in a single attractor.

[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
def run_and_visualize(runner, duration=100, dim=3, args=None):
  assert dim in [3, 2]
  if args is None:
      runner.run(duration)
  else:
      runner.run(duration, args=args)

  if dim == 3:
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for i in range(runner.mon.x.shape[1]):
      plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i], runner.mon.z[100:, i])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
  else:
    for i in range(runner.mon.x.shape[1]):
      plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i])
    plt.xlabel('x')
    plt.xlabel('y')
  plt.show()

Chen attractor

The Chen system is defined as follows [1]

\[\begin{split}\begin{aligned} \frac{d x(t)}{d t} &=a(y(t)-x(t)) \\ \frac{d y(t)}{d t} &=(c-a) x(t)-x(t) z(t)+c y(t) \\ \frac{d z(t)}{d t} &=x(t) y(t)-b z(t) \end{aligned}\end{split}\]
[3]:
@bp.odeint(method='euler')
def chen_system(x, y, z, t, a=40, b=3, c=28):
  dx = a * (y - x)
  dy = (c - a) * x - x * z + c * y
  dz = x * y - b * z
  return dx, dy, dz
[4]:
runner = bp.IntegratorRunner(
    chen_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-0.1, y=0.5, z=-0.6),
    dt=0.001
)
run_and_visualize(runner, 100)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/classical_dynamical_systems_Multiscroll_attractor_7_2.png

Lu Chen attractor

An extended Chen system with multiscroll was proposed by Jinhu Lu (吕金虎) and Guanrong Chen [2].

\[\begin{split}\begin{aligned} \frac{d x(t)}{d t} &=a(y(t)-x(t)) \\ \frac{d y(t)}{d t} &=x(t)-x(t) z(t)+c y(t)+u \\ \frac{d z(t)}{d t} &=x(t) y(t)-b z(t) \end{aligned}\end{split}\]
  1. 当 u ≤-11 时,Lṻ Chen 混沌吸引子为左卷波混沌吸引子,

  2. 当u 在 -10 和 10 之间 时为麻花型吸引子,

  3. 当 u≥ 11 ,是右卷波混沌吸引子。

[5]:
@bp.odeint(method='rk4')
def lu_chen_system(x, y, z, t, a=36, c=20, b=3, u=-15.15):
  dx = a * (y - x)
  dy = x - x * z + c * y + u
  dz = x * y - b * z
  return dx, dy, dz
[6]:
runner = bp.IntegratorRunner(
    lu_chen_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.1, y=0.3, z=-0.6),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=-15.15),)

runner = bp.IntegratorRunner(
    lu_chen_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.1, y=0.3, z=-0.6),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=-8),)
_images/classical_dynamical_systems_Multiscroll_attractor_12_1.png
_images/classical_dynamical_systems_Multiscroll_attractor_12_3.png
[7]:
runner = bp.IntegratorRunner(
    lu_chen_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.1, y=0.3, z=-0.6),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=11),)
_images/classical_dynamical_systems_Multiscroll_attractor_13_1.png
[8]:
runner = bp.IntegratorRunner(
    lu_chen_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.1, y=0.3, z=-0.6),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(u=12),)
_images/classical_dynamical_systems_Multiscroll_attractor_14_1.png

Modified Lu Chen attractor

System equations:

\[\begin{split}\begin{aligned} \frac{d x(t)}{d t} &=a(y(t)-x(t)) \\ \frac{d y(t)}{d t} &=(c-a) x(t)-x(t) f+c y(t) \\ \frac{d z(t)}{d t} &=x(t) y(t)-b z(t) \end{aligned}\end{split}\]

where

\[{ f=d0z(t)+d1z(t-\tau )-d2\sin(z(t-\tau ))}\]
[9]:
class ModifiedLuChenSystem(bp.dyn.NeuGroup):
  def __init__(self, num, a=35, b=3, c=28, d0=1, d1=1, d2=0., tau=.2, dt=0.1):
    super(ModifiedLuChenSystem, self).__init__(num)

    # parameters
    self.a = a
    self.b = b
    self.c = c
    self.d0 = d0
    self.d1 = d1
    self.d2 = d2
    self.tau = tau
    self.delay_len = int(tau/dt)

    # variables
    self.z = bm.Variable(bm.ones(num) * 14)
    self.z_delay = bm.LengthDelay(self.z,
                                  delay_len=self.delay_len,
                                  initial_delay_data=14.)
    self.x = bm.Variable(bm.ones(num))
    self.y = bm.Variable(bm.ones(num))

    # functions
    def derivative(x, y, z, t, z_delay):
      dx = self.a * (y - x)
      f = self.d0 * z + self.d1 * z_delay - self.d2 * bm.sin(z_delay)
      dy = (self.c - self.a) * x - x * f + self.c * y
      dz = x * y - self.b * z
      return dx, dy, dz

    self.integral = bp.odeint(derivative, method='rk4')

  def update(self, tdi):
    self.x.value, self.y.value, self.z.value = self.integral(
        self.x.value, self.y.value, self.z.value, tdi.t,
        self.z_delay(self.delay_len), tdi.dt
    )
    self.z_delay.update(self.z)
[10]:
runner = bp.DSRunner(ModifiedLuChenSystem(1, dt=0.001),
                     monitors=['x', 'y', 'z'],
                     dt=0.001)

run_and_visualize(runner, 50)
_images/classical_dynamical_systems_Multiscroll_attractor_18_1.png

Chua’s system

The classic Chua’s system is described by the following dimensionless equations [3]

\[\begin{split}\begin{aligned} &\dot{x}=\alpha(y-x)-\alpha f(x) \\ &\dot{y}=x-y+z \\ &\dot{z}=-\beta y-\gamma z \end{aligned}\end{split}\]

where \(\alpha\), \(\beta\), and \(\gamma\) are constant parameters and

\[f(x)=b x+0.5(a-b)(|x+1|-|x-1|)\]

with \(a,b\) being the slopes of the inner and outer segments of \(f(x)\). Note that the parameterγis usually ignored; that is, \(\gamma=0\).

[11]:
@bp.odeint(method='rk4')
def chua_system(x, y, z, t, alpha=10, beta=14.514, gamma=0, a=-1.197, b=-0.6464):
  fx = b * x + 0.5 * (a - b) * (bm.abs(x + 1) - bm.abs(x - 1))
  dx = alpha * (y - x) - alpha * fx
  dy = x - y + z
  dz = -beta * y - gamma * z
  return dx, dy, dz
[12]:
runner = bp.IntegratorRunner(
    chua_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.001, y=0, z=0.),
    dt=0.002
)
run_and_visualize(runner, 100)
_images/classical_dynamical_systems_Multiscroll_attractor_22_1.png
[13]:
runner = bp.IntegratorRunner(
    chua_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=9.4287, y=-0.5945, z=-13.4705),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(alpha=8.8, beta=12.0732, gamma=0.0052, a=-0.1768, b=-1.1468),)
_images/classical_dynamical_systems_Multiscroll_attractor_23_1.png
[14]:
runner = bp.IntegratorRunner(
    chua_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=[-6.0489, 6.0489],
               y=[0.0839, -0.0839],
               z=[8.7739, -8.7739]),
    dt=0.002
)
run_and_visualize(runner, 100, args=dict(alpha=8.4562, beta=12.0732, gamma=0.0052, a=-0.1768, b=-1.1468),)
_images/classical_dynamical_systems_Multiscroll_attractor_24_1.png

Modified Chua chaotic attractor

In 2001, Tang et al. proposed a modified Chua chaotic system:

\[\begin{split}\begin{aligned} &\frac{d x(t)}{d t}=\alpha(y(t)-h) \\ &\frac{d y(t)}{d t}=x(t)-y(t)+z(t) \\ &\frac{d z(t)}{d t}=-\beta y(t) \end{aligned}\end{split}\]

where

\[h=-b \sin \left(\frac{\pi x(t)}{2 a}+d\right)\]
[15]:
@bp.odeint(method='rk4')
def modified_chua_system(x, y, z, t, alpha=10.82, beta=14.286, a=1.3, b=.11, d=0):
  dx = alpha * (y + b * bm.sin(bm.pi * x / 2 / a + d))
  dy = x - y + z
  dz = -beta * y
  return dx, dy, dz
[16]:
runner = bp.IntegratorRunner(
    modified_chua_system,
    monitors=['x', 'y', 'z'],
    inits=dict(x=1, y=1, z=0.),
    dt=0.01
)
run_and_visualize(runner, 1000)
_images/classical_dynamical_systems_Multiscroll_attractor_28_1.png

PWL Duffing chaotic attractor

Aziz Alaoui investigated PWL Duffing equation in 2000:

\[\begin{split}\begin{aligned} &\frac{d x(t)}{d t}=y(t) \\ &\frac{d y(t)}{d t}=-m_{1} x(t)-\left(1 / 2\left(m_{0}-m_{1}\right)\right)(|x(t)+1|-|x(t)-1|)-e y(t)+\gamma \cos (\omega t) \end{aligned}\end{split}\]
[17]:
@bp.odeint(method='rk4')
def PWL_duffing_eq(x, y, t, e=0.25, m0=-0.0845, m1=0.66, omega=1, i=-14):
  gamma = 0.14 + i / 20
  dx = y
  dy = -m1 * x - (0.5 * (m0 - m1)) * (abs(x + 1) - abs(x - 1)) - e * y + gamma * bm.cos(omega * t)
  return dx, dy
[18]:
runner = bp.IntegratorRunner(
    PWL_duffing_eq,
    monitors=['x', 'y'],
    inits=dict(x=0, y=0),
    dt=0.01
)
run_and_visualize(runner, 1000, dim=2)
_images/classical_dynamical_systems_Multiscroll_attractor_32_1.png

Modified Lorenz chaotic system

Miranda & Stone proposed a modified Lorenz system:

\[\begin{split}\begin{aligned} &\frac{d x(t)}{d t}=1 / 3 *(-(a+1) x(t)+a-c+z(t) y(t))+\left((1-a)\left(x(t)^{2}-y(t)^{2}\right)+(2(a+c-z(t))) x(t) y(t)\right) \frac{1}{3 \sqrt{x(t)^{2}+y(t)^{2}}} \\ &\frac{d y(t)}{d t}=1 / 3((c-a-z(t)) x(t)-(a+1) y(t))+\left((2(a-1)) x(t) y(t)+(a+c-z(t))\left(x(t)^{2}-y(t)^{2}\right)\right) \frac{1}{3 \sqrt{x(t)^{2}+y(t)^{2}}} \\ &\frac{d z(t)}{d t}=1 / 2\left(3 x(t)^{2} y(t)-y(t)^{3}\right)-b z(t) \end{aligned}\end{split}\]
[19]:
@bp.odeint(method='euler')
def modified_Lorenz(x, y, z, t, a=10, b=8 / 3, c=137 / 5):
  temp = 3 * bm.sqrt(x * x + y * y)
  dx = (-(a + 1) * x + a - c + z * y) / 3 + ((1 - a) * (x * x - y * y) + (2 * (a + c - z)) * x * y) / temp
  dy = ((c - a - z) * x - (a + 1) * y) / 3 + (2 * (a - 1) * x * y + (a + c - z) * (x * x - y * y)) / temp
  dz = (3 * x * x * y - y * y * y) / 2 - b * z
  return dx, dy, dz
[20]:
runner = bp.IntegratorRunner(
    modified_Lorenz,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-8, y=4, z=10),
    dt=0.001
)
run_and_visualize(runner, 100, dim=3)
_images/classical_dynamical_systems_Multiscroll_attractor_36_1.png

References

  • [1] CHEN, GUANRONG; UETA, TETSUSHI (July 1999). “Yet Another Chaotic Attractor”. International Journal of Bifurcation and Chaos. 09 (7): 1465–1466.

  • [2] Chen, Guanrong; Jinhu Lu (2006). “Generating Multiscroll Chaotic Attractors: Theories, Methods and Applications” (PDF). International Journal of Bifurcation and Chaos. 16 (4): 775–858. Bibcode:2006IJBC…16..775L. doi:10.1142/s0218127406015179. Retrieved 2012-02-16.

  • [3] L. Fortuna, M. Frasca, and M. G. Xibilia, “Chuas circuit implementations: yesterday, today and tomorrow,” World Scientific, 2009.

Rabinovich-Fabrikant equations

The Rabinovich–Fabrikant equations are a set of three coupled ordinary differential equations exhibiting chaotic behaviour for certain values of the parameters. They are named after Mikhail Rabinovich and Anatoly Fabrikant, who described them in 1979.

\[\begin{split}\begin{aligned} &\dot{x}=y\left(z-1+x^{2}\right)+\gamma x \\ &\dot{y}=x\left(3 z+1-x^{2}\right)+\gamma y \\ &\dot{z}=-2 z(\alpha+x y) \end{aligned}\end{split}\]

where \(\alpha, \gamma\) are constants that control the evolution of the system.

[1]:
import brainpy as bp
import matplotlib.pyplot as plt
[2]:
@bp.odeint(method='rk4')
def rf_eqs(x, y, z, t, alpha=1.1, gamma=0.803):
    dx = y *(z-1+x*x) + gamma *x
    dy = x *(3*z+1-x*x) + gamma *y
    dz = -2*z*(alpha+x*y)
    return dx, dy, dz
[3]:
def run_and_visualize(runner, duration=100, dim=3, args=None):
  assert dim in [3, 2]
  if args is None:
    runner.run(duration)
  else:
    runner.run(duration, args=args)

  if dim == 3:
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for i in range(runner.mon.x.shape[1]):
      plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i], runner.mon.z[100:, i])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
  else:
    for i in range(runner.mon.x.shape[1]):
      plt.plot(runner.mon.x[100:, i], runner.mon.y[100:, i])
    plt.xlabel('x')
    plt.xlabel('y')
  plt.show()
[4]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 100, args=dict(alpha=1.1, gamma=0.87),)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_5_2.png
[5]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 100, args=dict(alpha=0.98, gamma=0.1),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_6_1.png
[6]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 300, args=dict(alpha=0.14, gamma=0.1),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_7_1.png
[7]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 50, dim=2, args=dict(alpha=0.05, gamma=0.1),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_8_1.png
[8]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 100, dim=2, args=dict(alpha=0.25, gamma=0.1),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_9_1.png
[9]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 100, dim=2, args=dict(alpha=1.1, gamma=0.86666666666666666667),)

runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=-1, y=0, z=0.5),
    dt=0.001
)
run_and_visualize(runner, 100, dim=3, args=dict(alpha=1.1, gamma=0.86666666666666666667),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_10_1.png
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_10_3.png
[10]:
runner = bp.IntegratorRunner(
    rf_eqs,
    monitors=['x', 'y', 'z'],
    inits=dict(x=0.1, y=-0.1, z=0.1),
    dt=0.001
)
run_and_visualize(runner, 60, dim=3, args=dict(alpha=0.05, gamma=0.1),)
_images/classical_dynamical_systems_Rabinovich_Fabrikant_eq_11_1.png

(Brette & Guigon, 2003): Reliability of spike timing

Implementation of the paper:

  • Brette R and E Guigon (2003). Reliability of Spike Timing Is a General Property of Spiking Model Neurons. Neural Computation 15, 279-308.

[1]:
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
[2]:
class ExternalInput(bp.DynamicalSystem):
  def __init__(self, num):
    super(ExternalInput, self).__init__()

    # parameters
    self.num = num

    # variables
    self.I = bm.Variable(bm.zeros(num))
[3]:
class LinearLeakyModel(bp.NeuGroup):
  r"""A nonlinear leaky model.

  .. math::

     \tau \frac{d V}{d t}=-V+R I(t)

  """

  def __init__(self, size, inputs):
    super(LinearLeakyModel, self).__init__(size)
    # parameters
    self.tau = 33  # ms
    self.R = 200 / 1e3  # Ω
    self.Vt = 15  # mV
    self.Vr = -5  # mV
    self.sigma = 1.
    self.inputs = inputs

    # variables
    self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # functions
    f = lambda V, t: (-V + self.R * self.inputs.I) / self.tau
    g = lambda V, t: self.sigma * bm.sqrt(2. / self.tau)
    self.integral = bp.sdeint(f=f, g=g)

  def update(self, tdi):
    self.inputs.update(tdi)
    self.V.value = self.integral(self.V, tdi.t, tdi.dt)
    self.spike.value = self.V >= self.Vt
    self.V.value = bm.where(self.spike, self.Vr, self.V)
[4]:
class NonlinearLeakyModel(bp.NeuGroup):
  r"""A nonlinear leaky model.

  .. math::

     \tau \frac{d V}{d t}=-a V^{3}+R I(t)

  """

  def __init__(self, size, inputs):
    super(NonlinearLeakyModel, self).__init__(size)
    # parameters
    self.tau = 33  # ms
    self.R = 200 / 1e3  # MΩ
    self.Vt = 15  # mV
    self.Vr = -5  # mV
    self.a = 4444 / 1e6  # V^(-2)
    self.sigma = 1.
    self.inputs = inputs

    # variables
    self.V = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # functions
    f = lambda V, t: (-self.a * V ** 3 + self.R * self.inputs.I) / self.tau
    g = lambda V, t: self.sigma * bm.sqrt(2. / self.tau)
    self.integral = bp.sdeint(f=f, g=g)

  def update(self, tdi):
    self.inputs.update(tdi)
    self.V.value = self.integral(self.V, tdi.t, tdi.dt)
    self.spike.value = self.V >= self.Vt
    self.V.value = bm.where(self.spike, self.Vr, self.V)
[5]:
class NonLeakyModel(bp.NeuGroup):
  r"""A non-leaky model.

  .. math::

     \tau \frac{d V}{d t}=V I(t)+k

  """

  def __init__(self, size, inputs):
    super(NonLeakyModel, self).__init__(size)

    # parameters
    self.Vt = 1
    self.Vr = 0
    self.tau = 33  # ms
    self.k = 1
    self.sigma = 0.02
    self.inputs = inputs

    # variables
    self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # functions
    f = lambda V, t: (V * self.inputs.I + self.k) / self.tau
    g = lambda V, t: self.sigma * (2 / self.tau) ** .5
    self.integral = bp.sdeint(f, g, method='euler')

  def update(self, tdi):
    self.inputs.update(tdi)
    self.V.value = self.integral(self.V, tdi.t, tdi.dt)
    self.spike.value = self.V >= self.Vt
    self.V.value = bm.where(self.spike, self.Vr, self.V)
[6]:
class PerfectIntegrator(bp.NeuGroup):
  r"""Integrate inputs.

  .. math::

     \tau \frac{d V}{d t}=f(t)

  where :math:`f(t)` is an input function.

  """

  def __init__(self, size, inputs):
    super(PerfectIntegrator, self).__init__(size)

    # parameters
    self.tau = 12.5  # ms
    self.Vt = 1
    self.Vr = 0
    self.sigma = 0.27
    self.inputs = inputs

    # variables
    self.V = bm.Variable(bm.random.uniform(self.Vr, self.Vt, self.num))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))

    # functions
    f = lambda V, t: self.inputs.I / self.tau
    g = lambda V, t: self.sigma * (2 / self.tau) ** .5
    self.integral = bp.sdeint(f, g, method='euler')

  def update(self, tdi):
    self.inputs.update(tdi)
    self.V.value = self.integral(self.V, tdi.t, tdi.dt)
    self.spike.value = self.V >= self.Vt
    self.V.value = bm.where(self.spike, self.Vr, self.V)
[7]:
def figure6(sigma=0.):
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.p = bm.linspace(0., 1., num)
      # variables
      self.B = bm.Variable(bm.zeros(1))

    def update(self, tdi):
      self.B[:] = 30 * bm.sin(40 * bm.pi * tdi.t / 1000)  # (1,)
      self.I.value = 85 + 40 * (1 - self.p) + self.p * self.B  # (num,)

  num = 500
  model = LinearLeakyModel(num, inputs=Noise(num))
  model.sigma = sigma

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(2000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 12)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[8]:
figure6(0.)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
_images/others_Brette_Guigon_2003_spike_timing_reliability_9_2.png
[9]:
figure6(1.)
_images/others_Brette_Guigon_2003_spike_timing_reliability_10_1.png
[10]:
def figure7(sigma):
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.p = bm.linspace(0., 1., num)
      # variables
      self.B = bm.Variable(bm.zeros(1))

    def update(self, tdi):
      self.B[:] = 150 * bm.sin(40 * bm.pi * tdi.t / 1000)  # (1,)
      self.I.value = 150 + self.p * self.B  # (num,)

  num = 500
  model = NonlinearLeakyModel(num, inputs=Noise(num))
  model.sigma = sigma

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(1000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[11]:
figure7(1.5)
_images/others_Brette_Guigon_2003_spike_timing_reliability_12_1.png
[12]:
def figure8(sigma):
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.p = bm.linspace(0., 1., num)
      # variables
      self.B = bm.Variable(bm.zeros(1))

    def update(self, tdi):
      self.B[:] = 150 * bm.sin(40 * bm.pi * tdi.t / 1000)  # (1,)
      self.I.value = 150 + self.p * self.B  # (num,)

  num = 500
  model = NonlinearLeakyModel(num, inputs=Noise(num))
  model.V.value = bm.random.uniform(model.Vr, model.Vt, model.num)
  model.sigma = sigma

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(1000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[13]:
figure8(1.5)
_images/others_Brette_Guigon_2003_spike_timing_reliability_14_1.png
[14]:
def figure9(sigma):
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.p = bm.linspace(0., 1., num)
      # variables
      self.B = bm.Variable(bm.zeros(1))

    def update(self, tdi):
      self.B[:] = bm.sin(40 * bm.pi * tdi.t / 1000)  # (1,)
      self.I.value = 112.5 + (75 + 37.5 * self.p) * self.B + 37.5 * self.p  # (num,)

  num = 500
  # B should be a triangular wave taking values between -1 and 1, with
  # rising and falling time drawn uniformly between 10 ms and 50 ms.
  # While we use sin wave instead.
  model = LinearLeakyModel(num, inputs=Noise(num))
  model.sigma = sigma

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(2000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 12)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[15]:
figure9(0.5)
_images/others_Brette_Guigon_2003_spike_timing_reliability_16_1.png
[16]:
def figure10():
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.tau = 10  # ms
      self.p = bm.linspace(0., 1., num)
      # variables
      self.x = bm.Variable(bm.zeros(1))
      self.B = bm.Variable(bm.zeros(1))
      # functions
      f = lambda x, t: -x / self.tau
      g = lambda x, t: bm.sqrt(2. / self.tau)
      self.integral = bp.sdeint(f=f, g=g)

    def update(self, tdi):
      self.x.value = self.integral(self.x, tdi.t, tdi.dt)  # (1,)
      self.B.value = 2. / (1 + bm.exp(-2 * self.x)) - 1  # (1,)
      self.I.value = 0.5 + 3 * self.p * self.B  # (num,)

  num = 500
  model = NonLeakyModel(num, inputs=Noise(num))

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(1000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[17]:
def figure11(sigma):
  class Noise(ExternalInput):
    def __init__(self, num):
      super(Noise, self).__init__(num)
      # parameters
      self.tau = 10  # ms
      self.p = bm.linspace(0., 1., num)
      # variables
      self.x = bm.Variable(bm.zeros(1))
      self.B = bm.Variable(bm.zeros(1))
      # functions
      f = lambda x, t: -x / self.tau
      g = lambda x, t: bm.sqrt(2. / self.tau)
      self.integral = bp.sdeint(f=f, g=g)

    def update(self, tdi):
      self.x.value = self.integral(self.x, tdi.t, tdi.dt)  # (1,)
      self.B.value = 2. / (1 + bm.exp(-2 * self.x)) - 1  # (1,)
      self.I.value = 0.5 + self.p * self.B  # (num,)

  num = 500
  model = PerfectIntegrator(num, inputs=Noise(num))
  model.sigma = sigma

  runner = bp.DSRunner(model, monitors=['inputs.B', 'spike'])
  runner.run(1000.)

  fig, gs = bp.visualize.get_figure(2, 1, 4, 8)
  fig.add_subplot(gs[0, 0])
  bp.visualize.line_plot(runner.mon.ts, runner.mon['inputs.B'], title='shared inputs')
  fig.add_subplot(gs[1, 0])
  bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], ylabel='p')
  plt.yticks(bm.linspace(0, num, 5).numpy(), bm.linspace(0., 1., 5).numpy())
  plt.show()
[18]:
figure11(0.)
_images/others_Brette_Guigon_2003_spike_timing_reliability_19_1.png
[19]:
figure11(0.27)
_images/others_Brette_Guigon_2003_spike_timing_reliability_20_1.png

Indices and tables