python - 如何在 pycuda SourceModule 中生成高斯噪声?
问题描述
我正在尝试按照具有平均值和标准偏差的高斯定律生成随机数。现在我写这段代码。
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import numpy as np
import matplotlib.pyplot as plt
import time
class GN:
def __init__(self, ):
self.NbCells = int(1024 * 100)
self.init_vectors()
self.Create_GPU_SourceModule()
BLOCK_SIZE = 1024
self.grid = (int(self.NbCells / BLOCK_SIZE), 1, 1)
self.block = (BLOCK_SIZE, 1, 1)
def put_vect_on_GPU(self, Variable):
Variable_gpu = cuda.mem_alloc(Variable.nbytes)
cuda.memcpy_htod(Variable_gpu, Variable)
return Variable_gpu
def init_vectors(self):
self.V = self.put_vect_on_GPU(np.zeros((self.NbCells), dtype=np.float32))
self.m = self.put_vect_on_GPU(np.ones((self.NbCells), dtype=np.float32) * 120)
self.s = self.put_vect_on_GPU(np.ones((self.NbCells), dtype=np.float32) * 60)
def Create_GPU_SourceModule(self): #
self.mod = SourceModule("""
#include <math.h>
#include <curand.h>
#include <cuda.h>
__global__ void randgauss( float *m, float *s, float *res)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x;
int n=1;
curandGenerator_t gen ;
float d_normals;
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_MTGP32) ;
curandGenerateNormal(gen, &d_normals, n, m[idx], s[idx]);
res[idx] = d_normals;
}
""")
def updateParameters(self):
func = self.mod.get_function("sinus")
func(self.m, self.s, self.V, block=self.block, grid=self.grid)
def gen(self, N):
V = np.zeros((N, self.NbCells), dtype=np.float32)
for k in range(N):
self.updateParameters()
cuda.memcpy_dtoh(V[k, :], self.V)
return V
GN = GN()
t0 = time.time()
Vm = GN.gen(10000)
print('GPU', time.time() - t0)
plt.figure()
plt.subplot(111)
plt.plot(Vm[:, 0] ) # plt.plot(t,Vm[:,0::1000])
#
plt.show()
当我运行它时,我收到以下消息:
kernel.cu(13): error: calling a __host__ function("curandCreateGenerator") from a __global__ function("randgauss") is not allowed
kernel.cu(13): error: identifier "curandCreateGenerator" is undefined in device code
我不明白我应该如何curandGenerateNormal
正确使用该功能。
解决方案
推荐阅读
- reactjs - 为什么我在控制台中收到这样的警告,即标签按钮在此浏览器中无法识别
- ios - 架构 arm64 的未定义符号,似乎只是随机发生在某些使用 cocoapods 的第三方框架中
- sed - 如何使用 sed 从 CSV 文件中删除动态字符串?
- git - git中损坏的文件和树
- c# - 使用 .Net 中的 IAmsiStream 会导致 AccessViolationException
- javascript - 在子标题之间选择相邻的兄弟姐妹作为单独的组?
- angular - 使用 codepipeline 将 Angular 8 应用程序自动部署到 Elastic Beanstalk
- javascript - onMouseOver:防止桌面上的默认行为,同时保留在手机上
- java - 如何获取视频文件夹?
- android - 在 Xamarin Android 中为一组按钮设置动画