Commit ea8d74b9 authored by John Selbie's avatar John Selbie

Refactoring for upcoming TCP support

parent 5007dd15
/*
* File: main.cpp
* Author: jselbie
*
* Created on December 3, 2011, 11:18 PM
*/
#ifndef FASTHASH_H
#define FASTHASH_H
// FastHash is a cheap and dirty hash table template class
// It is "fast" because it never allocates memory beyond the static arrays inside each instance
// Hence, it can be used off the stack or in cases where memory allocations impact performance
// Limitations:
// Fixed number of insertions (specified by FSIZE)
// Does not support removals
// Made for simple types and structs of simple types - as items are pre-allocated (no regards to constructors or destructors)
// Duplicate key insertions will not remove the previous item
// Additional:
// FastHash keeps a static array of items inserted (in insertion order)
// Then a hash table of <K,int> to map keys back to index values
// This allows calling code to be able to iterate over the table in insertion order
// Template parameters
// K = key type
// V = value type
// FSIZE = max number of items in the hash table (default is 100)
// TSIZE = hash table width (higher value reduces collisions, but with extra memory overhead - default is 37). Usually a prime number.
inline size_t FastHash_Hash(unsigned int x)
{
return (size_t)x;
}
inline size_t FastHash_Hash(signed int x)
{
return (size_t)x;
}
const size_t FAST_HASH_DEFAULT_CAPACITY = 100;
const size_t FASH_HASH_DEFAULT_TABLE_SIZE = 37;
template <class K, class V, size_t FSIZE=FAST_HASH_DEFAULT_CAPACITY, size_t TSIZE=FASH_HASH_DEFAULT_TABLE_SIZE>
class FastHash
{
private:
struct ItemNode
{
K key;
int index; // index into _list where this item is stored
ItemNode* pNext;
};
V _list[FSIZE]; // list of items
size_t _count; // number of items inserted so far
ItemNode _tablenodes[FSIZE];
ItemNode* _table[TSIZE];
public:
FastHash()
{
Reset();
}
void Reset()
{
_count = 0;
memset(_table, '\0', sizeof(_table));
}
size_t Size()
{
return _count;
}
HRESULT Insert(K key, const V& val)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
if (_count >= FSIZE)
{
return false;
}
_list[_count] = val;
_tablenodes[_count].index = _count;
_tablenodes[_count].key = key;
_tablenodes[_count].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[_count];
_count++;
return true;
}
V* Lookup(K key, int* pIndex=NULL)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
V* pFoundItem = NULL;
ItemNode* pHead = _table[tableindex];
if (pIndex)
{
*pIndex = -1;
}
while (pHead)
{
if (pHead->key == key)
{
pFoundItem = &_list[pHead->index];
if (pIndex)
{
*pIndex = pHead->index;
}
break;
}
pHead = pHead->pNext;
}
return pFoundItem;
}
bool Exists(K key)
{
V* pItem = Lookup(key);
return (pItem != NULL);
}
V* GetItemByIndex(int index)
{
if ((index < 0) || (((size_t)index) >= _count))
{
return NULL;
}
return &_list[index];
}
};
#endif
\ No newline at end of file
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
#include "server.h"
#endif /* SERVER_H */
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
#include "server.h"
#endif /* SERVER_H */
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <openssl/hmac.h> #include <openssl/hmac.h>
#include <openssl/md5.h> #include <openssl/md5.h>
#include "stunauth.h" #include "stunauth.h"
#include "fasthash.h"
...@@ -35,17 +36,10 @@ CStunMessageReader::CStunMessageReader() : ...@@ -35,17 +36,10 @@ CStunMessageReader::CStunMessageReader() :
_fAllowLegacyFormat(false), _fAllowLegacyFormat(false),
_fMessageIsLegacyFormat(false), _fMessageIsLegacyFormat(false),
_state(HeaderNotRead), _state(HeaderNotRead),
_nAttributeCount(0),
_transactionid(), _transactionid(),
_msgTypeNormalized(0xffff), _msgTypeNormalized(0xffff),
_msgClass(StunMsgClassInvalidMessageClass), _msgClass(StunMsgClassInvalidMessageClass),
_msgLength(0), _msgLength(0)
_indexFingerprint(-1),
_indexResponsePort(-1),
_indexChangeRequest(-1),
_indexPaddingAttribute(-1),
_indexErrorCode(-1),
_indexMessageIntegrity(-1)
{ {
; ;
} }
...@@ -79,13 +73,14 @@ uint16_t CStunMessageReader::HowManyBytesNeeded() ...@@ -79,13 +73,14 @@ uint16_t CStunMessageReader::HowManyBytesNeeded()
bool CStunMessageReader::HasFingerprintAttribute() bool CStunMessageReader::HasFingerprintAttribute()
{ {
return (_indexFingerprint >= 0); StunAttribute *pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_FINGERPRINT);
return (pAttrib != NULL);
} }
bool CStunMessageReader::IsFingerprintAttributeValid() bool CStunMessageReader::IsFingerprintAttributeValid()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
StunAttribute attrib={}; StunAttribute* pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_FINGERPRINT);
CRefCountedBuffer spBuffer; CRefCountedBuffer spBuffer;
size_t size=0; size_t size=0;
boost::crc_32_type crc; boost::crc_32_type crc;
...@@ -97,11 +92,11 @@ bool CStunMessageReader::IsFingerprintAttributeValid() ...@@ -97,11 +92,11 @@ bool CStunMessageReader::IsFingerprintAttributeValid()
// the fingerprint attribute MUST be the last attribute in the stream. // the fingerprint attribute MUST be the last attribute in the stream.
// If it's not, then the code below will return false // If it's not, then the code below will return false
ChkIf(pAttrib==NULL, E_FAIL);
ChkIfA(pAttrib->attributeType != STUN_ATTRIBUTE_FINGERPRINT, E_FAIL);
GetAttributeByIndex(_indexFingerprint, &attrib); ChkIf(pAttrib->size != 4, E_FAIL);
ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_FINGERPRINT, E_FAIL);
ChkIf(attrib.size != 4, E_FAIL);
ChkIf(_state != BodyValidated, E_FAIL); ChkIf(_state != BodyValidated, E_FAIL);
Chk(_stream.GetBuffer(&spBuffer)); Chk(_stream.GetBuffer(&spBuffer));
...@@ -115,7 +110,7 @@ bool CStunMessageReader::IsFingerprintAttributeValid() ...@@ -115,7 +110,7 @@ bool CStunMessageReader::IsFingerprintAttributeValid()
computedValue = crc.checksum(); computedValue = crc.checksum();
computedValue = computedValue ^ STUN_FINGERPRINT_XOR; computedValue = computedValue ^ STUN_FINGERPRINT_XOR;
readValue = *(uint32_t*)(ptr+attrib.offset); readValue = *(uint32_t*)(ptr+pAttrib->offset);
readValue = ntohl(readValue); readValue = ntohl(readValue);
hr = (readValue==computedValue) ? S_OK : E_FAIL; hr = (readValue==computedValue) ? S_OK : E_FAIL;
...@@ -125,13 +120,14 @@ Cleanup: ...@@ -125,13 +120,14 @@ Cleanup:
bool CStunMessageReader::HasMessageIntegrityAttribute() bool CStunMessageReader::HasMessageIntegrityAttribute()
{ {
return (_indexMessageIntegrity >= 0); return (NULL != _mapAttributes.Lookup(STUN_ATTRIBUTE_MESSAGEINTEGRITY));
} }
HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylength) HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylength)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
int lastAttributeIndex = _nAttributeCount - 1;
int lastAttributeIndex = ((int)_mapAttributes.Size()) - 1;
bool fFingerprintAdjustment = false; bool fFingerprintAdjustment = false;
bool fNoOtherAttributesAfterIntegrity = false; bool fNoOtherAttributesAfterIntegrity = false;
const size_t c_hmacsize = 20; const size_t c_hmacsize = 20;
...@@ -143,29 +139,34 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -143,29 +139,34 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
size_t len, nChunks; size_t len, nChunks;
CDataStream stream; CDataStream stream;
CRefCountedBuffer spBuffer; CRefCountedBuffer spBuffer;
StunAttribute attribIntegrity; StunAttribute* pAttribIntegrity=NULL;
int indexMessageIntegrity = 0;
int indexFingerprint = -1;
int cmp = 0; int cmp = 0;
bool fContextInit = false; bool fContextInit = false;
ChkIf(_state != BodyValidated, E_FAIL); ChkIf(_state != BodyValidated, E_FAIL);
ChkIf(_indexMessageIntegrity < 0, E_FAIL);
// can a key be empty? // can a key be empty?
ChkIfA(key==NULL, E_INVALIDARG); ChkIfA(key==NULL, E_INVALIDARG);
ChkIfA(keylength==0, E_INVALIDARG); ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
Chk(this->GetAttributeByIndex(_indexMessageIntegrity, &attribIntegrity));
ChkIf(attribIntegrity.size != c_hmacsize, E_FAIL); ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL); ChkIfA(lastAttributeIndex < 0, E_FAIL);
// first, check to make sure that no other attributes (other than fingerprint) follow the message integrity // first, check to make sure that no other attributes (other than fingerprint) follow the message integrity
fNoOtherAttributesAfterIntegrity = (_indexMessageIntegrity == lastAttributeIndex) || ((_indexMessageIntegrity == (lastAttributeIndex-1)) && (_indexFingerprint == lastAttributeIndex)); fNoOtherAttributesAfterIntegrity = (indexMessageIntegrity == lastAttributeIndex) || ((indexMessageIntegrity == (lastAttributeIndex-1)) && (indexFingerprint == lastAttributeIndex));
ChkIf(fNoOtherAttributesAfterIntegrity==false, E_FAIL); ChkIf(fNoOtherAttributesAfterIntegrity==false, E_FAIL);
fFingerprintAdjustment = (_indexMessageIntegrity == (lastAttributeIndex-1)); fFingerprintAdjustment = (indexMessageIntegrity == (lastAttributeIndex-1));
Chk(GetBuffer(&spBuffer)); Chk(GetBuffer(&spBuffer));
stream.Attach(spBuffer, false); stream.Attach(spBuffer, false);
...@@ -193,9 +194,9 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -193,9 +194,9 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
HMAC_Update(&ctx, (unsigned char*)&chunk16, sizeof(chunk16)); HMAC_Update(&ctx, (unsigned char*)&chunk16, sizeof(chunk16));
// now include everything up to the hash attribute itself. // now include everything up to the hash attribute itself.
len = _attributes[_indexMessageIntegrity].offset; len = pAttribIntegrity->offset;
len -= 4; // subtract the size of the attribute header len -= 4; // subtract the size of the attribute header
len -= 4; // subtrack the size of the message header (not including the transaction id) len -= 4; // subtract the size of the message header (not including the transaction id)
// len should be the number of bytes from the start of the transaction ID up through to the start of the integrity attribute header // len should be the number of bytes from the start of the transaction ID up through to the start of the integrity attribute header
// the stun message has to be a multiple of 4 bytes, so we can read in 32 bit chunks // the stun message has to be a multiple of 4 bytes, so we can read in 32 bit chunks
...@@ -210,7 +211,7 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -210,7 +211,7 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
HMAC_Final(&ctx, hmaccomputed, &hmaclength); HMAC_Final(&ctx, hmaccomputed, &hmaclength);
// now compare the bytes // now compare the bytes
cmp = memcmp(hmaccomputed, spBuffer->GetData() + attribIntegrity.offset, c_hmacsize); cmp = memcmp(hmaccomputed, spBuffer->GetData() + pAttribIntegrity->offset, c_hmacsize);
hr = (cmp == 0 ? S_OK : E_FAIL); hr = (cmp == 0 ? S_OK : E_FAIL);
...@@ -289,58 +290,60 @@ Cleanup: ...@@ -289,58 +290,60 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute) HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{ {
HRESULT hr = E_FAIL; StunAttribute* pFound = _mapAttributes.Lookup(attributeType, NULL);
for (size_t index = 0; index < _nAttributeCount; index++) if (pFound == NULL)
{ {
if (attributeType == _attributes[index].attributeType) return E_FAIL;
{
hr = S_OK;
// pAttribute can be NULL - useful if the caller just wants to detect the presence of an attribute
if (pAttribute)
{
*pAttribute = _attributes[index];
}
}
} }
return hr; if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
} }
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute) HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{ {
HRESULT hr = E_FAIL; StunAttribute* pFound = _mapAttributes.GetItemByIndex(index);
if ((index >= 0) && (index < (int)_nAttributeCount)) if (pFound == NULL)
{ {
hr = S_OK; return E_FAIL;
}
if (pAttribute)
{ if (pAttribute)
*pAttribute = _attributes[index]; {
} *pAttribute = *pFound;
} }
return S_OK;
}
return hr; int CStunMessageReader::GetAttributeCount()
{
return (int)(_mapAttributes.Size());
} }
HRESULT CStunMessageReader::GetResponsePort(uint16_t* pPort) HRESULT CStunMessageReader::GetResponsePort(uint16_t* pPort)
{ {
StunAttribute attrib; StunAttribute* pAttrib = NULL;
HRESULT hr = S_OK; HRESULT hr = S_OK;
uint16_t portNBO; uint16_t portNBO;
uint8_t *pData = NULL; uint8_t *pData = NULL;
ChkIfA(pPort == NULL, E_INVALIDARG); ChkIfA(pPort == NULL, E_INVALIDARG);
Chk(GetAttributeByIndex(_indexResponsePort, &attrib)); pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_RESPONSE_PORT);
ChkIf(attrib.size != STUN_ATTRIBUTE_RESPONSE_PORT_SIZE, E_UNEXPECTED); ChkIf(pAttrib == NULL, E_FAIL);
ChkIf(pAttrib->size != STUN_ATTRIBUTE_RESPONSE_PORT_SIZE, E_UNEXPECTED);
pData = _stream.GetDataPointerUnsafe(); pData = _stream.GetDataPointerUnsafe();
ChkIf(pData==NULL, E_UNEXPECTED); ChkIf(pData==NULL, E_UNEXPECTED);
memcpy(&portNBO, pData + attrib.offset, STUN_ATTRIBUTE_RESPONSE_PORT_SIZE); memcpy(&portNBO, pData + pAttrib->offset, STUN_ATTRIBUTE_RESPONSE_PORT_SIZE);
*pPort = ntohs(portNBO); *pPort = ntohs(portNBO);
Cleanup: Cleanup:
return hr; return hr;
...@@ -350,18 +353,20 @@ HRESULT CStunMessageReader::GetChangeRequest(StunChangeRequestAttribute* pChange ...@@ -350,18 +353,20 @@ HRESULT CStunMessageReader::GetChangeRequest(StunChangeRequestAttribute* pChange
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
uint8_t *pData = NULL; uint8_t *pData = NULL;
StunAttribute attrib; StunAttribute *pAttrib;
uint32_t value = 0; uint32_t value = 0;
ChkIfA(pChangeRequest == NULL, E_INVALIDARG); ChkIfA(pChangeRequest == NULL, E_INVALIDARG);
pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_CHANGEREQUEST);
ChkIf(pAttrib == NULL, E_FAIL);
Chk(GetAttributeByIndex(_indexChangeRequest, &attrib)); ChkIf(pAttrib->size != STUN_ATTRIBUTE_CHANGEREQUEST_SIZE, E_UNEXPECTED);
ChkIf(attrib.size != STUN_ATTRIBUTE_CHANGEREQUEST_SIZE, E_UNEXPECTED);
pData = _stream.GetDataPointerUnsafe(); pData = _stream.GetDataPointerUnsafe();
ChkIf(pData==NULL, E_UNEXPECTED); ChkIf(pData==NULL, E_UNEXPECTED);
memcpy(&value, pData + attrib.offset, STUN_ATTRIBUTE_CHANGEREQUEST_SIZE); memcpy(&value, pData + pAttrib->offset, STUN_ATTRIBUTE_CHANGEREQUEST_SIZE);
value = ntohl(value); value = ntohl(value);
...@@ -382,15 +387,17 @@ Cleanup: ...@@ -382,15 +387,17 @@ Cleanup:
HRESULT CStunMessageReader::GetPaddingAttributeSize(uint16_t* pSizePadding) HRESULT CStunMessageReader::GetPaddingAttributeSize(uint16_t* pSizePadding)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
StunAttribute attrib; StunAttribute *pAttrib;
ChkIfA(pSizePadding == NULL, E_INVALIDARG); ChkIfA(pSizePadding == NULL, E_INVALIDARG);
*pSizePadding = 0; *pSizePadding = 0;
pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_PADDING);
Chk(GetAttributeByIndex(_indexPaddingAttribute, &attrib)); ChkIf(pAttrib == NULL, E_FAIL);
*pSizePadding = attrib.size; *pSizePadding = pAttrib->size;
Cleanup: Cleanup:
return hr; return hr;
...@@ -403,15 +410,16 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber) ...@@ -403,15 +410,16 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber)
uint8_t cl = 0; uint8_t cl = 0;
uint8_t num = 0; uint8_t num = 0;
StunAttribute attrib; StunAttribute* pAttrib;
ChkIf(pErrorNumber==NULL, E_INVALIDARG); ChkIf(pErrorNumber==NULL, E_INVALIDARG);
Chk(GetAttributeByIndex(_indexErrorCode, &attrib)); pAttrib = _mapAttributes.Lookup(::STUN_ATTRIBUTE_ERRORCODE);
ChkIf(pAttrib == NULL, E_FAIL);
// first 21 bits of error-code attribute must be zero. // first 21 bits of error-code attribute must be zero.
// followed by 3 bits of "class" and 8 bits for the error number modulo 100 // followed by 3 bits of "class" and 8 bits for the error number modulo 100
ptr = _stream.GetDataPointerUnsafe() + attrib.offset + 2; ptr = _stream.GetDataPointerUnsafe() + pAttrib->offset + 2;
cl = *ptr++; cl = *ptr++;
cl = cl & 0x07; cl = cl & 0x07;
...@@ -426,12 +434,13 @@ Cleanup: ...@@ -426,12 +434,13 @@ Cleanup:
HRESULT CStunMessageReader::GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr) HRESULT CStunMessageReader::GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
StunAttribute attrib={}; StunAttribute* pAttrib = _mapAttributes.Lookup(attribType);
uint8_t *pAddrStart = NULL; uint8_t *pAddrStart = NULL;
Chk(GetAttributeByType(attribType, &attrib)); ChkIf(pAttrib == NULL, E_FAIL);
pAddrStart = _stream.GetDataPointerUnsafe() + attrib.offset;
Chk(::GetMappedAddress(pAddrStart, attrib.size, pAddr)); pAddrStart = _stream.GetDataPointerUnsafe() + pAttrib->offset;
Chk(::GetMappedAddress(pAddrStart, pAttrib->size, pAddr));
Cleanup: Cleanup:
return hr; return hr;
...@@ -475,17 +484,16 @@ Cleanup: ...@@ -475,17 +484,16 @@ Cleanup:
HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size) HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
StunAttribute attrib = {}; StunAttribute* pAttrib = _mapAttributes.Lookup(attributeType);
ChkIfA(pszValue == NULL, E_INVALIDARG); ChkIfA(pszValue == NULL, E_INVALIDARG);
ChkIf(pAttrib == NULL, E_INVALIDARG);
Chk(GetAttributeByType(attributeType, &attrib));
// size needs to be 1 greater than attrib.size so we can properly copy over a null char at the end // size needs to be 1 greater than attrib.size so we can properly copy over a null char at the end
ChkIf(attrib.size >= size, E_INVALIDARG); ChkIf(pAttrib->size >= size, E_INVALIDARG);
memcpy(pszValue, _stream.GetDataPointerUnsafe() + attrib.offset, attrib.size); memcpy(pszValue, _stream.GetDataPointerUnsafe() + pAttrib->offset, pAttrib->size);
pszValue[attrib.size] = '\0'; pszValue[pAttrib->size] = '\0';
Cleanup: Cleanup:
return hr; return hr;
...@@ -599,41 +607,18 @@ HRESULT CStunMessageReader::ReadBody() ...@@ -599,41 +607,18 @@ HRESULT CStunMessageReader::ReadBody()
hr = (attributeLength <= MAX_STUN_ATTRIBUTE_SIZE) ? S_OK : E_FAIL; hr = (attributeLength <= MAX_STUN_ATTRIBUTE_SIZE) ? S_OK : E_FAIL;
} }
if (SUCCEEDED(hr))
{
hr = (_nAttributeCount < MAX_NUM_ATTRIBUTES) ? S_OK : E_FAIL;
}
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
{ {
StunAttribute attrib; StunAttribute attrib;
int attributeIndex = _nAttributeCount;
attrib.attributeType = attributeType; attrib.attributeType = attributeType;
attrib.size = attributeLength; attrib.size = attributeLength;
attrib.offset = attributeOffset; attrib.offset = attributeOffset;
_attributes[_nAttributeCount++] = attrib;
hr = _mapAttributes.Insert(attributeType, attrib);
// now if this attribute is one we want to cache the index for, let's do it here }
// todo - think about a "fast map" for this operation
switch (attributeType) if (SUCCEEDED(hr))
{ {
case STUN_ATTRIBUTE_FINGERPRINT:
_indexFingerprint = attributeIndex; break;
case STUN_ATTRIBUTE_RESPONSE_PORT:
_indexResponsePort = attributeIndex; break;
case STUN_ATTRIBUTE_CHANGEREQUEST:
_indexChangeRequest = attributeIndex; break;
case STUN_ATTRIBUTE_PADDING:
_indexPaddingAttribute = attributeIndex; break;
case STUN_ATTRIBUTE_ERRORCODE:
_indexErrorCode = attributeIndex; break;
case STUN_ATTRIBUTE_MESSAGEINTEGRITY:
_indexMessageIntegrity = attributeIndex; break;
default: break;
}
hr = _stream.SeekRelative(attributeLength); hr = _stream.SeekRelative(attributeLength);
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "stuntypes.h" #include "stuntypes.h"
#include "datastream.h" #include "datastream.h"
#include "socketaddress.h" #include "socketaddress.h"
#include "fasthash.h"
class CStunMessageReader class CStunMessageReader
...@@ -46,8 +47,13 @@ private: ...@@ -46,8 +47,13 @@ private:
ReaderParseState _state; ReaderParseState _state;
static const size_t MAX_NUM_ATTRIBUTES = 30; static const size_t MAX_NUM_ATTRIBUTES = 30;
StunAttribute _attributes[MAX_NUM_ATTRIBUTES]; //StunAttribute _attributes[MAX_NUM_ATTRIBUTES];
size_t _nAttributeCount; //size_t _nAttributeCount;
typedef FastHash<uint16_t, StunAttribute, MAX_NUM_ATTRIBUTES, 53> AttributeHashTable; // 53 is a prime number for a reasonable table width
AttributeHashTable _mapAttributes;
StunTransactionId _transactionid; StunTransactionId _transactionid;
uint16_t _msgTypeNormalized; uint16_t _msgTypeNormalized;
......
include ../common.inc include ../common.inc
PROJECT_TARGET := stuntestcode PROJECT_TARGET := stuntestcode
PROJECT_OBJS := testbuilder.o testclientlogic.o testcmdline.o testcode.o testdatastream.o testintegrity.o testmessagehandler.o testreader.o testrecvfromex.o PROJECT_OBJS := testbuilder.o testclientlogic.o testcmdline.o testcode.o testdatastream.o testfasthash.o testintegrity.o testmessagehandler.o testreader.o testrecvfromex.o
INCLUDES := $(BOOST_INCLUDE) $(OPENSSL_INCLUDE) -I../common -I../stuncore -I../networkutils INCLUDES := $(BOOST_INCLUDE) $(OPENSSL_INCLUDE) -I../common -I../stuncore -I../networkutils
LIB_PATH := -L../networkutils -L../stuncore -L../common LIB_PATH := -L../networkutils -L../stuncore -L../common
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "testintegrity.h" #include "testintegrity.h"
#include "testclientlogic.h" #include "testclientlogic.h"
#include "testrecvfromex.h" #include "testrecvfromex.h"
#include "testfasthash.h"
#include "cmdlineparser.h" #include "cmdlineparser.h"
#include "oshelper.h" #include "oshelper.h"
#include "prettyprint.h" #include "prettyprint.h"
...@@ -76,6 +77,7 @@ void RunUnitTests() ...@@ -76,6 +77,7 @@ void RunUnitTests()
boost::shared_ptr<CTestCmdLineParser> spTestCmdLineParser(new CTestCmdLineParser); boost::shared_ptr<CTestCmdLineParser> spTestCmdLineParser(new CTestCmdLineParser);
boost::shared_ptr<CTestClientLogic> spTestClientLogic(new CTestClientLogic); boost::shared_ptr<CTestClientLogic> spTestClientLogic(new CTestClientLogic);
boost::shared_ptr<CTestRecvFromEx> spTestRecvFromEx(new CTestRecvFromEx); boost::shared_ptr<CTestRecvFromEx> spTestRecvFromEx(new CTestRecvFromEx);
boost::shared_ptr<CTestFastHash> spTestFastHash(new CTestFastHash);
vecTests.push_back(spTestDataStream.get()); vecTests.push_back(spTestDataStream.get());
vecTests.push_back(spTestReader.get()); vecTests.push_back(spTestReader.get());
...@@ -85,6 +87,8 @@ void RunUnitTests() ...@@ -85,6 +87,8 @@ void RunUnitTests()
vecTests.push_back(spTestCmdLineParser.get()); vecTests.push_back(spTestCmdLineParser.get());
vecTests.push_back(spTestClientLogic.get()); vecTests.push_back(spTestClientLogic.get());
vecTests.push_back(spTestRecvFromEx.get()); vecTests.push_back(spTestRecvFromEx.get());
vecTests.push_back(spTestFastHash.get());
for (size_t index = 0; index < vecTests.size(); index++) for (size_t index = 0; index < vecTests.size(); index++)
{ {
......
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "commonincludes.h"
#include "testfasthash.h"
#include "fasthash.h"
HRESULT CTestFastHash::Run()
{
return TestFastHash();
}
HRESULT CTestFastHash::TestFastHash()
{
HRESULT hr = S_OK;
const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash;
for (size_t index = 0; index < c_maxsize; index++)
{
Item item;
item.key = (int)index;
ChkA(hash.Insert((int)index, item));
}
// validate that all the items are in the table
for (size_t index = 0; index < c_maxsize; index++)
{
Item* pItem = NULL;
Item* pItemDirect = NULL;
int insertindex = -1;
ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup((int)index, &insertindex);
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != (int)index, E_FAIL);
ChkIfA((int)index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL);
}
// validate that items aren't in the table don't get returned
for (size_t index = c_maxsize; index < (c_maxsize*2); index++)
{
ChkIfA(hash.Exists((int)index), E_FAIL);
ChkIfA(hash.Lookup((int)index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex((int)index)!=NULL, E_FAIL);
}
Cleanup:
return hr;
}
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef TEST_FAST_HASH_H
#define TEST_FAST_HASH_H
#include "commonincludes.h"
#include "unittest.h"
class CTestFastHash : public IUnitTest
{
private:
HRESULT TestFastHash();
struct Item
{
int key;
};
public:
virtual HRESULT Run();
UT_DECLARE_TEST_NAME("CTestFastHash");
};
#endif
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "commonincludes.h" #include "commonincludes.h"
#include "stuncore.h" #include "stuncore.h"
#include "testintegrity.h" #include "testintegrity.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment