Gap junction-coupled FitzHugh-Nagumo Model

Colab Open in Kaggle

[1]:
import brainpy as bp
import brainpy.math as bm

# bp.math.set_platform('cpu')
bp.math.enable_x64()
bp.__version__
[1]:
'2.4.3'
[2]:
class GJCoupledFHN(bp.DynamicalSystem):
  def __init__(self, num=2, method='exp_auto'):
    super(GJCoupledFHN, self).__init__()

    # parameters
    self.num = num
    self.a = 0.7
    self.b = 0.8
    self.tau = 12.5
    self.gjw = 0.0001

    # variables
    self.V = bm.Variable(bm.random.uniform(-2, 2, num))
    self.w = bm.Variable(bm.random.uniform(-2, 2, num))
    self.Iext = bm.Variable(bm.zeros(num))

    # functions
    self.int_V = bp.odeint(self.dV, method=method)
    self.int_w = bp.odeint(self.dw, method=method)

  def dV(self, V, t, w, Iext=0.):
    gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw
    dV = V - V * V * V / 3 - w + Iext + gj
    return dV

  def dw(self, w, t, V):
    dw = (V + self.a - self.b * w) / self.tau
    return dw

  def update(self):
    tdi = bp.share.get_shargs()
    self.V.value = self.int_V(self.V, tdi.t, self.w, self.Iext, tdi.dt)
    self.w.value = self.int_w(self.w, tdi.t, self.V, tdi.dt)
[3]:
def analyze_net(num=2, gjw=0.01, Iext=bm.asarray([0., 0.6])):
    assert isinstance(Iext, (int, float)) or (len(Iext) == num)

    model = GJCoupledFHN(num)
    model.gjw = gjw
    model.Iext[:] = Iext

    # simulation
    runner = bp.DSRunner(model, monitors=['V'])
    runner.run(300.)
    bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V',
                           plot_ids=list(range(model.num)), show=True)

    # analysis
    finder = bp.analysis.SlowPointFinder(f_cell=model,
                                         target_vars={'V': model.V, 'w': model.w})
    finder.find_fps_with_gd_method(
      candidates={'V': bm.random.normal(0., 2., (1000, model.num)),
                  'w': bm.random.normal(0., 2., (1000, model.num))},
      tolerance=1e-5,
      num_batch=200,
      optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)),
    )
    finder.filter_loss(1e-8)
    finder.keep_unique()

    print('fixed_points: ', finder.fixed_points)
    print('losses:', finder.losses)

    _ = finder.compute_jacobians(finder.fixed_points, plot=True)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

4D system

[4]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.6]))
../_images/dynamics_analysis_highdim_gj_coupled_fhn_5_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.81 sec, Training loss 0.0002788616
    Batches 201-400 in 0.48 sec, Training loss 0.0002076722
    Batches 401-600 in 0.58 sec, Training loss 0.0001604245
    Batches 601-800 in 0.52 sec, Training loss 0.0001257214
    Batches 801-1000 in 0.55 sec, Training loss 0.0000993401
    Batches 1001-1200 in 0.25 sec, Training loss 0.0000791463
    Batches 1201-1400 in 0.64 sec, Training loss 0.0000634978
    Batches 1401-1600 in 0.35 sec, Training loss 0.0000512977
    Batches 1601-1800 in 0.57 sec, Training loss 0.0000417576
    Batches 1801-2000 in 0.41 sec, Training loss 0.0000342319
    Batches 2001-2200 in 0.50 sec, Training loss 0.0000282288
    Batches 2201-2400 in 0.59 sec, Training loss 0.0000234155
    Batches 2401-2600 in 0.50 sec, Training loss 0.0000195315
    Batches 2601-2800 in 0.58 sec, Training loss 0.0000164139
    Batches 2801-3000 in 0.40 sec, Training loss 0.0000138706
    Batches 3001-3200 in 0.59 sec, Training loss 0.0000118313
    Batches 3201-3400 in 0.33 sec, Training loss 0.0000101996
    Batches 3401-3600 in 0.72 sec, Training loss 0.0000088474
    Stop optimization as mean training loss 0.0000088474 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 372/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/372 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.17315378, -0.73803683]]), 'w': array([[-0.59144223, -0.04754604]])}
