首页 > 解决方案 > 提高通用交换的性能

问题描述

语境

在 C 中实现适用于多种类型的泛型函数时,void*经常使用。该libc函数qsort()是一个经典的例子。在内部qsort()和许多其他算法都需要一个swap()函数。

通用交换的一个简单但典型的实现如下所示:

void swap(void* x, void* y, size_t size) {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
}

对于较大的类型,可以使用逐字节交换,或者malloc这会很慢,但这里的重点是当这个泛型swap()用于小类型时会发生什么。

更好的通用交换?

事实证明,如果我们匹配一些常见的类型大小(x86_64 上的 4 和 8 字节的 int 和 long)也包括 float、double、pointer 等,我们可以获得令人惊讶的性能提升:

void swap(void* x, void* y, size_t size) {
  if (size == sizeof(int)) {
    int t      = *((int*)x);
    *((int*)x) = *((int*)y);
    *((int*)y) = t;
  } else if (size == sizeof(long)) {
    long t      = *((long*)x);
    *((long*)x) = *((long*)y);
    *((long*)y) = t;
  } else {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
  }
}

注意:这显然可以改进为使用#if而不是if/else更多类型。

在以下通用实现的上下文中,与更标准的仅顶部交换quicksort()相比,上述交换为 10,000,000 个随机 int 排序提供了约 2 倍的性能改进。memcpy()这是在 ubuntu 20.04 上使用 gcc-9 或 clang-10 和-O3.

问题

这似乎是一个了不起的结果。

注意:我还没有检查生成的汇编代码。

#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

typedef bool (*cmp)(const void*, const void*);

bool cmp_ints_less(const void* a, const void* b) {
  return *(const int*)a < *(const int*)b;
}

bool cmp_ints_greater(const void* a, const void* b) {
  return *(const int*)a > *(const int*)b;
}

bool cmp_floats_less(const void* a, const void* b) {
  return *(const float*)a < *(const float*)b;
}

bool cmp_floats_greater(const void* a, const void* b) {
  return *(const float*)a > *(const float*)b;
}

bool cmp_doubles_less(const void* a, const void* b) {
  return *(const double*)a < *(const double*)b;
}

bool cmp_doubles_greater(const void* a, const void* b) {
  return *(const double*)a > *(const double*)b;
}

bool cmp_strs_less(const void* a, const void* b) {
  return strcmp(*((const char**)a), *((const char**)b)) < 0;
}

bool cmp_strs_greater(const void* a, const void* b) {
  return strcmp(*((const char**)a), *((const char**)b)) > 0;
}

void swap(void* x, void* y, size_t size) {
  if (size == sizeof(int)) {
    int t      = *((int*)x);
    *((int*)x) = *((int*)y);
    *((int*)y) = t;
  } else if (size == sizeof(long)) {
    long t      = *((long*)x);
    *((long*)x) = *((long*)y);
    *((long*)y) = t;
  } else {
    char t[size];
    memcpy(t, x, size);
    memcpy(x, y, size);
    memcpy(y, t, size);
  }
}

void* partition(void* start, void* end, size_t size, cmp predicate) {
  if (start == NULL || end == NULL || start == end) return start;
  char* storage = (char*)start;
  char* last    = (char*)end - size; // used as pivot
  for (char* current = start; current != last; current += size) {
    if (predicate(current, last)) {
      swap(current, storage, size);
      storage += size;
    }
  }
  swap(storage, last, size);
  return storage; // returns position of pivot
}

void quicksort(void* start, void* end, size_t size, cmp predicate) {
  if (start == end) return;
  void* middle = partition(start, end, size, predicate);
  quicksort(start, middle, size, predicate);
  quicksort((char*)middle + size, end, size, predicate);
}

void print(const int* start, int size) {
  for (int i = 0; i < size; ++i) printf("%3d", start[i]);
  printf("\n");
}

