同步请求响应的实现

声明:本文代码均为 C++ 伪代码,仅用于表示含义,出于描述简单,语法使用 C++11。

最近在接手的工作中发现了一段有点问题的代码,功能上大致是要实现这样的目标:客户端发送一个请求,然后在当前线程内等待服务器的响应,但由于底层使用的网络库是异步 Reactor 模式的,需要进行一个基本的同步操作来实现发送线程内接收响应。以伪代码的形式表现大概如下:

struct Packet {};

class Synchronizer {};

void SendRequest(Packet *_req)
{
    unsigned int threadId = ::GetCurrentThreadId();
    _req->threadId = threadId;
    network->SendData(target, _req);

    Synchronizer::GetInstance()->RegisterThread(threadId);
    Packet *res = Synchronizer::GetInstance()->WaitThreadData(threadId);
}

显然,WaitThreadData 的操作不能因为收不到对端响应就一直阻塞,需要引入超时等机制来应对对端崩溃或响应慢的情况,代码里直接利用了信号量所提供的参数来实现超时机制,可以满足条件。但是,向调用者返回超时并不意味着这个报文消失了,底层所用的网络库提供的是保证可靠的报文模式,所以这个报文最终一定会到来,因此,必须处理掉它以避免影响到后续请求。

然而,接手的代码在实现上对这个问题没有考虑全面,它大致的伪代码如下:

struct Packet
{
    unsigned int threadId;
};
struct ThreadLocalStorage
{
    Packet *data;
    Semaphore *semaphore;
};

class Synchronizer
{
    Mutex m_mutex;
    map<unsigned int, ThreaLocalStorage> m_threads;

public:
    void OnReceiveData(Packet *_data)
    {
        MutexLocker locker(m_mutex);

        auto iter = m_threads.find(_data->threadId);
        if (m_threads.end() == iter)
        {
            LOG(WARNING, "Bad data");
            delete _data;
            return;
        }

        // 检查残留
        auto &tls = iter->second;
        if (tls.data != nullptr)
        {
            delete tls.data;
            tls.data = _data;
        }

        POST_SEMAPHORE(tls.semaphore);
    }

    void RegisterThread(unsigned int _threadId)
    {
        MutexLocker locker(m_mutex);
        m_threads[_threadId];
    }

    Packet* WaitThreadData(unsigned int _threadId)
    {
        delctype(m_threads.find(_threadId) iter;
        {
            MutexLocker locker(m_mutex);
            iter = m_threads.find(_threadId);
            if (m_threads.end() == iter)
            {
                LOG(WARNING, "Unregistered thread");
                return nullptr;
            }
        }

        auto &tls = iter->second;
        if (ERROR == WAIT_SEMAPHORE(tls.semaphore, DEFAULT_TIMEOUT)) // 10000ms
        {
            // 超时,会导致本次响应残留在 TLS 中
            LOG(WARNING, "Timeout");
            return nullptr;
        }

        MutexLocker locker(m_mutex);
        Packet *data = tls.data;
        tls.data = nullptr;
        return data;
    }
};

稍微解释下代码:

  1. 因为 Synchronizer 会被多个线程使用,所以使用了 TLS 的机制来维护各自的信号量及接收到的数据。
  2. OnReceiveData 是提供给底层网络库的回调函数,接收到数据后按照报文中的线程 ID 选择合适的 TLS 进行存储,这里加入了一定的检查残留机制
  3. RegisterThread 其实没做什么,创建了该线程 ID 需要的 TLS。
  4. WaitThreadData 根据参数选择合适的 TLS,等待其中的信号量以同步异步的数据。

这样的实现基本满足了同步的要求,数据异步到来,存储并通知信号量,信号量 WAIT 返回后得到数据。但是,当发生 WAIT 超时的时候,该请求对应的响应将残留在网络中,并且最终一定会被 Synchronizer 接收并 POST,从而导致此响应有可能会作为下一个请求的响应,即发生了报文错位

因为发送和接收实际上是两个队列,这种问题发生的概率还挺大的,以两个连续请求(Req1 和 Req2)为例:

  1. 发送 Req1,由于网络原因没有收到 Res1,WaitThreadData 超时返回 nullptr;
  2. 发送 Req2,调用 WaitThreadData 开始等待 Res2;
  3. 此时 Res1 到来,OnReceiveData 存储到 TLS 中并 POST 信号量;
  4. Req2 调用的 WaitThreadData 中的 WAIT 返回,获取 TLS 中的数据,返回了 Res1;
  5. 错误:Res1 响应了 Req2。

显然,这种错误有较大概率会一直继续下去。即使上层加入检查(比如该层的序列号)抛弃掉错误的 Res,依旧没办法主动帮助错位问题恢复。

恢复完全看运气,一种自然恢复的情况(虽然我觉得概率不大):

  1. 前 3 步同上;
  2. Req2 的 WaitThreadData 所在线程虽然可能已经得到信号量的 POST,但由于 CPU 调度,还没有去访问 TLS 获取数据;
  3. Res2 到来,该线程恰巧被调度了 CPU,检查 TLS 时发现了残留(Res1),清理;
  4. Req2 的 WaitThreadData 所在线程得到 CPU 时间,访问 TLS,获取数据;
  5. 正确:Res2 响应了 Req2。

这种恢复有个副作用,信号量被 POST 了两次,却只有一次成功的 WAIT,从而导致其中的值增加了 1,但是,由于每次获取数据时都要调用 WaitThreadData,这个副作用不会影响到响应的正确性。

为了改进这个问题,需要在 WAIT 得到数据后进行检查,判断它是否是所需要的数据。而为了能够判断,我们可以给每个请求都增加上序列号,并要求响应在回复的数据中也设置对应的序列号。序列号由 TLS 维护,保证能够标识该线程内的各个请求即可,此处直接利用了 64 位无符号数来代表(这个数真的是大啊)。

最终版本如下:

struct Packet
{
    enum Type { ... };

