c++ - 如何知道推力::partition_copy 的结果中有多少元素
问题描述
我正在尝试使用推力库的 partition_copy 函数对数组进行分区。
我见过传递指针的例子,但我需要知道每个分区中有多少元素。
我尝试将设备向量作为 OutputIterator 参数传递,如下所示:
#include <thrust/device_vector.h>
#include <thrust/device_ptr.h>
#include <thrust/partition.h>
struct is_even {
__host__ __device__ bool operator()(const int &x) {
return (x % 2) == 0;
}
};
int N;
int *d_data;
cudaMalloc(&d_data, N*sizeof(int));
//... Some data is put in the d_data array
thrust::device_ptr<int> dptr_data(d_data);
thrust::device_vector<int> out_true(N);
thrust::device_vector<int> out_false(N);
thrust::partition_copy(dptr_data, dptr_data + N, out_true, out_false, is_even());
当我尝试编译时,出现此错误:
error: class "thrust::iterator_system<thrust::device_vector<int, thrust::device_allocator<int>>>" has no member "type"
detected during instantiation of "thrust::pair<OutputIterator1, OutputIterator2> thrust::partition_copy(InputIterator, InputIterator, OutputIterator1, OutputIterator2, Predicate) [with InputIterator=thrust::device_ptr<int>, OutputIterator1=thrust::device_vector<int, thrust::device_allocator<int>>, OutputIterator2=thrust::device_vector<int, thrust::device_allocator<int>>, Predicate=leq]"
所以我的问题是:如何使用推力::分区或推力::分区复制并知道每个分区中有多少元素?
解决方案
您的编译错误是由于您在此处传递向量而不是迭代器:
thrust::partition_copy(dptr_data, dptr_data + N, out_true, out_false, is_even());
^^^^^^^^^^^^^^^^^^^
相反,您应该基于这些容器传递迭代器:
thrust::partition_copy(dptr_data, dptr_data + N, out_true.begin(), out_false.begin(), is_even());
为了得到结果的长度,我们必须使用thrust::partition copy()的返回值:
返回一对 p 使得 p.first 是从 out_true 开始的输出范围的结尾,而 p.second 是从 out_false 开始的输出范围的结尾。
像这样的东西:
auto r = thrust::partition_copy(dptr_data, dptr_data + N, out_true.begin(), out_false.begin(), is_even());
int length_true = r.first - out_true.begin();
int length_false = r.second - out_false.begin();
请注意,类似的方法可以与其他推力算法一起使用。那些不返回元组的将更容易使用。
例如:
auto length = (thrust::remove_if(A.begin(), A.end(), ...) - A.begin());
推荐阅读
- html - 用 Beautiful Soup 将脚本刮成 Html
- python - 在测试期间模拟一个 dnspython dns 查询
- amazon-web-services - AWS CloudFormation 堆栈卡在状态 UPDATE_ROLLBACK_IN_PROGRESS
- c++ - 为什么 C++ 将大数四舍五入到 ceil 并将小数四舍五入到 floor
- qt - 如何在自定义组件中声明 EventHandler(插槽?)类型的属性并在其使用中为其分配函数
- flutter - W/Choreographer(11277):未来的帧时间是 13.988632 毫秒!检查图形 HAL 是否使用正确的时基生成 vsync 时间戳
- android - 在这个递归方法中返回什么?
- selenium - Azure DevOps VSTest (Selenium) 超时任务失败
- python-3.x - 是否可以将 cli 输出(字符串)转换为 Python 3 中的字典
- haskell - 编程语言中参数化多态函数(不是临时多态)操作的全部空间是多少?