c - 提高通用交换的性能
问题描述
语境
在 C 中实现适用于多种类型的泛型函数时,void*
经常使用。该libc
函数qsort()
是一个经典的例子。在内部qsort()
和许多其他算法都需要一个swap()
函数。
通用交换的一个简单但典型的实现如下所示:
void swap(void* x, void* y, size_t size) {
char t[size];
memcpy(t, x, size);
memcpy(x, y, size);
memcpy(y, t, size);
}
对于较大的类型,可以使用逐字节交换,或者malloc
这会很慢,但这里的重点是当这个泛型swap()
用于小类型时会发生什么。
更好的通用交换?
事实证明,如果我们匹配一些常见的类型大小(x86_64 上的 4 和 8 字节的 int 和 long)也包括 float、double、pointer 等,我们可以获得令人惊讶的性能提升:
void swap(void* x, void* y, size_t size) {
if (size == sizeof(int)) {
int t = *((int*)x);
*((int*)x) = *((int*)y);
*((int*)y) = t;
} else if (size == sizeof(long)) {
long t = *((long*)x);
*((long*)x) = *((long*)y);
*((long*)y) = t;
} else {
char t[size];
memcpy(t, x, size);
memcpy(x, y, size);
memcpy(y, t, size);
}
}
注意:这显然可以改进为使用#if
而不是if/else
更多类型。
在以下通用实现的上下文中,与更标准的仅顶部交换quicksort()
相比,上述交换为 10,000,000 个随机 int 排序提供了约 2 倍的性能改进。memcpy()
这是在 ubuntu 20.04 上使用 gcc-9 或 clang-10 和-O3
.
问题
这似乎是一个了不起的结果。
- 这是否违反任何标准?
- 任何人都可以验证这一点吗?
- 是什么让这种收益成为可能?它是简单地复制“更广泛的词”还是一些编译器优化/内联在起作用?
- 如果它确实有效,为什么还没有完成呢?或者是吗?
注意:我还没有检查生成的汇编代码。
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
typedef bool (*cmp)(const void*, const void*);
bool cmp_ints_less(const void* a, const void* b) {
return *(const int*)a < *(const int*)b;
}
bool cmp_ints_greater(const void* a, const void* b) {
return *(const int*)a > *(const int*)b;
}
bool cmp_floats_less(const void* a, const void* b) {
return *(const float*)a < *(const float*)b;
}
bool cmp_floats_greater(const void* a, const void* b) {
return *(const float*)a > *(const float*)b;
}
bool cmp_doubles_less(const void* a, const void* b) {
return *(const double*)a < *(const double*)b;
}
bool cmp_doubles_greater(const void* a, const void* b) {
return *(const double*)a > *(const double*)b;
}
bool cmp_strs_less(const void* a, const void* b) {
return strcmp(*((const char**)a), *((const char**)b)) < 0;
}
bool cmp_strs_greater(const void* a, const void* b) {
return strcmp(*((const char**)a), *((const char**)b)) > 0;
}
void swap(void* x, void* y, size_t size) {
if (size == sizeof(int)) {
int t = *((int*)x);
*((int*)x) = *((int*)y);
*((int*)y) = t;
} else if (size == sizeof(long)) {
long t = *((long*)x);
*((long*)x) = *((long*)y);
*((long*)y) = t;
} else {
char t[size];
memcpy(t, x, size);
memcpy(x, y, size);
memcpy(y, t, size);
}
}
void* partition(void* start, void* end, size_t size, cmp predicate) {
if (start == NULL || end == NULL || start == end) return start;
char* storage = (char*)start;
char* last = (char*)end - size; // used as pivot
for (char* current = start; current != last; current += size) {
if (predicate(current, last)) {
swap(current, storage, size);
storage += size;
}
}
swap(storage, last, size);
return storage; // returns position of pivot
}
void quicksort(void* start, void* end, size_t size, cmp predicate) {
if (start == end) return;
void* middle = partition(start, end, size, predicate);
quicksort(start, middle, size, predicate);
quicksort((char*)middle + size, end, size, predicate);
}
void print(const int* start, int size) {
for (int i = 0; i < size; ++i) printf("%3d", start[i]);
printf("\n");
}
void rand_seed() {
int seed = 0;
FILE* fp = fopen("/dev/urandom", "re");
if (!fp) {
fprintf(stderr, "Warning: couldn't open source of randomness, falling back to time(NULL)");
srand(time(NULL));
return;
}
if (fread(&seed, sizeof(int), 1, fp) < 1) {
fprintf(stderr, "Warning: couldn't read random seed, falling back to time(NULL)");
fclose(fp);
srand(time(NULL));
return;
}
fclose(fp);
srand(seed); // nice seed for rand()
}
int rand_range(int start, int end) {
return start + rand() / (RAND_MAX / (end - start + 1) + 1);
}
int main() {
// int demo
rand_seed();
#define int_count 20
int* ints = malloc(int_count * sizeof(int));
if (!ints) {
fprintf(stderr, "couldn't allocate memory");
exit(EXIT_FAILURE);
}
for (int i = 0; i < int_count; ++i) ints[i] = rand_range(1, int_count / 2);
print(ints, int_count);
quicksort(ints, ints + int_count, sizeof(int), &cmp_ints_less);
print(ints, int_count);
free(ints);
// string demo
const char* strings[] = {
"material", "rare", "fade", "aloof", "way", "torpid",
"men", "purring", "abhorrent", "unpack", "zinc", "unsightly",
};
const int str_count = sizeof(strings) / sizeof(strings[0]);
quicksort(strings, strings + str_count, sizeof(char*), &cmp_strs_greater);
for (int i = 0; i < str_count; ++i) printf("%s\n", strings[i]);
// double demo
#define dbl_count 20
double doubles[dbl_count];
for (int i = 0; i < dbl_count; ++i) doubles[i] = rand() / (RAND_MAX / 100.0);
quicksort(doubles, doubles + dbl_count, sizeof(char*), &cmp_doubles_less);
for (int i = 0; i < dbl_count; ++i) printf("%20.16f\n", doubles[i]);
return EXIT_SUCCESS;
}
编辑:
仅供参考 Compiler Explorer 报告了替代通用的非常明显的以下程序集swap()
:
那里的样本main()
是:
int main() {
int two = 2;
int three = 3;
swap(&two, &three, sizeof(int));
swap2(&two, &three, sizeof(int));
return two - three;
}
下面是完整的汇编程序swap2()
,但值得注意的是编译器已内联swap2()
但不 swap()
包含对memcopy
. 这可能是一些(全部?)的区别?
swap2:
push rbp
mov rbp, rsp
push r14
mov r14, rdi
push r13
mov r13, rsi
push r12
push rbx
cmp rdx, 4
je .L9
mov r12, rdx
cmp rdx, 8
jne .L7
mov rax, QWORD PTR [rdi]
mov rdx, QWORD PTR [rsi]
mov QWORD PTR [rdi], rdx
mov QWORD PTR [rsi], rax
lea rsp, [rbp-32]
pop rbx
pop r12
pop r13
pop r14
pop rbp
ret
.L7:
lea rax, [rdx+15]
mov rbx, rsp
mov rsi, rdi
and rax, -16
sub rsp, rax
mov rdi, rsp
call memcpy
mov rdx, r12
mov rsi, r13
mov rdi, r14
call memcpy
mov rdx, r12
mov rsi, rsp
mov rdi, r13
call memcpy
mov rsp, rbx
lea rsp, [rbp-32]
pop rbx
pop r12
pop r13
pop r14
pop rbp
ret
.L9:
mov eax, DWORD PTR [rdi]
mov edx, DWORD PTR [rsi]
mov DWORD PTR [rdi], edx
mov DWORD PTR [rsi], eax
lea rsp, [rbp-32]
pop rbx
pop r12
pop r13
pop r14
pop rbp
ret
解决方案
这是否违反任何标准?
是的。
这是一个严格的别名违规,并且可能违反6.3.2.3 指针,第 7 段:“指向对象类型的指针可能会转换为指向不同对象类型的指针。如果生成的指针未正确对齐引用的类型,则行为未定义……”
推荐阅读
- hadoop - HiveAccessControlException 权限被拒绝:用户 [hive] 在 [hdfs://sandbox-....:8020/user/..] 上没有 [ALL] 权限(状态=42000,代码=40000)
- wso2 - WSO2IS 5.3.0 - 联合 SAML 和启用身份验证请求签名
- java - Jersey 在启动时失败并出现以下堆栈跟踪。我在 Cent Os 7.9 上,在 Tomcat 9 上运行 openjdk 8
- android - 从 https://dl.google.com/ 在 macos 上下载失败的原因可能是什么
- batch-file - 批处理文件过滤 ping 结果并附加 .csv
- c - 使用 DT_FILTER 隐藏 DSO 中的符号
- r - 当列表是动态的时,在列表中查找变量的特定项目
- cassandra - cassandra 购物车数据建模讲解
- python - 如何在每一行中添加制表符?
- sql-server - SQL Server - 映射表以识别要汇总的字段