首页 > 解决方案 > 这个发布/订阅实现安全吗?

问题描述

我正在开发用于进程内通信的自定义发布/订阅系统。假设发布者想要向所有注册的订阅者发送一条数据(一条消息),并且所有这些实体都可能在不同的线程中执行。发布者和订阅者之间的每个连接都实现shared_ptr<const Message>.

因为我需要避免动态分配,所以每个发布者都预先分配了一个指向消息的共享指针池。当用户调用Publisher::publish(msg)时,我从池中获取消息(如果可用),将用户提供的消息复制为*msgptr_from_pool = msg,并将此类共享指针推送到连接到所有注册订阅者的所有队列。为了从池中获取消息共享指针,我检查了它的use_count. 如果是 1,则表示没有订阅者仍在使用(阅读)该消息,我可以安全地回收它。请注意,我不会从池中删除消息。简单地说,一旦它们被推送到队列中,它们的使用计数将 > 1。

在订阅者端,用户从队列中消费消息共享指针,但从不在任何地方存储消费的共享指针。用户提供了此类回调,但void(const Message& msg)从未真正看到底层共享指针。因此,一旦所有订阅者都完成了一条消息,它的使用计数将为 1。

我的问题是:考虑到标准 C++ 要求在多线程上下文中共享指针的保证,这种模式安全吗?更具体地说,我需要保证当我看到共享指针的 use count = 1 时,我可以安全地修改指向的消息,因为没有其他线程可以读取它了

编辑:根据要求,我添加了一个可以编译和运行的简化代码

完全可编译的简化示例(需要最新版本的 Boost.Lockfree、C++11 支持并且必须针对 pthread 链接)

#include <boost/lockfree/spsc_queue.hpp>
#include <memory>
#include <functional>
#include <thread>


template <typename Msg>
class Publisher
{

public:

    typedef std::shared_ptr<Msg> MsgPtr;
    typedef boost::lockfree::spsc_queue<MsgPtr> Queue;
    typedef std::shared_ptr<Queue> QueuePtr;

    Publisher(QueuePtr q):
        _q(q)
    {
        // some pool size
        const int pool_size = 10;

        // preallocate msg pool
        for(int i = 0; i < pool_size; i++)
        {
            _pool.push_back(std::make_shared<Msg>());
        }
    }

    bool publish(const Msg& m)
    {
        // find a msg ptr with use_count = 1
        auto it = std::find_if(_pool.begin(), _pool.end(),
                               [](const MsgPtr& mptr)
                               {
                                   return mptr.unique();
                               });

        // it does not exist..
        if(it == _pool.end())
        {
            return false;
        }

        // it exists, fill it with the provided message
        **it = m;

        // try to push to the queue
        return _q->push(*it);
    }

private:

    QueuePtr _q;

    std::vector<MsgPtr> _pool;

};


template <typename Msg>
class Subscriber
{

public:

    typedef std::shared_ptr<Msg> MsgPtr;
    typedef boost::lockfree::spsc_queue<MsgPtr> Queue;
    typedef std::shared_ptr<Queue> QueuePtr;
    typedef std::function<void(const Msg& m)> Callback;

    Subscriber(QueuePtr q, Callback cb):
        _q(q), _cb(cb)
    {

    }

    void consume()
    {
        _q->consume_all([this](const MsgPtr& mptr)
                        {
                            _cb(*mptr);
                        });
    }

private:

    QueuePtr _q;
    Callback _cb;

};


// test util to generate random strings
std::string random_string(std::string::size_type length);

/**
 * @brief The MsgType struct is the message type
 * to be exchanged between threads. Contains a random 
 * string and its hash, to kind of check for data corruption
 * due to wrong synchronization
 */
struct MsgType
{
    std::string message;
    size_t hash;

    MsgType()
    {
        message = random_string(256);
        hash = std::hash<std::string>()(message);
    }

    bool check() const
    {
        return std::hash<std::string>()(message) == hash;
    }

};

std::atomic_int cb_called(0);

// test callback
void on_msg_recv(const MsgType& msg)
{

    if(!msg.check()) // we received garbage, abort
    {
        abort();
    }

    cb_called++; // increment global counter
}

// test main
int main()
{
    using Queue = boost::lockfree::spsc_queue<std::shared_ptr<MsgType>>;
    const int queue_size = 10;


    auto queue = std::make_shared<Queue>(queue_size);

    std::atomic_bool run(1);
    std::atomic_int pub_msgs(0);

    auto producer = [&]()
    {
        Publisher<MsgType> p(queue);

        while(run)
        {
            MsgType m;
            pub_msgs += p.publish(m);
        }
    };

    auto consumer = [&]()
    {
        Subscriber<MsgType> s(queue, on_msg_recv);

        while(run)
        {
            s.consume();
        }
    };

    std::thread tp(producer);
    std::thread tc(consumer);

    std::this_thread::sleep_for(std::chrono::seconds(10));

    run = 0;

    tp.join();
    tc.join();

    printf("Message published: %d \n", pub_msgs.load());
    printf("Callback called    %d times \n", cb_called.load());

}

std::string random_string(std::string::size_type length)
{
    static auto& chrs = "0123456789"
                        "abcdefghijklmnopqrstuvwxyz"
                        "ABCDEFGHIJKLMNOPQRSTUVWXYZ";

    thread_local static std::mt19937 rg{std::random_device{}()};
    thread_local static std::uniform_int_distribution<std::string::size_type> pick(0, sizeof(chrs) - 2);

    std::string s;

    s.reserve(length);

    while(length--)
        s += chrs[pick(rg)];

    return s;
}

在我的系统上,输出是

Message published: 967752 
Callback called    967751 times

标签: c++multithreadingshared-ptr

解决方案


推荐阅读