#include "example.h"

#ifdef USE_MINGW_STD_THREAD
	#include <mingw-std-threads/mingw.thread.h>
#else
	#include <thread>
#endif

#include <iostream>
#include <vector>
#include <chrono>
#include <numeric>
#include <atomic>

#ifdef ENABLE_DEBUG
#include <cstdarg>
#endif

#ifdef ENABLE_DEBUG
const int LOG_BUFFER_SIZE = 2048;
char log_buffer[LOG_BUFFER_SIZE];

inline void DLOG(const char* ctx, const char* format, ...) {
	va_list args;
	va_start(args, format);
	std::vsnprintf(log_buffer, LOG_BUFFER_SIZE-1, format, args);
	va_end(args);

	std::cout << "[DEBUG] [" << ctx << "] " << log_buffer << std::endl;
}
#else
	#define DLOG(...)
#endif

EXAMPLE_API int add(int a, int b) noexcept {
	return a + b;
}

EXAMPLE_API int sub(int a, int b) noexcept {
	return a - b;
}

EXAMPLE_API int num_hardware_concurrency() noexcept {
	return std::thread::hardware_concurrency();
}

EXAMPLE_API int single_threaded_sum(const int arr[], int num_elem) {
	auto start = std::chrono::steady_clock::now();

	int local_sum = 0;
	for (int i=0; i<num_elem; ++i) {
		local_sum += arr[i];
	}

	std::chrono::duration<double, std::milli> exec_time = std::chrono::steady_clock::now() - start;
	std::cout << "elapsed time: " << exec_time.count() << "ms" << std::endl;
	return local_sum;
}

//EXAMPLE_API int multi_threaded_sum_v1(const int* const& arr, int num_elem) {
//	auto start = std::chrono::steady_clock::now();
//
//	std::vector<std::pair<int, int>> arr_indexes;
//
//	const int num_max_threads = std::thread::hardware_concurrency() == 0 ? 2 : std::thread::hardware_concurrency();
//	const int chunk_work_size = num_elem / num_max_threads;
//
//	std::atomic<int> shared_total_sum(0);
//
//	// a lambda function that accepts input vector by copy
//	auto worker_func = [&shared_total_sum](const std::vector<int> arr) {
//		int local_sum = 0;
//		for (auto elem_val : arr) {
//			local_sum += elem_val;
//		}
//		shared_total_sum += local_sum;
//	};
//
//	DLOG("multi_threaded_sum_v1", "chunk_work_size=%d", chunk_work_size);
//	DLOG("multi_threaded_sum_v1", "num_max_threads=%d", num_max_threads);
//
//	std::vector<std::thread> threads;
//	threads.reserve(num_max_threads);
//
//	for (int i=0; i<num_max_threads; ++i) {
//		int start = i * chunk_work_size;
//		// also check if there's remaining to piggyback works into the last chunk
//		int end = (i == num_max_threads-1) && (start + chunk_work_size < num_elem-1) ? num_elem : start+chunk_work_size;
//
//		std::vector<int> copy_data;
//		copy_data.reserve(end-start+1);
//		for (std::size_t i=start; i<end; ++i) {
//			copy_data.push_back(arr[i]);
//		}
//
//		threads.emplace_back(worker_func, copy_data);	// note that we pass vector by copy, no need std::ref()
//	}
//
//	DLOG("multi_threaded_sum_v1", "thread_size=%d", threads.size());
//
//	for (auto& th : threads) {
//		th.join();
//	}
//
//	std::chrono::duration<double, std::milli> exec_time = std::chrono::steady_clock::now() - start;
//	std::cout << "elapsed time: " << exec_time.count() << "ms" << std::endl;
//
//	return shared_total_sum;
//}

EXAMPLE_API int multi_threaded_sum_v2(const int arr[], int num_elem) {
	auto start = std::chrono::steady_clock::now();

	std::vector<std::pair<int, int>> arr_indexes;

	const int num_max_threads = std::thread::hardware_concurrency() == 0 ? 2 : std::thread::hardware_concurrency();
	const int chunk_work_size = num_elem / num_max_threads;

	std::atomic<int> shared_total_sum(0);

	// a lambda function that accepts input vector by reference
	auto worker_func = [&shared_total_sum](const int arr[], std::pair<int, int> indexes) {
		int local_sum = 0;
		for (int i=indexes.first; i<indexes.second; ++i) {
			local_sum += arr[i];
		}
		shared_total_sum += local_sum;
	};

	DLOG("multi_threaded_sum_v2", "chunk_work_size=%d", chunk_work_size);
	DLOG("multi_threaded_sum_v2", "num_max_threads=%d", num_max_threads);

	std::vector<std::thread> threads;
	threads.reserve(num_max_threads);

	for (int i=0; i<num_max_threads; ++i) {
		int start = i * chunk_work_size;
		// also check if there's remaining to piggyback works into the last chunk
		int end = (i == num_max_threads-1) && (start + chunk_work_size < num_elem-1) ? num_elem : start+chunk_work_size;
		threads.emplace_back(worker_func, arr, std::make_pair(start, end));
	}

	DLOG("multi_threaded_sum_v2", "thread_size=%d", threads.size());

	for (auto& th : threads) {
		th.join();
	}

	std::chrono::duration<double, std::milli> exec_time = std::chrono::steady_clock::now() - start;
	std::cout << "elapsed time: " << exec_time.count() << "ms" << std::endl;

	return shared_total_sum;
}
