首页 > 解决方案 > python中用于实值函数的FFT卷积运算符

问题描述

我正在尝试从设计的模型中观察一些行动潜力。它是一个时空偏微分方程。所以,我在两边都应用了空间傅里叶变换。然后,模型归结为一组时间方程。空间域中的乘法是时间域中的卷积。我没有尝试定义时间卷积,而是让 python 通过在必要的部分上应用 signal.fftconvolve 来为我做这件事。但是,python 很难对实值函数进行卷积。这是代码;

"""
Bio-Physical Model for Cellular Electrophysiology
"""
import math
from odeintw import odeintw
import numpy as np
import random
import scipy as sp
from scipy import signal
from sympy import DiracDelta
from scipy.fftpack import rfft, irfft, fftfreq, fft, ifft, fft2, ifft2
import pylab as plt
from scipy.integrate import odeint


class SpatioTemporal_4AP_V3():

    def __init__(self):
        # ------------------------------------------------------------
        """ion channel pars"""

        self.D_K = 1.96 #5
        '''Ref; https://www.brown.edu/Departments/Engineering/Courses/En123/Lectures/electroN.htm'''
        '''diffusion coefficients'''

        self.V0 = -100 # membrane potential
        self.K0 = 5 # potassium
        self.phi0 = -45 # firing threshold
        self.sD0 = 1 # synaptic depression
        X0 = np.asarray([self.V0, self.K0, self.phi0, self.sD0])
        '''initial guesses in Euclidean space'''

        X0_hat = fft(X0)
        self.V_hat0 = X0_hat[0]
        self.K_hat0 = X0_hat[1]
        self.phi_hat0 = X0_hat[2]
        self.s_hat0 = X0_hat[3]
        '''initial guesses in Fourier domain'''

        self.kB = 1.380649 * 1e-23
        '''Boltzmann coefficient'''

        self.Temp = 37
        '''absolute temperature'''

        self.gK = 0.5
        '''potassium conductance'''

        self.q = 1.602176634 * 1e-19
        '''proton charge'''

        self.z_Na = 1
        self.z_Ca = 2
        self.z_Mg = 2
        self.z_Cl = -1
        self.z_K = 1
        '''ion charges'''

        self.V_k0 = 26.6 * math.log10(3 / 130)
        '''Reversal potassium par initial'''

        # ------------------------------------------------------------
        """conductance model pars"""

        self.C = 100
        '''cell capacitance'''

        # ------------------------------------------------------------
        """input current pars"""

        self.G_syn = 5
        '''Post synaptic charge'''

        self.sigma = 25
        '''noise amount'''

        # ------------------------------------------------------------
        """Time scale"""

        self.dt = 0.01
        self.t = np.arange(0.0, 10.0, self.dt)
        """ The time to integrate over in s """

        # ------------------------------------------------------------
        """Synaptic depression"""

        self.tau_D = 2
        '''in s'''

        self.delta_sD = 0.01
        '''synaptic resource'''

        self.xD_0 = 1
        '''synaptic resource initial val'''

        # -----------------------------------------------------------
        """Mean firing rate pars"""

        self.v_max = 100
        """maximum firing  rate in Hz"""

        self.V_th = 25
        """threshold potential"""

        self.k_v = 20.0
        """gain in mV"""

        self.fmax = 200 # populations maximal firing rate
        self.phi0 = -45 # firing threshold
        self.delta_Phi = 0.3
        self.beta = 1.5
        self.tau_Phi = 100

        Mfr = np.asarray([self.v_max, self.V_th, self.k_v, self.fmax, self.phi0, self.delta_Phi, self.beta, self.tau_Phi]) # np.asarray([self.fmax, self.phi0, self.delta_Phi, self.beta, self.tau_Phi])
        Mfr_hat = fft(Mfr)

        self.v_max_hat = Mfr_hat[0]
        self.V_th_hat = Mfr_hat[1]
        self.k_v_hat = Mfr_hat[2]

        self.fmax_hat = Mfr_hat[3]
        self.phi0_hat = Mfr_hat[4]
        self.delta_Phi = Mfr_hat[5]
        self.beta_hat = Mfr_hat[6]
        self.tau_Phi = Mfr_hat[7]

        # ----------------------------------------------------------
        """Discretized domain"""

        self.Ch_ID = 356 # pick a node among the vectorized array
        self.node_xdim = 64
        self.node_ydim = 64
        self.node_dim = self.node_xdim * self.node_ydim
        '''number of nodes in both directions'''

        nodex, nodey = np.mgrid[0: self.node_xdim - 1: 64j, 0: self.node_ydim - 1: 64j]  # 2-D
        self.nodexy = np.vstack((nodex.flatten(), nodey.flatten())).transpose()
        '''whole chip set as stacked'''

    # ============================================================

    def GenerateNoise(self, time):
        '''
        time dependant noise model to be contributed on the external input
        '''

        noise = np.random.normal(0, 1)
        noise = noise * time / np.sqrt(self.dt)
        return noise

    # ============================================================

    # def r_Electrode(self):
    #
    #     '''electrodes are self generated'''
    #
    #     # Source current position; assumed to be on the point [0,0]
    #     r0 = np.asarray([np.zeros((1, 2))])
    #     # Rest of the electrodes
    #     r = np.asarray([np.asarray([random.sample(range(0, 63), 1), random.sample(range(0, 63), 1)]) for _ in range(4096)])
    #     #r = r.reshape(self.Region, 1, 2)
    #     r_diff = r - r0
    #
    #     Abs_ElectrodeDiff = self.AbsEucDist(r_diff)
    #
    #     return r, r0, Abs_ElectrodeDiff

    # ============================================================

    def AbsEucDist(self, x):
        '''absolute euclidean distance'''
        x_squared = np.power(x, 2)
        x_sqr_sum = np.sum(x_squared)
        return math.sqrt(x_sqr_sum)

    # ============================================================

    def chipconv(self):

        (rads, phi) = self.cart2pol(self.nodexy[0:self.node_dim][:, 0], self.nodexy[0:self.node_dim][:, 1])
        '''conversion from cartesian to polar coordinates to get radius and elevation'''

        return rads, phi

    # ============================================================

    def ext_input(self, f, s_D, time):
        '''
        external input current I
        '''

        noise = self.GenerateNoise(time)
        (rads, phi) = self.chipconv()
        s = np.random.lognormal(0, 1)
        dW = np.random.normal(0, 1) /np.sqrt(self.dt) / np.sqrt(self.node_dim)
        W = np.cumsum(dW)

        # fig, ax = plt.subplots(figsize=(16, 6))
        # fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
        #
        # ax.plot(self.t, W)
        # ax.set_xlabel('Time in s')
        # ax.set_ylabel('W(t)')
        # ax.set_title('Discretized Brownian path')
        # ax.grid()
        # plt.show()

        return self.G_syn * (signal.fftconvolve(f, s_D, mode='same', axes=None) - 0.5*fft(f)) + self.sigma * W * np.sqrt(2*np.pi)*DiracDelta(rads[self.Ch_ID])#noise * np.sqrt(2*np.pi)*DiracDelta(rads[self.Ch_ID])

    # ============================================================

    def firing_rate_slope_control(self, V, phi):
        '''
        firing slope control as a part of mean firing rate
        '''

        return V - self.beta*phi

    # ============================================================

    def mean_firing_rate(self, V, phi):
        '''
        mean firing rate
        '''

        nu = self.firing_rate_slope_control(V, phi)

        fr_in = 2. / (1 + np.exp(2 * (self.V_th_hat - V) / self.k_v_hat)) - 1
        fr = np.clip(fr_in, a_min=0, a_max=None)
        v = self.v_max_hat * fr

        return v #self.fmax/(1 + np.exp(-2*nu/(20*self.beta)))

    # ============================================================

    def Potasium_Rev(self, K_o):
        """
        Nernst equation
        """
        return self.V_k0*np.log(K_o/130.)

    # ============================================================

    def I_inj(self, t):
        """
        External Current
        """
        Idv = np.where(np.logical_and(0 < t, t < 1), 150.0, 0.0) + np.where(np.logical_and(10.0 < t, t < 11.0), 50.0, 0.0)
        Idv_hat = fft(Idv)
        return Idv_hat

    # ============================================================

    def cart2pol(self, x, y):
        '''
        conversion from cartesion coordinates into polar coordinates
        '''

        rad = np.sqrt(x ** 2 + y ** 2)
        phi = np.arctan2(y, x)
        return (rad, phi)

    # ============================================================

    # def fft_coord_tr(self):
    #     '''
    #     conversion from cartesion coordinates into Fourier domain
    #     '''
    #
    #     return fft(self.node)

   # =============================================================


    def setofpdes(self, X, time):

        """set of temporal PDEs defined on the Fourier domain"""

        V, K, phi, s = X
        '''projecting all onto the Fourier domain'''

        f = self.mean_firing_rate(V, phi)

        (rads, phi) = self.chipconv()
        '''conversion from cartesian to polar coordinates to get radius and elevation'''

        dVdt = -1/self.C * (signal.fftconvolve(K, V, mode='same', axes=None)) + 1/self.C*self.ext_input(f, s, time)
        '''temporal diffusion for the membrane potential'''

        dKdt = -self.D_K * np.power(rads[self.Ch_ID], 2) * (K_hat + 2*self.z_K*self.q/(self.kB*self.Temp)*signal.fftconvolve(K, V, mode='same', axes=None))
        '''ion channels'''

        dphidt = 1/self.tau_Phi * (fft(self.phi0 - phi) + self.delta_Phi * fft(f))
        '''firing threshold'''

        dsdt = 1/self.tau_D*fft(1 - s) - self.delta_sD * fft(s * f) #np.sqrt(2*np.pi)*DiracDelta(rads[self.Ch_ID])
        '''synaptic depression'''

        dVdt, dKdt, dphidt, dsdt = ifft(np.asarray([dVdt, dKdt, dphidt, dsdt]))

        return dVdt, dKdt, dphidt, dsdt

    # =============================================================

    def main(self):

        """main executer"""
        initial_real = [self.V0, self.K0, self.phi0, self.sD0]
        initial_Fourier = [self.V_hat0, self.K_hat0, self.phi_hat0, self.s_hat0]
        #X_hat = []

        time = self.t
        X_hat = odeint(self.setofpdes, initial_real, time)
        # for count in range(len(self.t)):
        #     X_hat.append(odeint(self.setofpdes, initial_Fourier, time)[count])
        '''the ode solver'''
        #X_hat = np.vstack(X_hat)
        V = X_hat[:, 0] # ifft(ifft(X_hat[:, 0]))
        K = X_hat[:, 1] # ifft(ifft(X_hat[:, 4]))
        phi = X_hat[:, 2] # ifft(ifft(X_hat[:, 5]))
        s = X_hat[:, 3] # ifft(ifft(X_hat[:, 6]))
        '''back Fourier projection'''

        Mfr = self.mean_firing_rate(V, phi)

        fig, ax = plt.subplots(3, 2, figsize=(16, 6))
        fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)

        ax[0, 0].plot(time, V)
        #ax[0, 0].plot(time, V.imag, label='imaginary')
        ax[0, 0].set_xlabel('Time in s')
        ax[0, 0].set_ylabel('Amplitude')
        ax[0, 0].set_title('Membrane potential; V')
        #ax[0, 0].legend(loc='best')
        ax[0, 0].grid()

        ax[0, 1].plot(time, K, label='Potassium - K')
        ax[0, 1].legend(loc='best')
        ax[0, 1].set_title('Ion Channels')
        ax[0, 1].grid()

        ax[1, 0].plot(time, phi)
        ax[1, 0].set_title('Firing threshold')
        ax[1, 0].grid()

        ax[1, 1].plot(time, s)
        ax[1, 1].set_title('Synaptic depression')
        ax[1, 1].grid()

        ax[2, 0].plot(time, Mfr)
        ax[2, 0].set_title('Mean firing rate')
        ax[2, 0].grid()

        plt.show()


        return np.asarray([V, K, phi, s])


if __name__ == '__main__':
    runner = SpatioTemporal_4AP_V3()
    runner.main

标签: pythonfftconvolutiondifferential-equations

解决方案


推荐阅读