#pragma comment(lib, "Winhttp.lib")

#include "pch.h"
#include "asyncwebsocketclient.h"

#define WEBSOCKET_ERROR_FIRST                       WINHTTP_ERROR_LAST+1000
#define WEBSOCKET_ERROR_INVALID_HANDLE              WEBSOCKET_ERROR_FIRST+1
#define WEBSOCKET_ERROR_FAILED_OPERATION            WEBSOCKET_ERROR_FIRST+2
#define WEBSOCKET_ERROR_CLOSING_ACTIVE_CONNECTION   WEBSOCKET_ERROR_FIRST+3
#define WEBSOCKET_ERROR_INVALID_PARAMETER           WEBSOCKET_ERROR_FIRST+4
#define WEBSOCKET_ERROR_EMPTY_SEND_BUFFER           WEBSOCKET_ERROR_FIRST+5
#define WEBSOCKET_ERROR_NOT_CONNECTED               WEBSOCKET_ERROR_FIRST+6

#define WINHTTP_CALLBACK_STATUS_DEFAULT             0x3389

namespace WinHttpWebSocketClient {
	std::map<HINTERNET, std::shared_ptr<WebSocketClient>>WinHttpWebSocketClient::clients;

	void WebSocketCallback(HINTERNET hInternet, DWORD_PTR dwContext, DWORD dwInternetStatus, LPVOID lpvStatusInformation, DWORD dwStatusInformationLength)
	{
		if (WinHttpWebSocketClient::clients.find(hInternet)!= WinHttpWebSocketClient::clients.end())
		{
			WinHttpWebSocketClient::clients[hInternet]->OnCallBack(dwInternetStatus);
			switch (dwInternetStatus)
			{
			case WINHTTP_CALLBACK_STATUS_CLOSE_COMPLETE:
				WinHttpWebSocketClient::clients[hInternet]->OnClose();
				break;

			case WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE:
				WinHttpWebSocketClient::clients[hInternet]->OnSendComplete(((WINHTTP_WEB_SOCKET_STATUS*)lpvStatusInformation)->dwBytesTransferred);
				break;

			case WINHTTP_CALLBACK_STATUS_READ_COMPLETE:
				WinHttpWebSocketClient::clients[hInternet]->OnReadComplete(((WINHTTP_WEB_SOCKET_STATUS*)lpvStatusInformation)->dwBytesTransferred, ((WINHTTP_WEB_SOCKET_STATUS*)lpvStatusInformation)->eBufferType);
				break;

			case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR:
				WinHttpWebSocketClient::clients[hInternet]->OnError((WINHTTP_ASYNC_RESULT*)lpvStatusInformation);
				break;

			default:
				break;
			}
		}

	}

	WebSocketClient::WebSocketClient(VOID)
	{
		hWebSocket = hConnect = hSession = hRequest = NULL;
		rxBuffer.resize(BUFFERSIZE);
		frames = new std::queue<Frame>();
		ErrorCode = 0;
		initialized = 0;
		bytesRX = 0;
		bytesTX = 0;
		rxBufferType = {};
		status = ENUM_WEBSOCKET_STATE::CLOSED;
		completed_websocket_operation = 0;
	}

	WebSocketClient::~WebSocketClient(VOID)
	{
		Free();
	}

	VOID WebSocketClient::Reset(bool reset_error)
	{

		if (reset_error)
			ErrorCode = 0;

		if (hWebSocket)
		{
			WinHttpSetStatusCallback(hWebSocket, NULL, WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS, 0);
			WinHttpCloseHandle(hWebSocket);
			WinHttpCloseHandle(hRequest);
			WinHttpCloseHandle(hConnect);
			hConnect = hRequest = hWebSocket = NULL;
		}

		initialized = 0;
		bytesTX = 0;
		bytesRX = 0;
		status = ENUM_WEBSOCKET_STATE::CLOSED;

		if (frames)
		{
			while (!frames->empty())
				frames->pop();
		}
	}

	DWORD WebSocketClient::Initialize(VOID)
	{
		// Set hSession
		hSession = WinHttpOpen(L"MyApp", WINHTTP_ACCESS_TYPE_NO_PROXY, WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, WINHTTP_FLAG_ASYNC);
		if (hSession == NULL)
			ErrorCode = ERROR_INVALID_HANDLE;

		// Return error code
		return ErrorCode;
	}