losses: [1.781621e-27]
C:\Users\adadu\miniconda3\envs\brainpy\lib\site-packages\jax\_src\numpy\array_methods.py:329: FutureWarning: The arr.split() method is deprecated. Use jax.numpy.split instead.
  warnings.warn(
../_images/dynamics_analysis_highdim_gj_coupled_fhn_5_4.png
[5]:
analyze_net(num=2, gjw=0.1, Iext=bm.asarray([0., 0.1]))
../_images/dynamics_analysis_highdim_gj_coupled_fhn_6_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.42 sec, Training loss 0.0002744392
    Batches 201-400 in 0.63 sec, Training loss 0.0002038559
    Batches 401-600 in 0.27 sec, Training loss 0.0001579408
    Batches 601-800 in 0.71 sec, Training loss 0.0001249296
    Batches 801-1000 in 0.34 sec, Training loss 0.0001004665
    Batches 1001-1200 in 0.60 sec, Training loss 0.0000817813
    Batches 1201-1400 in 0.45 sec, Training loss 0.0000672077
    Batches 1401-1600 in 0.48 sec, Training loss 0.0000557249
    Batches 1601-1800 in 0.48 sec, Training loss 0.0000465356
    Batches 1801-2000 in 0.45 sec, Training loss 0.0000391051
    Batches 2001-2200 in 0.58 sec, Training loss 0.0000330272
    Batches 2201-2400 in 0.46 sec, Training loss 0.0000280838
    Batches 2401-2600 in 0.61 sec, Training loss 0.0000240266
    Batches 2601-2800 in 0.36 sec, Training loss 0.0000206586
    Batches 2801-3000 in 0.64 sec, Training loss 0.0000178943
    Batches 3001-3200 in 0.28 sec, Training loss 0.0000155499
    Batches 3201-3400 in 0.65 sec, Training loss 0.0000135563
    Batches 3401-3600 in 0.29 sec, Training loss 0.0000119020
    Batches 3601-3800 in 0.63 sec, Training loss 0.0000105507
    Batches 3801-4000 in 0.39 sec, Training loss 0.0000094129
    Stop optimization as mean training loss 0.0000094129 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 535/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 1/535 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.19613886, -1.14106509]]), 'w': array([[-0.62017365, -0.55133897]])}
losses: [1.17769156e-15]
../_images/dynamics_analysis_highdim_gj_coupled_fhn_6_3.png

8D system

[6]:
analyze_net(num=4, gjw=0.1, Iext=bm.asarray([0., 0., 0., 0.6]))
../_images/dynamics_analysis_highdim_gj_coupled_fhn_8_1.png
Optimizing with Adam(lr=ExponentialDecay(0.05, decay_steps=1, decay_rate=0.9999), last_call=-1), beta1=0.9, beta2=0.999, eps=1e-08) to find fixed points:
    Batches 1-200 in 0.41 sec, Training loss 0.0002981253
    Batches 201-400 in 0.20 sec, Training loss 0.0002175304
    Batches 401-600 in 0.21 sec, Training loss 0.0001681047
    Batches 601-800 in 0.22 sec, Training loss 0.0001327457
    Batches 801-1000 in 0.24 sec, Training loss 0.0001062666
    Batches 1001-1200 in 0.20 sec, Training loss 0.0000860732
    Batches 1201-1400 in 0.20 sec, Training loss 0.0000704338
    Batches 1401-1600 in 0.21 sec, Training loss 0.0000581533
    Batches 1601-1800 in 0.22 sec, Training loss 0.0000483691
    Batches 1801-2000 in 0.22 sec, Training loss 0.0000405031
    Batches 2001-2200 in 0.22 sec, Training loss 0.0000341309
    Batches 2201-2400 in 0.22 sec, Training loss 0.0000289916
    Batches 2401-2600 in 0.22 sec, Training loss 0.0000248071
    Batches 2601-2800 in 0.22 sec, Training loss 0.0000213434
    Batches 2801-3000 in 0.22 sec, Training loss 0.0000184814
    Batches 3001-3200 in 0.23 sec, Training loss 0.0000161170
    Batches 3201-3400 in 0.29 sec, Training loss 0.0000141462
    Batches 3401-3600 in 0.20 sec, Training loss 0.0000125133
    Batches 3601-3800 in 0.20 sec, Training loss 0.0000111212
    Batches 3801-4000 in 0.20 sec, Training loss 0.0000099064
    Stop optimization as mean training loss 0.0000099064 is below tolerance 0.0000100000.
Excluding fixed points with squared speed above tolerance 1e-08:
    Kept 313/1000 fixed points with tolerance under 1e-08.
Excluding non-unique fixed points:
    Kept 2/313 unique fixed points with uniqueness tolerance 0.025.
fixed_points:  {'V': array([[-1.17876245, -1.17875408, -1.1875546 , -0.81695988],
       [-1.17414082, -1.17455959, -1.17427155, -0.79030303]]), 'w': array([[-0.5975094 , -0.59751575, -0.59054519, -0.14463357],
       [-0.59622882, -0.59590929, -0.59613388, -0.14083089]])}
losses: [3.68611836e-09 8.30820343e-09]
../_images/dynamics_analysis_highdim_gj_coupled_fhn_8_3.png