#include <algorithm>
#include <cassert>
#include <numeric>
#include <thread>
#include <vector>
#include <Windows.h>

static const size_t NUM_THREADS = 2;
static const size_t NUM_LOOPS = 10000;

CRITICAL_SECTION cs;
long sharedCounter = 0;
std::vector<size_t> lockHistory;
HANDLE readyEvents[NUM_THREADS];
float sinTab[NUM_LOOPS];


void Run(size_t tid)
{
	// Signal our event, wait for others
	SetEvent(readyEvents[tid]);
	WaitForMultipleObjects(NUM_THREADS, readyEvents, TRUE, INFINITE);
	for (size_t i = 0; i != NUM_LOOPS; ++i)
	{
		EnterCriticalSection(&cs);
	   ++sharedCounter;
		lockHistory.push_back(tid);
		LeaveCriticalSection(&cs);
		for (size_t j = 0; j != 10; ++j) // 300
		{
			sinTab[j] = sinf(j) * cosf(j);
		}
	}
	printf("%zd done\n", tid);
}

int main()
{
#if 0
	InitializeCriticalSection(&cs);
#else
	// 2000 = default (ie. InitializeCriticalSection)
	// 33556432
	InitializeCriticalSectionAndSpinCount(&cs, 1);
#endif
	for (size_t i = 0; i != NUM_THREADS; ++i)
	{
		readyEvents[i] = CreateEvent(NULL, TRUE, FALSE, nullptr);
	}

	std::vector<std::thread> threads;
	lockHistory.reserve(NUM_THREADS * NUM_LOOPS);
	for (size_t i = 0; i != NUM_THREADS; ++i)
	{
		threads.push_back(std::thread(Run, i));
		DWORD_PTR dw = SetThreadAffinityMask(threads.back().native_handle(), (1ull << i));
	}

	for (auto&& th : threads)
	{
		th.join();
	}

	assert(lockHistory.size() == NUM_THREADS * NUM_LOOPS);
	assert(sharedCounter == NUM_THREADS * NUM_LOOPS);

	size_t prevOwner = lockHistory.front();
	std::vector<size_t> streakLengths;
	size_t currStreakLen = 1;
	for (size_t i = 1; i != lockHistory.size(); ++i)
	{
		if (prevOwner != lockHistory[i])
		{
			streakLengths.push_back(currStreakLen);
			currStreakLen = 1;
		}
		else
		{
			++currStreakLen;
		}
		prevOwner = lockHistory[i];
	}
	streakLengths.push_back(currStreakLen);
	const size_t avgStreakLen = std::accumulate(streakLengths.begin(), streakLengths.end(), 0ull) / streakLengths.size();
	const size_t numOwnershipChanges = streakLengths.size() - 1;
	std::sort(streakLengths.begin(), streakLengths.end());
	printf("Num ownership changes: %zd, avg streak: %zd\n", numOwnershipChanges, avgStreakLen);
	printf("Min streak: %zd, max streak: %zd\n", streakLengths.front(), streakLengths.back());

	for (auto&& e : readyEvents)
	{
		CloseHandle(e);
	}
	DeleteCriticalSection(&cs);

	return 0;
}
