同步请求响应的实现
声明:本文代码均为 C++ 伪代码,仅用于表示含义,出于描述简单,语法使用 C++11。
最近在接手的工作中发现了一段有点问题的代码,功能上大致是要实现这样的目标:客户端发送一个请求,然后在当前线程内等待服务器的响应,但由于底层使用的网络库是异步 Reactor 模式的,需要进行一个基本的同步操作来实现发送线程内接收响应。以伪代码的形式表现大概如下:
struct Packet {};
class Synchronizer {};
void SendRequest(Packet *_req)
{
unsigned int threadId = ::GetCurrentThreadId();
_req->threadId = threadId;
network->SendData(target, _req);
Synchronizer::GetInstance()->RegisterThread(threadId);
Packet *res = Synchronizer::GetInstance()->WaitThreadData(threadId);
}
显然,WaitThreadData 的操作不能因为收不到对端响应就一直阻塞,需要引入超时等机制来应对对端崩溃或响应慢的情况,代码里直接利用了信号量所提供的参数来实现超时机制,可以满足条件。但是,向调用者返回超时并不意味着这个报文消失了,底层所用的网络库提供的是保证可靠的报文模式,所以这个报文最终一定会到来,因此,必须处理掉它以避免影响到后续请求。
然而,接手的代码在实现上对这个问题没有考虑全面,它大致的伪代码如下:
struct Packet
{
unsigned int threadId;
};
struct ThreadLocalStorage
{
Packet *data;
Semaphore *semaphore;
};
class Synchronizer
{
Mutex m_mutex;
map<unsigned int, ThreaLocalStorage> m_threads;
public:
void OnReceiveData(Packet *_data)
{
MutexLocker locker(m_mutex);
auto iter = m_threads.find(_data->threadId);
if (m_threads.end() == iter)
{
LOG(WARNING, "Bad data");
delete _data;
return;
}
// 检查残留
auto &tls = iter->second;
if (tls.data != nullptr)
{
delete tls.data;
tls.data = _data;
}
POST_SEMAPHORE(tls.semaphore);
}
void RegisterThread(unsigned int _threadId)
{
MutexLocker locker(m_mutex);
m_threads[_threadId];
}
Packet* WaitThreadData(unsigned int _threadId)
{
delctype(m_threads.find(_threadId) iter;
{
MutexLocker locker(m_mutex);
iter = m_threads.find(_threadId);
if (m_threads.end() == iter)
{
LOG(WARNING, "Unregistered thread");
return nullptr;
}
}
auto &tls = iter->second;
if (ERROR == WAIT_SEMAPHORE(tls.semaphore, DEFAULT_TIMEOUT)) // 10000ms
{
// 超时,会导致本次响应残留在 TLS 中
LOG(WARNING, "Timeout");
return nullptr;
}
MutexLocker locker(m_mutex);
Packet *data = tls.data;
tls.data = nullptr;
return data;
}
};
稍微解释下代码:
- 因为 Synchronizer 会被多个线程使用,所以使用了 TLS 的机制来维护各自的信号量及接收到的数据。
- OnReceiveData 是提供给底层网络库的回调函数,接收到数据后按照报文中的线程 ID 选择合适的 TLS 进行存储,这里加入了一定的检查残留机制。
- RegisterThread 其实没做什么,创建了该线程 ID 需要的 TLS。
- WaitThreadData 根据参数选择合适的 TLS,等待其中的信号量以同步异步的数据。
这样的实现基本满足了同步的要求,数据异步到来,存储并通知信号量,信号量 WAIT 返回后得到数据。但是,当发生 WAIT 超时的时候,该请求对应的响应将残留在网络中,并且最终一定会被 Synchronizer 接收并 POST,从而导致此响应有可能会作为下一个请求的响应,即发生了报文错位。
因为发送和接收实际上是两个队列,这种问题发生的概率还挺大的,以两个连续请求(Req1 和 Req2)为例:
- 发送 Req1,由于网络原因没有收到 Res1,WaitThreadData 超时返回 nullptr;
- 发送 Req2,调用 WaitThreadData 开始等待 Res2;
- 此时 Res1 到来,OnReceiveData 存储到 TLS 中并 POST 信号量;
- Req2 调用的 WaitThreadData 中的 WAIT 返回,获取 TLS 中的数据,返回了 Res1;
- 错误:Res1 响应了 Req2。
显然,这种错误有较大概率会一直继续下去。即使上层加入检查(比如该层的序列号)抛弃掉错误的 Res,依旧没办法主动帮助错位问题恢复。
恢复完全看运气,一种自然恢复的情况(虽然我觉得概率不大):
- 前 3 步同上;
- Req2 的 WaitThreadData 所在线程虽然可能已经得到信号量的 POST,但由于 CPU 调度,还没有去访问 TLS 获取数据;
- Res2 到来,该线程恰巧被调度了 CPU,检查 TLS 时发现了残留(Res1),清理;
- Req2 的 WaitThreadData 所在线程得到 CPU 时间,访问 TLS,获取数据;
- 正确:Res2 响应了 Req2。
这种恢复有个副作用,信号量被 POST 了两次,却只有一次成功的 WAIT,从而导致其中的值增加了 1,但是,由于每次获取数据时都要调用 WaitThreadData,这个副作用不会影响到响应的正确性。
为了改进这个问题,需要在 WAIT 得到数据后进行检查,判断它是否是所需要的数据。而为了能够判断,我们可以给每个请求都增加上序列号,并要求响应在回复的数据中也设置对应的序列号。序列号由 TLS 维护,保证能够标识该线程内的各个请求即可,此处直接利用了 64 位无符号数来代表(这个数真的是大啊)。
最终版本如下:
struct Packet
{
enum Type { ... };
Type type;
unsigned long long seqNum;
unsigned int threadId;
};
struct ThreadLocalStorage
{
Packet *data = nullptr;
Semaphore *semaphore = nullptr;
unsigned long long curSeqNum = 0;
};
class Synchronizer
{
Mutex m_mutex;
map<unsigned int, ThreaLocalStorage> m_threads;
public:
void OnReceiveData(Packet *_data)
{
MutexLocker locker(m_mutex);
auto iter = m_threads.find(_data->threadId);
if (m_threads.end() == iter)
{
LOG(WARNING, "Bad data");
delete _data;
return;
}
// 检查残留
auto &tls = iter->second;
if (tls.data != nullptr)
{
delete tls.data;
tls.data = nullptr;
}
tls.data = _data;
POST_SEMAPHORE(tls.semaphore);
}
unsigned long long RegisterThread(unsigned int _threadId)
{
MutexLocker locker(m_mutex);
// 实际实现的时候,请务必注意此处,因为对一个不存在的 key 取下标的话,在 map 里会先进行 insert,
// 随后返回的时候可能会涉及到拷贝,而 TLS 对象里如果有信号量这种句柄的话,拷贝就必须自己稍微
// 操心一下了(必须手写拷贝构造和拷贝赋值函数来确定语义)。当然更通用地方式是使用指针作为值类型,
// 并且将此处修改成 find 检查结合 insert 操作。
auto &tls = m_threads[_threadId];
return tls.curSeqNum++;
}
Packet* WaitThreadData(unsigned int _threadId, unsigned long long _seqNum, unsigned int _timeout = 10000)
{
delctype(m_threads.find(_threadId) iter;
{
MutexLocker locker(m_mutex);
iter = m_threads.find(_threadId);
if (m_threads.end() == iter)
{
LOG(WARNING, "Unregistered thread");
return nullptr;
}
}
auto &tls = iter->second;
while (true)
{
if (ERROR == WAIT_SEMAPHORE(tls.semaphore, _timeout))
{
// 超时,会导致本次响应残留在 TLS 中
LOG(WARNING, "Timeout");
return nullptr;
}
MutexLocker locker(m_mutex);
Packet *data = tls.data;
tls.data = nullptr;
if (nullptr == data) continue;
// 正确的序列号
if (data->seqNum == _seqNum) return data;
// 序列号不相等,小于为正常情况,大于则为异常(因为此序列号是由发送端提供的)
if (tls.data->seqNum > _seqNum)
{
LOG(ERROR, "Exception");
}
else
{
LOG(WARNING, "Wrong data");
}
delete data;
}
return nullptr;
}
};
void SendRequest(Packet *_req)
{
unsigned int threadId = ::GetCurrentThreadId();
unsigned long long seqNum = Synchonizer::GetInstance()->RegisterThread(threadId);
_req->threadId = threadId;
_req->seqNum = seqNum;
network->Send(_req);
Packet *res = Synchonizer::GetInstance()->WaitThreadData(threadId, seqNum);
if (res != nullptr) { ... }
}
解释:
- RegisterThread 返回此报文应该附带的序列号。
- WaitThreadData 参数中指定了序列号,在获取数据后检查了得到的数据是否满足要求,不满足则继续 WAIT 直至满足。同时,这个检查的过程也修复了之前提到的信号量值的副作用。
- WaitThreadData 接收数据后加入了空判断,这是由于信号量返回和锁获取存在间隔导致的一个必要检查。举个不检查可能引起问题的场景:连续两个请求 Req1 和 Req2 的响应被 Res1 和 Res2 被判定为超时,它们都仍在路上,此时 Req3 发起,Res1 到来,信号量的 WAIT 被通过,在 WaitThreadData 获取锁之前,Res2 到来,信号量再次变为可触发的,WaitThreadData 发现 Res2 的序列号不对,抛弃并置空,进入下一个循环,显然,此时的 WAIT 可直接通过,但 data 字段是空的。解决方法就是跳过空的情况,等待下一次触发,即 Res3 的到来。
小小的总结下,这次改进主要是为了解决掉错位问题,并且追求最小的改动量,目前也只想到了这样的方法。如果要类比一个常见词汇的话,感觉这个问题的需求有点像同步 RPC。
问了下实现过异步 RPC 的同学,其实也在请求中加入了 Request ID 这样的标识,从而在异步接收时才能判别该调用哪个回调方法。这期间我也猜了下浏览器可能的实现,觉得主流浏览器应该也是为每个 HTTP Request 分配相应标识(感觉可以用 SOCKET 句柄),在收到 Response 数据后,通过标识找到对应 Request,进而找到对应的 Tab,由 Tab 实例去负责处理数据。
如果这个问题有其他方式的实现可以解决的话,欢迎指教,谢谢!