首页 > 解决方案 > 如何在多线程中使用 memory_buffer_alloc_init()?

问题描述

我有一个简单的代码用于对加密消息mbedtls执行操作。RSA decryption我正在尝试使用多线程同时执行相同的操作。我有兴趣使用stack memory而不是heap. mbedtls为基于堆栈的内存分配器提供memory_buffer_alloc.h。文件memory_buffer_alloc_init()说,

初始化使用基于堆栈的内存分配器。基于堆栈的分配器在提供的缓冲区内进行内存管理,并且不调用 malloc() 和 free()。它将全局 polarssl_malloc() 和 polarssl_free() 指针设置为它自己的函数。(如果定义了 POLARSSL_THREADING_C,则提供的 polarssl_malloc() 和 polarssl_free() 是线程安全的

因此,我在我的config.h文件中添加了以下配置,

#define POLARSSL_THREADING_PTHREAD
#define POLARSSL_THREADING_C
#define POLARSSL_MEMORY_C
#define POLARSSL_MEMORY_BUFFER_ALLOC_C
#define POLARSSL_PLATFORM_MEMORY
#define POLARSSL_PLATFORM_C

我的代码适用于单个线程。但是,当我增加线程数时,我的代码显示错误。以下是我的源代码,

#include "rsa/config.h"
#include "rsa/aes.h"
#include "rsa/bignum.h"
#include "rsa/rsa.h"
#include <sys/wait.h>
#include <sys/types.h>
#include <sys/stat.h> 
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/resource.h>
#include <pthread.h>
#include <time.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include "rsa/key.h"

#include "rsa/memory_buffer_alloc.h"
#include "rsa/memory.h"
#include "rsa/platform.h"
#include "rsa/threading.h"


#define NUM_OF_THREAD 2  

//threading_mutex_t lock;


void decryption(){

    // initialize stack memory
    unsigned char alloc_buf[10000];
    memory_buffer_alloc_init( &alloc_buf, sizeof(alloc_buf) );

    unsigned char private_encrypt[KEY_BUFFER_SIZE];
    int total_dec=5;
    unsigned char * buffer = 0;
    long length;
    unsigned char msg_decrypted[KEY_LEN];


    // reading encrypted msg
    FILE * fp2 = fopen ("msg.enc", "rb");
    int size1=KEY_BUFFER_SIZE;
    if(fp2){
        while(size1>0){
            fread(private_encrypt,1,sizeof (private_encrypt),fp2);
            size1=size1-1;
        }
    }
    fclose(fp2);

    // reading rsa-private key
    FILE * fp = fopen ("rsa_priv.txt", "rb");
    if (fp){
        fseek (fp, 0, SEEK_END);
        length = ftell (fp);
        fseek (fp, 0, SEEK_SET);
        buffer = calloc (1,length+1);
        if (buffer){
            fread (buffer, 1, length, fp);
        }
    fclose (fp);
    }

    // initialize rsaContext
    rsa_context rsaContext;
    rsa_init(&rsaContext,RSA_PKCS_V15, 0);
    rsaContext.len=KEY_LEN;

    // spliting keys and load into rsa context
    const char s[3] = "= ";
    char *token;
    int k=0, size;
    char *rest=buffer;

    // get the first token
    token = strtok_r(rest,s,&rest);

    // walk through other tokens
    while( token != NULL ) {
        size = strlen(token);
        switch (k) {
            case 1:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.N, 16, token);
                break;

            case 3:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.E, 16, token);
                break;

            case 5:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.D, 16, token);
                break;

            case 7:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.P, 16, token);
                break;

            case 9:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.Q, 16, token);
                break;

            case 11:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.DP, 16, token);
                break;

            case 13:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.DQ, 16, token);
                break;

            case 15:
                token[size-1]='\0';
                mpi_read_string(&rsaContext.QP, 16, token);
                break;
        }
        k=k+1;
        token = strtok_r(rest, "= \n", &rest);
    }


    if( rsa_private(&rsaContext,private_encrypt, msg_decrypted) != 0 ) {
        printf( "Decryption failed! %d\n", rsa_private(&rsaContext,private_encrypt, msg_decrypted));
    }else{
        printf("Decrypted plaintext-----> %s\n",msg_decrypted );
    }

   // free memory 
   memory_buffer_alloc_free();

}


void thread_function(void * input){

    printf("Test thread\n");
    int total_loop=5;
    while(total_loop>0){
        //pthread_mutex_lock(&lock); <-- multi-thread works with lock
        decryption();
        //pthread_mutex_unlock(&lock); 
        total_loop--;           
    }
}


int main(){ 
    int i;
    
    // total number of thread
    pthread_t ths[NUM_OF_THREAD];

    for (i = 0; i < NUM_OF_THREAD; i++) {
        pthread_create(&ths[i], NULL, thread_function, NULL);
    }

    for (i = 0; i < NUM_OF_THREAD; i++) {
        void* res;
        pthread_join(ths[i], &res);
    }
    return 0;
}

如果我使用mutex,上面的代码有效。我不想使用锁。花了很长时间。谁能告诉我我做错了什么?我该如何解决?

以前,我使用堆内存,我问为什么多个 pthread_create() 调用同一个函数最终会出现分段错误?. 感谢@Ingo Leonhardt @AlexM 的快速解决方案。我能够使用堆内存使用多线程。
源代码可在此处获得:https ://github.com/AlamShariful/stackMemory_multithread

标签: cmultithreadingmbedtlspolarssl

解决方案


推荐阅读