Commit d57c3eee authored by John Selbie's avatar John Selbie

Support multithreading with SO_REUSEPORT option and new threading model

parent 7b74a63a
......@@ -295,8 +295,17 @@ HRESULT CStunSocket::InitCommon(int socktype, const CSocketAddress& addrlocal, S
if (fSetReuseFlag)
{
int socket_option_reuse = SO_REUSEADDR;
// for now, just do SO_REUSEPORT on the UDP thread
// There's still some validation and we need to do on the TCP side to decide how to enable threading
#ifdef SO_REUSEPORT
if (socktype == SOCK_DGRAM)
{
socket_option_reuse = SO_REUSEPORT;
}
#endif
int fAllow = 1;
ret = ::setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &fAllow, sizeof(fAllow));
ret = ::setsockopt(sock, SOL_SOCKET, socket_option_reuse, &fAllow, sizeof(fAllow));
ChkIf(ret == -1, ERRNOHR);
}
......@@ -329,4 +338,19 @@ HRESULT CStunSocket::TCPInit(const CSocketAddress& local, SocketRole role, bool
return InitCommon(SOCK_STREAM, local, role, fSetReuseFlag);
}
HRESULT CStunSocket::SetRecvTimeout(int milliseconds)
{
HRESULT hr = S_OK;
timeval tv = {};
int result = 0;
ChkIfA(_sock == -1, E_UNEXPECTED);
tv.tv_sec = milliseconds / 1000;
tv.tv_usec = (milliseconds % 1000) * 1000;
result = ::setsockopt(_sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
ChkIf(result == -1, ERRNOHR);
hr = S_OK;
Cleanup:
return hr;
}
......@@ -59,6 +59,8 @@ public:
HRESULT EnablePktInfoOption(bool fEnable);
HRESULT SetNonBlocking(bool fEnable);
HRESULT SetRecvTimeout(int milliseconds);
void UpdateAddresses();
HRESULT UDPInit(const CSocketAddress& local, SocketRole role, bool fSetReuseFlag);
......
......@@ -163,22 +163,20 @@ void DumpConfig(CStunServerConfig &config)
std::string strSocket;
if (config.fHasPP)
{
config.addrPP.ToString(&strSocket);
Logging::LogMsg(LL_DEBUG, "PP = %s", strSocket.c_str());
}
if (config.fHasPA)
if (config.fIsFullMode)
{
config.addrPA.ToString(&strSocket);
Logging::LogMsg(LL_DEBUG, "PA = %s", strSocket.c_str());
}
if (config.fHasAP)
if (config.fIsFullMode)
{
config.addrAP.ToString(&strSocket);
Logging::LogMsg(LL_DEBUG, "AP = %s", strSocket.c_str());
}
if (config.fHasAA)
if (config.fIsFullMode)
{
config.addrAA.ToString(&strSocket);
Logging::LogMsg(LL_DEBUG, "AA = %s", strSocket.c_str());
......@@ -398,13 +396,13 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
if (mode == Basic)
{
uint16_t port = (uint16_t)((int16_t)nPrimaryPort);
config.fIsFullMode = false;
// in basic mode, if no adapter is specified, bind to all of them
if (args.strPrimaryInterface.length() == 0)
{
if (family == AF_INET)
{
config.addrPP = CSocketAddress(0, port);
config.fHasPP = true;
}
else if (family == AF_INET6)
{
......@@ -412,7 +410,6 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
addr6.sin6_family = AF_INET6;
config.addrPP = CSocketAddress(addr6);
config.addrPP.SetPort(port);
config.fHasPP = true;
}
}
else
......@@ -425,7 +422,6 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
Chk(hr);
}
config.addrPP = addr;
config.fHasPP = true;
}
}
else // Full mode
......@@ -466,19 +462,17 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
config.addrPP = addrPrimary;
config.addrPP.SetPort(portPrimary);
config.fHasPP = true;
config.addrPA = addrPrimary;
config.addrPA.SetPort(portAlternate);
config.fHasPA = true;
config.addrAP = addrAlternate;
config.addrAP.SetPort(portPrimary);
config.fHasAP = true;
config.addrAA = addrAlternate;
config.addrAA.SetPort(portAlternate);
config.fHasAA = true;
config.fIsFullMode = true;
}
......@@ -501,7 +495,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
if (mode != Full)
{
Logging::LogMsg(LL_ALWAYS, "Error. --altadvertised was specified, but --mode param was not set to FULL.");
ChkIf(config.fHasAA, E_INVALIDARG);
ChkIf(config.fIsFullMode, E_INVALIDARG);
}
hr = ::NumericIPToAddress(family, pszAltAdvertised, &config.addrAlternateAdvertised);
......@@ -528,7 +522,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
Logging::LogMsg(LL_ALWAYS, "Error with --threading. required argument must be between 0 - 64");
Chk(hr);
}
config.nThreadsPerSocket = threadcount;
config.nThreading = threadcount;
}
*pConfigOut = config;
......@@ -858,7 +852,7 @@ int main(int argc, char** argv)
}
Logging::LogMsg(LL_DEBUG, "Server is exiting");
Logging::LogMsg(LL_DEBUG, "Server is exiting. This may take a few seconds to complete.");
for (auto itor = udpServers.begin(); itor != udpServers.end(); itor++)
......
......@@ -21,27 +21,22 @@
#include "stunsocketthread.h"
#include "server.h"
#include "ratelimiter.h"
#include "stunsocket.h"
CStunServerConfig::CStunServerConfig() :
fHasPP(false),
fHasPA(false),
fHasAP(false),
fHasAA(false),
nThreadsPerSocket(0),
fTCP(false),
nThreading(0),
nMaxConnections(0), // zero means default
fEnableDosProtection(false),
fReuseAddr(false)
fReuseAddr(false),
fIsFullMode(false),
fTCP(false)
{
;
}
CStunServer::CStunServer() :
_arrSockets() // zero-init
_tsa() // zero-init
{
;
}
......@@ -51,42 +46,88 @@ CStunServer::~CStunServer()
Shutdown();
}
HRESULT CStunServer::AddSocket(TransportAddressSet* pTSA, SocketRole role, const CSocketAddress& addrListen, const CSocketAddress& addrAdvertise, bool fSetReuseFlag)
HRESULT CStunServer::CreateSocket(SocketRole role, const CSocketAddress& addr, bool fReuseAddr)
{
HRESULT hr = S_OK;
auto spSocket = std::make_shared<CStunSocket>();
Chk(spSocket->UDPInit(addr, role, fReuseAddr));
ChkA(spSocket->EnablePktInfoOption(true)); // todo - why is this not always set inside UDPInit
ChkA(spSocket->SetRecvTimeout(10000));
_sockets.push_back(spSocket);
Cleanup:
return hr;
}
ASSERT(IsValidSocketRole(role));
HRESULT CStunServer::InitializeTSA(const CStunServerConfig& config)
{
_tsa = {};
Chk(_arrSockets[role].UDPInit(addrListen, role, fSetReuseFlag));
ChkA(_arrSockets[role].EnablePktInfoOption(true));
CSocketAddress addr[4] = {config.addrPP, config.addrPA, config.addrAP, config.addrAA};
CSocketAddress advertised[4] = {config.addrPrimaryAdvertised, config.addrPrimaryAdvertised, config.addrAlternateAdvertised, config.addrAlternateAdvertised};
bool validity[4] = {true, config.fIsFullMode, config.fIsFullMode, config.fIsFullMode};
static_assert(RolePP == (SocketRole)0);
static_assert(RolePA == (SocketRole)1);
static_assert(RoleAP == (SocketRole)2);
static_assert(RoleAA == (SocketRole)3);
#ifdef DEBUG
for (size_t i = 0; i < 4; i++)
{
if (validity[i])
{
CSocketAddress addrLocal = _arrSockets[role].GetLocalAddress();
_tsa.set[i].fValid = true;
if (advertised[i].IsIPAddressZero() == false)
{
// set the TSA for this socket to what the configuration wants us to advertise this address for in ORIGIN and OTHER address attributes
_tsa.set[i].addr = advertised[i];
_tsa.set[i].addr.SetPort(addr[i].GetPort());
}
else
{
// use the socket's IP and port (OK if this is INADDR_ANY)
// the message handler code will use the local ip for ORIGIN
_tsa.set[i].addr = _sockets[i]->GetLocalAddress();
}
}
}
return S_OK;
}
// addrListen is the address we asked the socket to listen on via a call to bind()
// addrLocal is the socket address returned by getsockname after the socket is binded
HRESULT CStunServer::CreateSockets(const CStunServerConfig& config)
{
// four possible config types:
// 1. basic mode with 1 thread and 1 socket
// 2. basic mode with N threads and N sockets
// 3. full mode with 1 thread and 4 sockets
// 4. full mode with 4N threads and 4N sockets
// I can't think of any case where addrListen != addrLocal
// the ports will be different if addrListen.GetPort() is 0, but that
// should never happen.
// 1 and 2 are equivalent, so when we are in basic mode, just assume config.nThreading == 0 is the the same as config.nThreading == 1
// So really, only three configs: basic, full single threaded, and full multi-threaded
// but if the assert below fails, I want to know about it
ASSERT(addrLocal.IsSameIP_and_Port(addrListen));
}
#endif
HRESULT hr = S_OK;
uint32_t numberOfThreads = (config.nThreading == 0) ? 1 : config.nThreading;
bool fReuseAddr = config.fReuseAddr || config.nThreading > 0; // allow SO_REUSEPORT to be set, even if threading was explicitly set to 1
ASSERT(_sockets.size() == 0);
_sockets.clear();
pTSA->set[role].fValid = true;
if (addrAdvertise.IsIPAddressZero() == false)
if (config.fIsFullMode == false)
{
// set the TSA for this socket to what the configuration wants us to advertise this address for in ORIGIN and OTHER address attributes
pTSA->set[role].addr = addrAdvertise;
pTSA->set[role].addr.SetPort(addrListen.GetPort()); // use the original port
for (size_t i = 0; i < numberOfThreads; i++)
{
Chk(CreateSocket(RolePP, config.addrPP, fReuseAddr));
}
}
else
{
pTSA->set[role].addr = addrListen; // use the socket's IP and port (OK if this is INADDR_ANY)
// in full mode, we create 4 * numberOfThreads sockets
for (size_t i = 0; i < numberOfThreads; i++)
{
Chk(CreateSocket(RolePP, config.addrPP, fReuseAddr));
Chk(CreateSocket(RolePA, config.addrPA, fReuseAddr));
Chk(CreateSocket(RoleAP, config.addrAP, fReuseAddr));
Chk(CreateSocket(RoleAA, config.addrAA, fReuseAddr));
}
}
Cleanup:
......@@ -96,10 +137,9 @@ Cleanup:
HRESULT CStunServer::Initialize(const CStunServerConfig& config)
{
HRESULT hr = S_OK;
int socketcount = 0;
std::shared_ptr<IStunAuth> _spAuth;
TransportAddressSet tsa = {};
std::shared_ptr<RateLimiter> spLimiter;
size_t numberOfThreads = 1;
_spAuth = nullptr;
// cleanup any thing that's going on now
Shutdown();
......@@ -108,74 +148,70 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// set the _spAuth member to reference it
// _spAuth = std::make_shared<CYourAuthProvider>();
// Create the sockets and initialize the TSA thing
if (config.fHasPP)
{
Chk(AddSocket(&tsa, RolePP, config.addrPP, config.addrPrimaryAdvertised, config.fReuseAddr));
socketcount++;
}
Chk(CreateSockets(config));
ChkA(InitializeTSA(config));
if (config.fHasPA)
{
Chk(AddSocket(&tsa, RolePA, config.addrPA, config.addrPrimaryAdvertised, config.fReuseAddr));
socketcount++;
}
if (config.fHasAP)
{
Chk(AddSocket(&tsa, RoleAP, config.addrAP, config.addrAlternateAdvertised, config.fReuseAddr));
socketcount++;
}
ChkIfA(_sockets.size() == 0, E_UNEXPECTED);
ChkIfA((config.fIsFullMode && _sockets.size() < 4), E_UNEXPECTED);
ChkIfA((config.fIsFullMode && _sockets.size() % 4), E_UNEXPECTED);
if (config.fHasAA)
if ((config.fIsFullMode == false) || (config.nThreading > 0))
{
Chk(AddSocket(&tsa, RoleAA, config.addrAA, config.addrAlternateAdvertised, config.fReuseAddr));
socketcount++;
numberOfThreads = _sockets.size();
}
ChkIf(socketcount == 0, E_INVALIDARG);
if (config.fEnableDosProtection)
{
Logging::LogMsg(LL_DEBUG, "Creating rate limiter for ddos protection\n");
// hard coding to 25000 ip addresses
bool fMultiThreaded = (config.nThreadsPerSocket > 0);
bool fMultiThreaded = (numberOfThreads > 1);
Logging::LogMsg(LL_DEBUG, "Creating rate limiter for ddos protection (%s)\n", fMultiThreaded ? "multi-threaded" : "single-threaded");
spLimiter = std::shared_ptr<RateLimiter>(new RateLimiter(25000, fMultiThreaded));
}
if (config.nThreadsPerSocket <= 0)
{
Logging::LogMsg(LL_DEBUG, "Configuring single threaded mode\n");
// create one thread for all the sockets
CStunSocketThread* pThread = new CStunSocketThread();
ChkIf(pThread==nullptr, E_OUTOFMEMORY);
Logging::LogMsg(LL_DEBUG, "Configuring multi threaded mode with %d sockets on %d threads\n", _sockets.size(), numberOfThreads);
_threads.push_back(pThread);
Chk(pThread->Init(_arrSockets, &tsa, _spAuth, (SocketRole)-1, spLimiter));
}
else
for (size_t i = 0; i < numberOfThreads; i++)
{
Logging::LogMsg(LL_DEBUG, "Configuring multi-threaded mode with %d threads per socket\n", config.nThreadsPerSocket);
std::vector<std::shared_ptr<CStunSocket>> arrayOfFourSockets;
SocketRole role;
// N threads for every socket
CStunSocketThread* pThread = nullptr;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
if (config.fIsFullMode)
{
if (_arrSockets[index].IsValid())
size_t baseindex = (i/4) * 4;
arrayOfFourSockets.push_back(_sockets[baseindex + 0]);
arrayOfFourSockets.push_back(_sockets[baseindex + 1]);
arrayOfFourSockets.push_back(_sockets[baseindex + 2]);
arrayOfFourSockets.push_back(_sockets[baseindex + 3]);
ASSERT(arrayOfFourSockets[0]->GetRole() == RolePP);
ASSERT(arrayOfFourSockets[2]->GetRole() == RolePA);
ASSERT(arrayOfFourSockets[3]->GetRole() == RoleAP);
ASSERT(arrayOfFourSockets[3]->GetRole() == RoleAA);
if (config.nThreading == 0)
{
SocketRole rolePrimaryRecv = _arrSockets[index].GetRole();
ASSERT(rolePrimaryRecv == (SocketRole)index);
for (int t = 0; t < config.nThreadsPerSocket; t++)
role = (SocketRole)-1; // the thread will recognize this as "listen on all four sockets"
ASSERT(numberOfThreads == 1);
}
else
{
pThread = new CStunSocketThread();
ChkIf(pThread==nullptr, E_OUTOFMEMORY);
_threads.push_back(pThread);
Chk(pThread->Init(_arrSockets, &tsa, _spAuth, rolePrimaryRecv, spLimiter));
role = (SocketRole)(i % 4);
}
}
else
{
arrayOfFourSockets.push_back(_sockets[i]);
arrayOfFourSockets.push_back(nullptr);
arrayOfFourSockets.push_back(nullptr);
arrayOfFourSockets.push_back(nullptr);
role = RolePP;
ASSERT(_sockets[i]->GetRole() == RolePP);
ASSERT(_sockets[i]->IsValid());
}
CStunSocketThread* pThread = new CStunSocketThread();
_threads.push_back(pThread);
Chk(pThread->Init(arrayOfFourSockets, _tsa, _spAuth, role, spLimiter));
}
......@@ -192,25 +228,22 @@ Cleanup:
HRESULT CStunServer::Shutdown()
{
size_t len;
Stop();
// release the sockets and the thread
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
for (auto pThread : _threads)
{
_arrSockets[index].Close();
delete pThread;
}
_threads.clear();
len = _threads.size();
for (size_t index = 0; index < len; index++)
for (auto spSocket : _sockets)
{
CStunSocketThread* pThread = _threads[index];
delete pThread;
_threads[index] = nullptr;
if (spSocket != nullptr)
{
spSocket->Close();
}
_threads.clear();
}
_sockets.clear();
_spAuth.reset();
......@@ -218,7 +251,6 @@ HRESULT CStunServer::Shutdown()
}
HRESULT CStunServer::Start()
{
HRESULT hr = S_OK;
......@@ -247,8 +279,6 @@ Cleanup:
HRESULT CStunServer::Stop()
{
size_t len = _threads.size();
for (size_t index = 0; index < len; index++)
......@@ -257,37 +287,84 @@ HRESULT CStunServer::Stop()
if (pThread != nullptr)
{
// set the "exit flag" that each thread looks at when it wakes up from waiting
pThread->SignalForStop(false);
pThread->SignalForStop();
}
}
PostWakeupMessages();
for (size_t index = 0; index < len; index++)
{
CStunSocketThread* pThread = _threads[index];
// Post a bunch of empty buffers to get the threads unblocked from whatever socket call they are on
// In multi-threaded mode, this may wake up a different thread. But that's ok, since all threads start and stop together
if (pThread != nullptr)
{
pThread->SignalForStop(true);
pThread->WaitForStopAndClose();
}
}
for (size_t index = 0; index < len; index++)
return S_OK;
}
void CStunServer::PostWakeupMessages()
{
// This is getting harder to maintain.
// When all the threads shared the same socket, we just had to invoke sendto once for each thread.
// Then each thread would wakeup from its recvfrom call and detect the exit condition.
// In the new multi-threaded design mode, where each socket has the SO_REUSEPORT option set, the packets
// from the same source ip:port get queued into the same listening socket. (And when the socket
// closes, there's no requeuing to another listening socket). And there's
// some hash table lookup by which the OS maps to each. So we can't guarantee that sending
// N packets will get received by all N threads.
// Workarounds:
// 1. Use a random port for each send, and loop multiple times
// 2. Use a SO_RCVTIMEO on each socket so that they eventually all wake up
// 3. Combo of 1 and 2 with a conditional variable
size_t count = _sockets.size();
for (size_t i = 0; i < count; i++)
{
CStunSocketThread* pThread = _threads[index];
CStunSocket sock;
CSocketAddress addrLocal;
if (pThread != nullptr)
if (_sockets[i]->GetLocalAddress().GetFamily() == AF_INET)
{
pThread->WaitForStopAndClose();
addrLocal = CSocketAddress(0,0);
}
else
{
sockaddr_in6 addr6 = {};
addr6.sin6_family = AF_INET6;
addrLocal = CSocketAddress(addr6);
}
// bind socket to port 0 (auto assign), using the same family as the socket we are trying to unblock
HRESULT hr = sock.UDPInit(addrLocal, RolePP, false);
ASSERT(SUCCEEDED(hr));
return S_OK;
}
if (SUCCEEDED(hr))
{
char data = 'x';
CSocketAddress addr(_sockets[i]->GetLocalAddress());
// If no specific adapter was binded to, IP will be 0.0.0.0
// Linux evidently treats 0.0.0.0 IP as loopback (and works)
// On Windows you can't send to 0.0.0.0. sendto will fail - switch to sending to localhost
if (addr.IsIPAddressZero())
{
CSocketAddress addrLocal;
CSocketAddress::GetLocalHost(addr.GetFamily(), &addrLocal);
addrLocal.SetPort(addr.GetPort());
addr = addrLocal;
}
::sendto(sock.GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength());
}
}
}
......@@ -14,7 +14,6 @@
limitations under the License.
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
......@@ -23,34 +22,25 @@
#include "stunauth.h"
#include "messagehandler.h"
class CStunServerConfig
{
public:
uint32_t nThreading; // when set to 0, all sockets on 1 thread. Otherwise, N threads per socket
bool fHasPP; // PP: Primary ip, Primary port
bool fHasPA; // PA: Primary ip, Alternate port
bool fHasAP; // AP: Alternate ip, Primary port
bool fHasAA; // AA: Alternate ip, Alternate port
int nThreadsPerSocket; // when set to > 0, each socket gets N threads assigned to it, otherwise, all sockets on 1 thread
bool fTCP; // if true, then use TCP instead of UDP
uint32_t nMaxConnections; // only valid for TCP (on a per-thread basis)
CSocketAddress addrPP; // address for PP
CSocketAddress addrPA; // address for PA
CSocketAddress addrAP; // address for AP
CSocketAddress addrAA; // address for AA
CSocketAddress addrPA; // address for PA, ignored if fIsFullMode==false
CSocketAddress addrAP; // address for AP, ignored if fIsFullMode==false
CSocketAddress addrAA; // address for AA, ignored if fIsFullMode==false
CSocketAddress addrPrimaryAdvertised; // public-IP for PP and PA (port is ignored)
CSocketAddress addrAlternateAdvertised; // public-IP for AP and AA (port is ignored)
bool fEnableDosProtection; // enable denial of service protection
bool fReuseAddr; // if true, the socket option SO_REUSEADDR will be set
bool fIsFullMode; // indicated that we are listening on PA, AP, and AA addresses above
bool fTCP; // if true, then use TCP instead of UDP
CStunServerConfig();
};
......@@ -59,13 +49,17 @@ public:
class CStunServer
{
private:
CStunSocket _arrSockets[4];
std::vector<std::shared_ptr<CStunSocket>> _sockets;
std::vector<CStunSocketThread*> _threads;
TransportAddressSet _tsa;
std::shared_ptr<IStunAuth> _spAuth;
HRESULT AddSocket(TransportAddressSet* pTSA, SocketRole role, const CSocketAddress& addrListen, const CSocketAddress& addrAdvertise, bool fSetReuseFlag);
HRESULT InitializeTSA(const CStunServerConfig& config);
HRESULT CreateSocket(SocketRole role, const CSocketAddress& addr, bool fReuseAddr);
HRESULT CreateSockets(const CStunServerConfig& config);
void PostWakeupMessages();
public:
CStunServer();
......@@ -78,7 +72,4 @@ public:
HRESULT Stop();
};
#endif /* SERVER_H */
......@@ -37,57 +37,65 @@ _tsa() // zero-init
CStunSocketThread::~CStunSocketThread()
{
SignalForStop(true);
SignalForStop();
WaitForStopAndClose();
}
void CStunSocketThread::ClearSocketArray()
{
_arrSendSockets = nullptr;
_arrSendSockets.clear();
_socks.clear();
}
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets, TransportAddressSet* pTSA, std::shared_ptr<IStunAuth> spAuth, SocketRole rolePrimaryRecv, std::shared_ptr<RateLimiter>& spLimiter)
void CStunSocketThread::DumpInitParams(std::vector<std::shared_ptr<CStunSocket>>& arrayOfFourSockets, const TransportAddressSet& tsa, SocketRole rolePrimaryRecv)
{
Logging::LogMsg(LL_VERBOSE, "CStunSocketThread initialized with:");
for (auto spStunSocket : arrayOfFourSockets)
{
Logging::LogMsg(LL_VERBOSE, "sock handle: %d", (spStunSocket == nullptr) ? -1 : (int)spStunSocket->GetSocketHandle());
}
}
HRESULT CStunSocketThread::Init(std::vector<std::shared_ptr<CStunSocket>>& arrayOfFourSockets, const TransportAddressSet& tsa, std::shared_ptr<IStunAuth> spAuth, SocketRole rolePrimaryRecv, std::shared_ptr<RateLimiter>& spLimiter)
{
HRESULT hr = S_OK;
DumpInitParams(arrayOfFourSockets, tsa, rolePrimaryRecv);
// if -1 was passed, then we are in "multi socket mode", otherwise, 1 socket to receive on
bool fSingleSocketRecv = ::IsValidSocketRole(rolePrimaryRecv);
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(arrayOfFourSockets == nullptr, E_INVALIDARG);
ChkIfA(pTSA == nullptr, E_INVALIDARG);
ChkIfA(arrayOfFourSockets.size() == 0, E_INVALIDARG);
// if this thread was configured to listen on a single socket (aka "multi-threaded mode"), then
// validate that it exists
if (fSingleSocketRecv)
{
ChkIfA(arrayOfFourSockets[rolePrimaryRecv].IsValid()==false, E_UNEXPECTED);
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == nullptr, E_UNEXPECTED);
ChkIfA(arrayOfFourSockets[rolePrimaryRecv]->IsValid()==false, E_UNEXPECTED);
}
_arrSendSockets = arrayOfFourSockets;
// initialize the TSA thing
_tsa = *pTSA;
_tsa = tsa;
if (fSingleSocketRecv)
{
// only one socket to listen on
_socks.push_back(&_arrSendSockets[rolePrimaryRecv]);
_socks.push_back(_arrSendSockets[rolePrimaryRecv]);
}
else
{
for (size_t i = 0; i < 4; i++)
for (auto spSocket : arrayOfFourSockets)
{
if (_arrSendSockets[i].IsValid())
if (spSocket != nullptr && spSocket->IsValid())
{
_socks.push_back(&_arrSendSockets[i]);
_socks.push_back(spSocket);
}
}
}
Chk(InitThreadBuffers());
_fNeedToExit = false;
......@@ -151,45 +159,10 @@ Cleanup:
return hr;
}
HRESULT CStunSocketThread::SignalForStop(bool fPostMessages)
HRESULT CStunSocketThread::SignalForStop()
{
HRESULT hr = S_OK;
_fNeedToExit = true;
// have the socket send a message to itself
// if another thread is sharing the same socket, this may wake that thread up to
// but all the threads should be started and shutdown together
if (fPostMessages)
{
for (size_t index = 0; index < _socks.size(); index++)
{
char data = 'x';
ASSERT(_socks[index] != nullptr);
CSocketAddress addr(_socks[index]->GetLocalAddress());
// If no specific adapter was binded to, IP will be 0.0.0.0
// Linux evidently treats 0.0.0.0 IP as loopback (and works)
// On Windows you can't send to 0.0.0.0. sendto will fail - switch to sending to localhost
if (addr.IsIPAddressZero())
{
CSocketAddress addrLocal;
CSocketAddress::GetLocalHost(addr.GetFamily(), &addrLocal);
addrLocal.SetPort(addr.GetPort());
addr = addrLocal;
}
::sendto(_socks[index]->GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength());
}
}
return hr;
}
......@@ -261,7 +234,7 @@ CStunSocket* CStunSocketThread::WaitForSocketData()
if (FD_ISSET(sock, &set))
{
pReadySocket = _socks[indexconverted];
pReadySocket = _socks[indexconverted].get(); // todo - let this method return a shared_ptr
break;
}
}
......@@ -277,7 +250,7 @@ void CStunSocketThread::Run()
size_t nSocketCount = _socks.size();
bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CStunSocket* pSocket = _socks[0];
CStunSocket* pSocket = _socks[0].get();
int ret;
char szIPRemote[100] = {};
char szIPLocal[100] = {};
......@@ -321,13 +294,26 @@ void CStunSocketThread::Run()
_spBufferIn->SetSize(0);
ret = ::recvfromex(pSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (ret < 0)
{
int err = errno;
if ((err == EAGAIN) || (err == EWOULDBLOCK))
{
Logging::LogMsg(LL_VERBOSE_EXTREME, "recvfromex returned timeout error");
}
else
{
Logging::LogMsg(LL_VERBOSE, "recvfromex returned error: %d", err);
}
continue;
}
// recvfromex no longer sets the port value on the local address
if (ret >= 0)
if (_fNeedToExit)
{
_msgIn.addrLocal.SetPort(pSocket->GetLocalAddress().GetPort());
break;
}
_msgIn.addrLocal.SetPort(pSocket->GetLocalAddress().GetPort());
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
......@@ -340,28 +326,16 @@ void CStunSocketThread::Run()
szIPLocal[0] = '\0';
}
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s on thread %lu", ret, szIPRemote, szIPLocal, (unsigned long)threadid);
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s on thread %lu sr=%d", ret, szIPRemote, szIPLocal, (unsigned long)threadid, (int)pSocket->GetRole());
allowed_to_pass = (_spLimiter.get() != nullptr) ? _spLimiter->RateCheck(_msgIn.addrRemote) : true;
if (allowed_to_pass == false)
{
Logging::LogMsg(LL_VERBOSE, "RateLimiter signals false for packet from %s", szIPRemote);
}
if ((ret < 0) || (allowed_to_pass == false))
{
// error
continue;
}
if (_fNeedToExit)
{
break;
}
_spBufferIn->SetSize(ret);
_msgIn.socketrole = pSocket->GetRole();
......@@ -395,8 +369,8 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth.get()));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
ASSERT(_arrSendSockets[_msgOut.socketrole].IsValid());
sockout = _arrSendSockets[_msgOut.socketrole].GetSocketHandle();
ASSERT(_arrSendSockets[_msgOut.socketrole]->IsValid());
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle();
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
......
......@@ -14,8 +14,6 @@
limitations under the License.
*/
#ifndef STUNSOCKETTHREAD_H
#define STUNSOCKETTHREAD_H
......@@ -33,14 +31,12 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(CStunSocket* arrayOfFourSockets, TransportAddressSet* pTSA, std::shared_ptr<IStunAuth> spAuth, SocketRole rolePrimaryRecv, std::shared_ptr<RateLimiter>& _spRateLimiter);
HRESULT Init(std::vector<std::shared_ptr<CStunSocket>>& arrayOfFourSockets, const TransportAddressSet& tsa, std::shared_ptr<IStunAuth> spAuth, SocketRole rolePrimaryRecv, std::shared_ptr<RateLimiter>& _spRateLimiter);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
HRESULT SignalForStop();
HRESULT WaitForStopAndClose();
private:
// this is the function that runs in a thread
......@@ -50,8 +46,8 @@ private:
CStunSocket* WaitForSocketData();
CStunSocket* _arrSendSockets; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on
std::vector<std::shared_ptr<CStunSocket>> _arrSendSockets; // 1 socket in basic mode. 4 sockets in full mode
std::vector<std::shared_ptr<CStunSocket>> _socks; // 1 socket in multi-threaded or basic mode. 4 sockets in single-threaded full mode
bool _fNeedToExit;
pthread_t _pthread;
......@@ -67,7 +63,7 @@ private:
CStunMessageReader _reader;
CRefCountedBuffer _spBufferReader; // buffer internal to the reader
CRefCountedBuffer _spBufferIn; // buffer we receive requests on
CRefCountedBuffer _spBufferOut; // buffer we send response on
CRefCountedBuffer _spBufferOut; // buffer we send responses on
StunMessageIn _msgIn;
StunMessageOut _msgOut;
......@@ -80,10 +76,9 @@ private:
void ClearSocketArray();
};
void DumpInitParams(std::vector<std::shared_ptr<CStunSocket>>& arrayOfFourSockets, const TransportAddressSet& tsa, SocketRole rolePrimaryRecv);
};
#endif /* STUNSOCKETTHREAD_H */
......@@ -844,15 +844,15 @@ HRESULT CTCPServer::Initialize(const CStunServerConfig& config)
// tsaHandler is sort of a hack for TCP. It's really just a glorified indication to the the
// CStunRequestHandler code to figure out if it can offer a CHANGED-ADDRESS attribute.
InitTSA(&tsaHandler, RolePP, config.fHasPP, config.addrPP, config.addrPrimaryAdvertised);
InitTSA(&tsaHandler, RolePA, config.fHasPA, config.addrPA, config.addrPrimaryAdvertised);
InitTSA(&tsaHandler, RoleAP, config.fHasAP, config.addrAP, config.addrAlternateAdvertised);
InitTSA(&tsaHandler, RoleAA, config.fHasAA, config.addrAA, config.addrAlternateAdvertised);
InitTSA(&tsaListenAll, RolePP, config.fHasPP, config.addrPP, CSocketAddress());
InitTSA(&tsaListenAll, RolePA, config.fHasPA, config.addrPA, CSocketAddress());
InitTSA(&tsaListenAll, RoleAP, config.fHasAP, config.addrAP, CSocketAddress());
InitTSA(&tsaListenAll, RoleAA, config.fHasAA, config.addrAA, CSocketAddress());
InitTSA(&tsaHandler, RolePP, true, config.addrPP, config.addrPrimaryAdvertised);
InitTSA(&tsaHandler, RolePA, config.fIsFullMode, config.addrPA, config.addrPrimaryAdvertised);
InitTSA(&tsaHandler, RoleAP, config.fIsFullMode, config.addrAP, config.addrAlternateAdvertised);
InitTSA(&tsaHandler, RoleAA, config.fIsFullMode, config.addrAA, config.addrAlternateAdvertised);
InitTSA(&tsaListenAll, RolePP, true, config.addrPP, CSocketAddress());
InitTSA(&tsaListenAll, RolePA, config.fIsFullMode, config.addrPA, CSocketAddress());
InitTSA(&tsaListenAll, RoleAP, config.fIsFullMode, config.addrAP, CSocketAddress());
InitTSA(&tsaListenAll, RoleAA, config.fIsFullMode, config.addrAA, CSocketAddress());
if (config.fEnableDosProtection)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment