Program Listing for File socket.cpp

Return to documentation for file (blam\components\networking\socket.cpp)

#include "socket.h"

#include <accctrl.h>
#include <cassert>

#include "components\core\config\config.h"
#include "components/core/logger/logger.h"

UINT32 BufferCap = 2048;
UINT32 BufferMask = (BufferCap - 1);

static LARGE_INTEGER clockFreq;

bool isConnected = false;

namespace Internal
{
    SOCKADDR_IN ip_endpoint_to_sockaddr_in(Blam::Endpoint* ip_endpoint)
    {
        SOCKADDR_IN sockaddr_in;
        sockaddr_in.sin_family = AF_INET;
        sockaddr_in.sin_addr.s_addr = htonl(ip_endpoint->address);
        sockaddr_in.sin_port = htons(ip_endpoint->port);
        return sockaddr_in;
    }

    static bool set_sock_opt(SOCKET sock, int opt, int val)
    {
        int length = sizeof(int);
        int actual;

        if (setsockopt(sock, SOL_SOCKET, opt, (char*)&val, length) == SOCKET_ERROR)
        {
            return false;
        }

        if (getsockopt(sock, SOL_SOCKET, opt, (char*)&actual, &length) == SOCKET_ERROR)
        {
            return false;
        }

        if (val == actual)
        {
            return true;
        }
        else
        {
            return false;
        }
    }

    bool socket(Blam::Socket* outSocket)
    {
        int address_family = AF_INET;
        int type = SOCK_DGRAM;
        int protocol = IPPROTO_UDP;
        SOCKET sock = ::socket(address_family, type, protocol);

        if (!set_sock_opt(sock, SO_RCVBUF, 2 * 1024 * 1024))
        {
            Blam::LogEvent("failed to set rcvbuf size");
        }
        if (!set_sock_opt(sock, SO_SNDBUF, 2 * 1024 * 1024))
        {
            Blam::LogEvent("failed to set sndbuf size");
        }

        if (sock == INVALID_SOCKET)
        {
            Blam::LogEvent("socket() failed");
            return false;
        }

        u_long enabled = 1;

        int result = ioctlsocket(sock, FIONBIO, &enabled);

        if (result == SOCKET_ERROR)
        {
            Blam::LogEvent("ioctlsocket() failed");

            return false;
        }

        *outSocket = {};
        outSocket->handle = sock;

        return true;
    }

    void socket_close(Blam::Socket* sock)
    {
        int result = closesocket(sock->handle);

        assert(result != SOCKET_ERROR);
    }

    bool socket_bind(Blam::Socket* sock, Blam::Endpoint* localEndpoint)
    {
        SOCKADDR_IN localAddr = ip_endpoint_to_sockaddr_in(localEndpoint);

        if (bind(sock->handle, (SOCKADDR*)&localAddr, sizeof(localAddr)) == SOCKET_ERROR)
        {
            Blam::LogEvent("bind() failed");

            return false;
        }

        return true;
    }

    bool socket_send(Blam::Socket* sock, UINT8* packet, UINT32 packetSize, Blam::Endpoint* endpoint)
    {
        SOCKADDR_IN serverAddr;
        serverAddr.sin_family = AF_INET;
        serverAddr.sin_addr.S_un.S_addr = htonl(endpoint->address);
        serverAddr.sin_port = htons(endpoint->port);
        int serverAddrSize = sizeof(serverAddr);

        if (sendto(sock->handle, (const char*)packet, packetSize, 0, (SOCKADDR*)&serverAddr, serverAddrSize) == SOCKET_ERROR)
        {
            Blam::LogEvent("sendto() failed");

            return false;
        }

        return true;
    }

    bool socket_receive(Blam::Socket* sock, UINT8* buffer, UINT32 bufferSize, UINT32* outSize, Blam::Endpoint* endpoint)
    {
        int flags = 0;
        SOCKADDR_IN from;
        int fromSize = sizeof(from);

        int bytesReceived = recvfrom(sock->handle, (char*)buffer, bufferSize, flags, (SOCKADDR*)&from, &fromSize);

        if (bytesReceived == SOCKET_ERROR)
        {
            int error = WSAGetLastError();

            // TODO: Handle these errors accordingly

            return false;
        }

        *outSize = bytesReceived;

        *endpoint = {};
        endpoint->address = ntohl(from.sin_addr.S_un.S_addr);
        endpoint->port = ntohs(from.sin_port);

        return true;
    }
}

void Blam::linearAllocCreate(Blam::LinearAllocator* alloc, UINT64 size)
{
    alloc->memory = new UINT8[size];
    alloc->next = alloc->memory;
    alloc->bytesRemaining = size;
}

UINT8* Blam::linearAlloc(Blam::LinearAllocator* alloc, UINT64 size)
{
    UINT8* mem = alloc->next;
    alloc->next += size;
    alloc->bytesRemaining -= size;

    return mem;
}

static Blam::Network::PacketBuffer packetBuffer(Blam::LinearAllocator* allocator)
{
    QueryPerformanceCounter(&clockFreq);

    Blam::Network::PacketBuffer packetBuffer = {};

    packetBuffer.index = 0;
    packetBuffer.size = 0;
    packetBuffer.packets = Blam::linearAlloc(allocator, BufferCap * 1024);
    packetBuffer.packetSizes = (UINT32*) Blam::linearAlloc(allocator, sizeof(UINT32) * BufferCap);
    packetBuffer.endpoints = (Blam::Endpoint*) Blam::linearAlloc(allocator, sizeof(Blam::Endpoint) * BufferCap);
    packetBuffer.times = (LARGE_INTEGER*) Blam::linearAlloc(allocator, sizeof(LARGE_INTEGER) * BufferCap);

    return packetBuffer;
}

static bool bufferIsFull(Blam::Network::PacketBuffer* pBuffer)
{
    return pBuffer->size == BufferCap;
}

static void bufferPush(Blam::Network::PacketBuffer* pBuffer, UINT8* packet, UINT32 packetSize, Blam::Endpoint* endpoint)
{
    LARGE_INTEGER now;

    QueryPerformanceCounter(&now);

    LARGE_INTEGER then;

    then.QuadPart = now.QuadPart + (LONGLONG)(clockFreq.QuadPart * 0); // where 0 = fake lag ms

    UINT32 index = pBuffer->index;
    pBuffer->times[index] = then;
    pBuffer->packetSizes[index] = packetSize;
    pBuffer->endpoints[index] = *endpoint;

    UINT8* dstPacket = &pBuffer->packets[index * 1024];
    memcpy(dstPacket, packet, packetSize);

    pBuffer->size = pBuffer->size + 1;
    pBuffer->index = (pBuffer->index + 1) & BufferMask;
}

static void bufferForcePop(Blam::Network::PacketBuffer* pBuffer, UINT8** outPacket, UINT32* outPacketSize, Blam::Endpoint* endpoint)
{
    UINT32 index = (pBuffer->index - pBuffer->size) & BufferMask;

    *outPacket = &pBuffer->packets[index * 1024];
    *outPacketSize = pBuffer->packetSizes[index];
    *endpoint = pBuffer->endpoints[index];

    pBuffer->size = pBuffer->size - 1;
}

static bool bufferPop(Blam::Network::PacketBuffer* pBuffer, UINT8** outPacket, UINT32* outPacketSize, Blam::Endpoint* endpoint)
{
    LARGE_INTEGER now;
    QueryPerformanceCounter(&now);

    UINT32 index = (pBuffer->index - pBuffer->size) & BufferMask;

    if (pBuffer->times[index].QuadPart <= now.QuadPart)
    {
        *outPacket = &pBuffer->packets[index * 1024];
        *outPacketSize = pBuffer->packetSizes[index];
        *endpoint = pBuffer->endpoints[index];

        pBuffer->size = pBuffer->size - 1;

        return true;
    }

    return false;
}

bool Blam::Network::Init()
{
    WORD wsaVersion = 0x202;
    WSADATA wsaData;

    if (WSAStartup(wsaVersion, &wsaData))
    {
        Blam::LogEvent("WSAStartup failed");

        return false;
    }

    return true;
}

bool Blam::Network::Start(Socket* sock, LinearAllocator *allocator)
{
    *sock = {};

    if (!Internal::socket(&sock->sock))
    {
        return false;
    }

    sock->sendBuf = packetBuffer(allocator);
    sock->recvBuf = packetBuffer(allocator);

    return true;
}

void Blam::Network::Close(Socket* sock)
{
    while(sock->sendBuf.size)
    {
        UINT8* packet;
        UINT32 packetSize;
        Endpoint destination;
        bufferForcePop(&sock->sendBuf, &packet, &packetSize, &destination);
        Internal::socket_send(&sock->sock, packet, packetSize, &destination);
    }

    Internal::socket_close(&sock->sock);
}

bool Blam::Network::Bind(Socket* sock, Endpoint* endpoint)
{
    if (Internal::socket_bind(&sock->sock, endpoint))
    {
        sock->receive = 1;

        return true;
    }

    return false;
}

bool Blam::Network::Send(Socket* sock, UINT8* packet, UINT32 packetSize, Endpoint* endpoint)
{
    if (bufferIsFull(&sock->sendBuf))
    {
        Blam::LogEvent("Packet Send Buffer is full.. this is not good!");

        UINT8* pendingPacket;
        UINT32 pendingPacketSize;
        Endpoint pendingPacketDst;

        bufferForcePop(&sock->sendBuf, &pendingPacket, &pendingPacketSize, &pendingPacketDst);

        Internal::socket_send(&sock->sock, pendingPacket, pendingPacketSize, &pendingPacketDst);
    }

    bufferPush(&sock->sendBuf, packet, packetSize, endpoint);

    return true;
}

bool Blam::Network::Receive(Socket* sock, UINT8* buffer, UINT32 bufferSize, UINT32* outPacketSize, Endpoint* outFrom)
{
    UINT32 packetSize;
    Endpoint endpoint;
    UINT8* packet;

    while (sock->sendBuf.size && bufferPop(&sock->sendBuf, &packet, &packetSize, &endpoint))
    {
        bool result = Internal::socket_send(&sock->sock, packet, packetSize, &endpoint);

        assert(result);

        sock->receive = 1;
    }

    if (!sock->receive)
    {
        return false;
    }

    if (bufferIsFull(&sock->recvBuf))
    {
        Blam::LogEvent("Packet Recv Buffer is full.. this is not good!");
    }

    while (!bufferIsFull(&sock->recvBuf) && Internal::socket_receive(&sock->sock, buffer, bufferSize, &packetSize, &endpoint))
    {
        bufferPush(&sock->recvBuf, buffer, packetSize, &endpoint);
    }

    if (sock->recvBuf.size && bufferPop(&sock->recvBuf, &packet, outPacketSize, outFrom))
    {
        memcpy(buffer, packet, *outPacketSize);

        return true;
    }

    return false;
}

void Blam::Network::HandleReceive(Socket* sock, UINT8* socketBuffer, UINT32 bufferSize, UINT32* outPacketSize, Endpoint* outFrom)
{
    while(Blam::Network::Receive(sock, socketBuffer, 2048, outPacketSize, outFrom))
    {
        UINT16 packetid = Reader::ReadShort(&socketBuffer);

        switch(ServerMessages(packetid))
        {
        case ServerHello:
            Blam::LogEvent("Received ServerHelloMessage!");
            break;
        default: ;
        }
    }
}

bool Blam::Network::IsConnected()
{
    return isConnected;
}