	DWORD WebSocketClient::Close(WINHTTP_WEB_SOCKET_CLOSE_STATUS close_status, CHAR* reason)
	{
		status = ENUM_WEBSOCKET_STATE::CLOSING;
		// Length of reason in bytes
		DWORD reasonLen;
		if (reason == NULL)
			reasonLen = 0;
		else 
			reasonLen = (DWORD)strlen(reason);

		// Gracefully close the connection
		ErrorCode = WinHttpWebSocketClose(hWebSocket, (USHORT)close_status, (PVOID)reason, reasonLen);

		// Return error code
		return ErrorCode;
	}

	DWORD WebSocketClient::Connect(const WCHAR* host, const INTERNET_PORT port, const DWORD secure)
	{

		if ((status != ENUM_WEBSOCKET_STATE::CLOSED))
		{
			ErrorCode = Close(WINHTTP_WEB_SOCKET_CLOSE_STATUS::WINHTTP_WEB_SOCKET_SUCCESS_CLOSE_STATUS,NULL);
			return WEBSOCKET_ERROR_CLOSING_ACTIVE_CONNECTION;
		}

		status = ENUM_WEBSOCKET_STATE::CONNECTING;


		// Return 0 for success
		if (hSession == NULL)
		{
			ErrorCode = Initialize();
			if (ErrorCode)
			{
				Reset(false);
				return ErrorCode;
			}
		}

		// Cracked URL variable pointers
		URL_COMPONENTS UrlComponents;

		// Create cracked URL buffer variables
		std::unique_ptr <WCHAR> scheme(new WCHAR[0x20]);
		std::unique_ptr <WCHAR> hostName(new WCHAR[0x100]);
		std::unique_ptr <WCHAR> urlPath(new WCHAR[0x1000]);


		DWORD dwFlags = 0;
		if (secure)
			dwFlags |= WINHTTP_FLAG_SECURE;

		if (scheme == NULL || hostName == NULL || urlPath == NULL) {
			ErrorCode = ERROR_NOT_ENOUGH_MEMORY;
			Reset();
			return ErrorCode;
		}

		// Clear error's
		ErrorCode = 0;

		// Setup UrlComponents structure
		memset(&UrlComponents, 0, sizeof(URL_COMPONENTS));
		UrlComponents.dwStructSize = sizeof(URL_COMPONENTS);
		UrlComponents.dwSchemeLength = -1;
		UrlComponents.dwHostNameLength = -1;
		UrlComponents.dwUserNameLength = -1;
		UrlComponents.dwPasswordLength = -1;
		UrlComponents.dwUrlPathLength = -1;
		UrlComponents.dwExtraInfoLength = -1;

		// Get the individual parts of the url
		if (!WinHttpCrackUrl(host, NULL, 0, &UrlComponents))
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset();
			return ErrorCode;
		}

		// Copy cracked URL hostName & UrlPath to buffers so they are separated
		if (wcsncpy_s(scheme.get(), 0x20, UrlComponents.lpszScheme, UrlComponents.dwSchemeLength) != 0 ||
			wcsncpy_s(hostName.get(), 0x100, UrlComponents.lpszHostName, UrlComponents.dwHostNameLength) != 0 ||
			wcsncpy_s(urlPath.get(), 0x1000, UrlComponents.lpszUrlPath, UrlComponents.dwUrlPathLength) != 0)
		{
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		if (port == 0) {
			if ((_wcsicmp(scheme.get(), L"wss") == 0) || (_wcsicmp(scheme.get(), L"https") == 0)) {
				UrlComponents.nPort = INTERNET_DEFAULT_HTTPS_PORT;
			}
			else if ((_wcsicmp(scheme.get(), L"ws") == 0) || (_wcsicmp(scheme.get(), L"http")) == 0) {
				UrlComponents.nPort = INTERNET_DEFAULT_HTTP_PORT;
			}
			else {
				ErrorCode = ERROR_INVALID_PARAMETER;
				Reset(false);
				return ErrorCode;
			}
		}
		else
			UrlComponents.nPort = port;


		// Call the WinHttp Connect method
		hConnect = WinHttpConnect(hSession, hostName.get(), UrlComponents.nPort, 0);
		if (!hConnect)
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}



		// Create a HTTP request
		hRequest = WinHttpOpenRequest(hConnect, L"GET", urlPath.get(), NULL, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, dwFlags);
		if (!hRequest)
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		// Set option for client certificate

		if (!WinHttpSetOption(hRequest, WINHTTP_OPTION_CLIENT_CERT_CONTEXT, WINHTTP_NO_CLIENT_CERT_CONTEXT, 0))
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}


		// Add WebSocket upgrade to our HTTP request