    Type type;
    unsigned long long seqNum;
    unsigned int threadId;
};

struct ThreadLocalStorage
{
    Packet *data = nullptr;
    Semaphore *semaphore = nullptr;
    unsigned long long curSeqNum = 0;
};

class Synchronizer
{
    Mutex m_mutex;
    map<unsigned int, ThreaLocalStorage> m_threads;

public: 
    void OnReceiveData(Packet *_data)
    {
        MutexLocker locker(m_mutex);

        auto iter = m_threads.find(_data->threadId);
        if (m_threads.end() == iter)
        {
            LOG(WARNING, "Bad data");
            delete _data;
            return;
        }

        // 检查残留
        auto &tls = iter->second;
        if (tls.data != nullptr)
        {
            delete tls.data;
            tls.data = nullptr;
        }

        tls.data = _data;
        POST_SEMAPHORE(tls.semaphore);
    }

    unsigned long long RegisterThread(unsigned int _threadId)
    {
        MutexLocker locker(m_mutex);
        // 实际实现的时候,请务必注意此处,因为对一个不存在的 key 取下标的话,在 map 里会先进行 insert,
        // 随后返回的时候可能会涉及到拷贝,而 TLS 对象里如果有信号量这种句柄的话,拷贝就必须自己稍微
        // 操心一下了(必须手写拷贝构造和拷贝赋值函数来确定语义)。当然更通用地方式是使用指针作为值类型,
        // 并且将此处修改成 find 检查结合 insert 操作。
        auto &tls = m_threads[_threadId];
        return tls.curSeqNum++;
    }

    Packet* WaitThreadData(unsigned int _threadId, unsigned long long _seqNum, unsigned int _timeout = 10000)
    {
        delctype(m_threads.find(_threadId) iter;
        {
            MutexLocker locker(m_mutex);
            iter = m_threads.find(_threadId);
            if (m_threads.end() == iter)
            {
                LOG(WARNING, "Unregistered thread");
                return nullptr;
            }
        }

        auto &tls = iter->second;
        while (true)
        {
            if (ERROR == WAIT_SEMAPHORE(tls.semaphore, _timeout))
            {
                // 超时,会导致本次响应残留在 TLS 中
                LOG(WARNING, "Timeout");
                return nullptr;
            }

            MutexLocker locker(m_mutex);
            Packet *data = tls.data;
            tls.data = nullptr;
            if (nullptr == data) continue;
            // 正确的序列号
            if (data->seqNum == _seqNum) return data;

            // 序列号不相等,小于为正常情况,大于则为异常(因为此序列号是由发送端提供的)
            if (tls.data->seqNum > _seqNum)
            {
                LOG(ERROR, "Exception");
            }
            else
            {
                LOG(WARNING, "Wrong data");
            }
            delete data;
        }

        return nullptr;
    }
};

void SendRequest(Packet *_req)
{
    unsigned int threadId = ::GetCurrentThreadId();
    unsigned long long seqNum = Synchonizer::GetInstance()->RegisterThread(threadId);
    _req->threadId = threadId;
    _req->seqNum = seqNum;
    network->Send(_req);

    Packet *res = Synchonizer::GetInstance()->WaitThreadData(threadId, seqNum);
    if (res != nullptr) { ... }
}

解释:

  1. RegisterThread 返回此报文应该附带的序列号。
  2. WaitThreadData 参数中指定了序列号,在获取数据后检查了得到的数据是否满足要求,不满足则继续 WAIT 直至满足。同时,这个检查的过程也修复了之前提到的信号量值的副作用。
  3. WaitThreadData 接收数据后加入了空判断,这是由于信号量返回和锁获取存在间隔导致的一个必要检查。举个不检查可能引起问题的场景:连续两个请求 Req1 和 Req2 的响应被 Res1 和 Res2 被判定为超时,它们都仍在路上,此时 Req3 发起,Res1 到来,信号量的 WAIT 被通过,在 WaitThreadData 获取锁之前,Res2 到来,信号量再次变为可触发的,WaitThreadData 发现 Res2 的序列号不对,抛弃并置空,进入下一个循环,显然,此时的 WAIT 可直接通过,但 data 字段是空的。解决方法就是跳过空的情况,等待下一次触发,即 Res3 的到来。

小小的总结下,这次改进主要是为了解决掉错位问题,并且追求最小的改动量,目前也只想到了这样的方法。如果要类比一个常见词汇的话,感觉这个问题的需求有点像同步 RPC。

问了下实现过异步 RPC 的同学,其实也在请求中加入了 Request ID 这样的标识,从而在异步接收时才能判别该调用哪个回调方法。这期间我也猜了下浏览器可能的实现,觉得主流浏览器应该也是为每个 HTTP Request 分配相应标识(感觉可以用 SOCKET 句柄),在收到 Response 数据后,通过标识找到对应 Request,进而找到对应的 Tab,由 Tab 实例去负责处理数据。

如果这个问题有其他方式的实现可以解决的话,欢迎指教,谢谢!

有啥想说的就留个言呗~

comments powered by Disqus