void rand_seed() {
  int   seed = 0;
  FILE* fp   = fopen("/dev/urandom", "re");
  if (!fp) {
    fprintf(stderr, "Warning: couldn't open source of randomness, falling back to time(NULL)");
    srand(time(NULL));
    return;
  }
  if (fread(&seed, sizeof(int), 1, fp) < 1) {
    fprintf(stderr, "Warning: couldn't read random seed, falling back to time(NULL)");
    fclose(fp);
    srand(time(NULL));
    return;
  }
  fclose(fp);
  srand(seed); // nice seed for rand()
}

int rand_range(int start, int end) {
  return start + rand() / (RAND_MAX / (end - start + 1) + 1);
}

int main() {
  // int demo
  rand_seed();
#define int_count 20
  int* ints = malloc(int_count * sizeof(int));
  if (!ints) {
    fprintf(stderr, "couldn't allocate memory");
    exit(EXIT_FAILURE);
  }
  for (int i = 0; i < int_count; ++i) ints[i] = rand_range(1, int_count / 2);
  print(ints, int_count);
  quicksort(ints, ints + int_count, sizeof(int), &cmp_ints_less);
  print(ints, int_count);
  free(ints);

  // string demo
  const char* strings[] = {
      "material", "rare",    "fade",      "aloof",  "way",  "torpid",
      "men",      "purring", "abhorrent", "unpack", "zinc", "unsightly",
  };
  const int str_count = sizeof(strings) / sizeof(strings[0]);
  quicksort(strings, strings + str_count, sizeof(char*), &cmp_strs_greater);
  for (int i = 0; i < str_count; ++i) printf("%s\n", strings[i]);

// double demo
#define dbl_count 20
  double doubles[dbl_count];
  for (int i = 0; i < dbl_count; ++i) doubles[i] = rand() / (RAND_MAX / 100.0);
  quicksort(doubles, doubles + dbl_count, sizeof(char*), &cmp_doubles_less);
  for (int i = 0; i < dbl_count; ++i) printf("%20.16f\n", doubles[i]);

  return EXIT_SUCCESS;
}

编辑:

仅供参考 Compiler Explorer 报告了替代通用的非常明显的以下程序集swap()

https://godbolt.org/z/GhvsY4

那里的样本main()是:

int main() {
  int two = 2;
  int three = 3;

  swap(&two, &three, sizeof(int));
  swap2(&two, &three, sizeof(int));

  return two - three;
}

下面是完整的汇编程序swap2(),但值得注意的是编译器已内联swap2() swap()包含对memcopy. 这可能是一些(全部?)的区别?

swap2:
        push    rbp
        mov     rbp, rsp
        push    r14
        mov     r14, rdi
        push    r13
        mov     r13, rsi
        push    r12
        push    rbx
        cmp     rdx, 4
        je      .L9
        mov     r12, rdx
        cmp     rdx, 8
        jne     .L7
        mov     rax, QWORD PTR [rdi]
        mov     rdx, QWORD PTR [rsi]
        mov     QWORD PTR [rdi], rdx
        mov     QWORD PTR [rsi], rax
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret
.L7:
        lea     rax, [rdx+15]
        mov     rbx, rsp
        mov     rsi, rdi
        and     rax, -16
        sub     rsp, rax
        mov     rdi, rsp
        call    memcpy
        mov     rdx, r12
        mov     rsi, r13
        mov     rdi, r14
        call    memcpy
        mov     rdx, r12
        mov     rsi, rsp
        mov     rdi, r13
        call    memcpy
        mov     rsp, rbx
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret
.L9:
        mov     eax, DWORD PTR [rdi]
        mov     edx, DWORD PTR [rsi]
        mov     DWORD PTR [rdi], edx
        mov     DWORD PTR [rsi], eax
        lea     rsp, [rbp-32]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     rbp
        ret

标签: cgenericsswapmemcpy

解决方案


这是否违反任何标准?

是的。

这是一个严格的别名违规,并且可能违反6.3.2.3 指针,第 7 段:“指向对象类型的指针可能会转换为指向不同对象类型的指针。如果生成的指针未正确对齐引用的类型,则行为未定义……”


推荐阅读