#pragma prefast(suppress:6387, "WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET does not take any arguments.")
		if (!WinHttpSetOption(hRequest, WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, 0, 0))
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		// Send the WebSocket upgrade request.
		if (!WinHttpSendRequest(hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, 0, 0, 0, 0))
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		// Receive response from the server
		if (!WinHttpReceiveResponse(hRequest, 0))
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		// Finally complete the upgrade
		hWebSocket = WinHttpWebSocketCompleteUpgrade(hRequest, NULL);
		if (hWebSocket == 0)
		{
			// Handle error
			ErrorCode = GetLastError();
			Reset(false);
			return ErrorCode;
		}

		status = ENUM_WEBSOCKET_STATE::CONNECTED;

		// Return should be zero
		return ErrorCode;
	}

	VOID  WebSocketClient::Free(VOID)
	{
		Reset();
		if(hSession!=NULL)
		   WinHttpCloseHandle(hSession);
		hSession = NULL;
		// Free resources
		delete frames;
		frames = NULL;
	}

	DWORD WebSocketClient::EnableCallBack(VOID)
	{
		if (!WinHttpSetOption(hWebSocket, WINHTTP_OPTION_CONTEXT_VALUE, (LPVOID)this, sizeof(this)))
		{
			// Handle error
			ErrorCode = GetLastError();
			return ErrorCode;
		}


		if (WinHttpSetStatusCallback(hWebSocket, (WINHTTP_STATUS_CALLBACK)WebSocketCallback, WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS, 0) == WINHTTP_INVALID_STATUS_CALLBACK)
		{
			ErrorCode = GetLastError();
			return ErrorCode;
		}

		return ErrorCode;
	}

	DWORD WebSocketClient::Receive(PVOID pBuffer, DWORD pLength, DWORD* bytesRead, WINHTTP_WEB_SOCKET_BUFFER_TYPE* pBufferType)
	{

		status = ENUM_WEBSOCKET_STATE::POLLING;

		ErrorCode = WinHttpWebSocketReceive(hWebSocket, pBuffer, pLength, bytesRead, pBufferType);

		return ErrorCode;
	}

	VOID WebSocketClient::OnSendComplete(const DWORD sent)
	{
		bytesTX = sent;
		status = ENUM_WEBSOCKET_STATE::CONNECTED;
		return;
	}

	ENUM_WEBSOCKET_STATE WebSocketClient::Status(VOID)
	{
		return status;
	}

	DWORD WebSocketClient::Send(WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType, void* pBuffer, DWORD dwLength)
	{
		status = ENUM_WEBSOCKET_STATE::SENDING;
		// Send the data to the server
		ErrorCode = WinHttpWebSocketSend(hWebSocket, bufferType, pBuffer, dwLength);

		// Return error code
		return ErrorCode;
	}

	DWORD WebSocketClient::QueryCloseStatus(USHORT* pusStatus, PVOID pvReason, DWORD dwReasonLength, DWORD* pdwReasonLengthConsumed)
	{
		// Query the close status
		ErrorCode = WinHttpWebSocketQueryCloseStatus(hWebSocket, pusStatus, pvReason, dwReasonLength, pdwReasonLengthConsumed);

		// Return error code
		return ErrorCode;
	}

	HINTERNET WebSocketClient::WebSocketHandle(VOID)
	{
		return hWebSocket;
	}

	VOID WebSocketClient::Read(BYTE* pBuffer, DWORD pLength, WINHTTP_WEB_SOCKET_BUFFER_TYPE* pBufferType)
	{
		if (!frames->empty())
		{
			Frame& frame = frames->front();
			std::copy(frame.frame_buffer.begin(), frame.frame_buffer.end(), pBuffer);
			*pBufferType = frame.frame_type;
			frames->pop();
		}
		return;
	}

	VOID WebSocketClient::OnReadComplete(const DWORD read, const WINHTTP_WEB_SOCKET_BUFFER_TYPE buffertype)
	{
		bytesRX = read;
		rxBufferType = buffertype;
		status = ENUM_WEBSOCKET_STATE::CONNECTED;
		Frame frame;
		frame.frame_buffer.insert(frame.frame_buffer.begin(), rxBuffer.data(), rxBuffer.data() + read);
		frame.frame_type = buffertype;
		frame.frame_size = read;
		frames->push(frame);
	}

	DWORD WebSocketClient::ReadAvailable(VOID)
	{
		if (frames->empty())
			return 0;
		else
		{
			Frame& front = frames->front();
			return front.frame_size;
		}
	}

	VOID WebSocketClient::OnClose(VOID)
	{
		Reset();
	}

	DWORD WebSocketClient::LastError(VOID)
	{
		return ErrorCode;
	}

	DWORD WebSocketClient::LastOperation(VOID)
	{
		return completed_websocket_operation;
	}

	VOID WebSocketClient::SetError(const DWORD errorcode)
	{
		ErrorCode = errorcode;
	}

	VOID WebSocketClient::OnError(const WINHTTP_ASYNC_RESULT* result)
	{
		SetError(result->dwError);
		Reset(false);
	}

	VOID WebSocketClient::OnCallBack(const DWORD operation)
	{
		completed_websocket_operation = operation;
	}

	VOID WEBSOCK_API client_reset(HINTERNET websocket_handle)
	{
		if (clients.find(websocket_handle) != clients.end())
		{
			clients[websocket_handle]->Free();
			clients.erase(websocket_handle);
		}
	}

	DWORD  WEBSOCK_API client_connect( const WCHAR* url, INTERNET_PORT port, DWORD secure, HINTERNET* websocketp_handle)
	{
		DWORD errorCode = 0;
		auto client = std::make_shared<WebSocketClient>();

		if (client->Connect(url, port, secure) != NO_ERROR)
			errorCode = client->LastError();
		else
		{
			HINTERNET handle = client->WebSocketHandle();
			if (client->EnableCallBack())
			{
				errorCode = client->LastError();
				client->Close(WINHTTP_WEB_SOCKET_CLOSE_STATUS::WINHTTP_WEB_SOCKET_SUCCESS_CLOSE_STATUS);
				client->Free();
				handle = NULL;
			}
			else
			{
				clients[handle] = client;
				*websocketp_handle = handle;
			}
		}
		return errorCode;
	}

	void WEBSOCK_API client_disconnect(HINTERNET websocket_handle)
	{

		if (clients.find(websocket_handle) != clients.end())
		{
			if (clients[websocket_handle]->WebSocketHandle() != NULL)
			{
				clients[websocket_handle]->Close(WINHTTP_WEB_SOCKET_CLOSE_STATUS::WINHTTP_WEB_SOCKET_SUCCESS_CLOSE_STATUS);
			}
		}

		return;
	}

	DWORD WEBSOCK_API client_send(HINTERNET websocket_handle, WINHTTP_WEB_SOCKET_BUFFER_TYPE buffertype, BYTE* message, DWORD length)
	{
		DWORD out = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = WEBSOCKET_ERROR_INVALID_HANDLE;
		else
			out = clients[websocket_handle]->Send(buffertype, message, length);

		return out;
	}


	DWORD WEBSOCK_API client_poll(HINTERNET websocket_handle)
	{
		DWORD out = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = WEBSOCKET_ERROR_INVALID_HANDLE;
		else
			out = clients[websocket_handle]->Receive(clients[websocket_handle]->rxBuffer.data(), (DWORD)clients[websocket_handle]->rxBuffer.size(), &clients[websocket_handle]->bytesRX, &clients[websocket_handle]->rxBufferType);

		return out;
	}


	DWORD WEBSOCK_API client_read(HINTERNET websocket_handle, BYTE* out, DWORD out_size, WINHTTP_WEB_SOCKET_BUFFER_TYPE* buffertype)
	{
		DWORD rout = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			rout = WEBSOCKET_ERROR_INVALID_HANDLE;
		else
			clients[websocket_handle]->Read(out, out_size, buffertype);

		return rout;
	}

	DWORD WEBSOCK_API  client_lasterror(HINTERNET websocket_handle)
	{
		DWORD out = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = WEBSOCKET_ERROR_INVALID_HANDLE;
		else
			out = clients[websocket_handle]->LastError();

		return out;
	}

	DWORD WEBSOCK_API client_readable(HINTERNET websocket_handle)
	{
		DWORD out = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = 0;
		else
			out = clients[websocket_handle]->ReadAvailable();

		return out;
	}

	ENUM_WEBSOCKET_STATE WEBSOCK_API client_status(HINTERNET websocket_handle)
	{
		ENUM_WEBSOCKET_STATE out = {};
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = {};
		else
			out = clients[websocket_handle]->Status();

		return out;
	}

	DWORD WEBSOCK_API client_lastcallback_notification(HINTERNET websocket_handle)
	{
		DWORD out = 0;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = WEBSOCKET_ERROR_INVALID_HANDLE;
		else
			out = clients[websocket_handle]->LastOperation();

		return out;
	}

	HINTERNET WEBSOCK_API client_websocket_handle(HINTERNET websocket_handle)
	{
		HINTERNET out = NULL;
		if (websocket_handle == NULL || clients.find(websocket_handle) == clients.end())
			out = NULL;
		else
			out = clients[websocket_handle]->WebSocketHandle();
		return out;
	}

}