#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <time.h>

template <typename T>
class Matrix
{
private:
	size_t size;
	T *data; // [size * size], row-major order
public:
	Matrix()
		: size(0), data(0)
	{}

	// random matrix
	Matrix(size_t size)
		: size(size), data(0)
	{
		data = new T[size * size];
	}

	Matrix &randomize()
	{
		for (size_t y = 0; y < size; ++y) {
			for (size_t x = 0; x < size; ++x) {
				data[y * size + x] = rand() / T(RAND_MAX) - T(.5);
			}
		}

		return *this;
	}

	// call on submatrix (size/2)
	Matrix(size_t size, size_t stride_x, size_t stride_y, T *r_data)
		: size(size), data(0)
	{
		data = new T[size * size];

		for (size_t y = 0; y < size; ++y) {
			for (size_t x = 0; x < size; ++x) {
				data[y * size + x] = r_data[y * stride_y + x * stride_x];
			}
		}
	}

	// copy matrix
	Matrix(const Matrix &matrix)
		: size(matrix.size), data(0)
	{
		data = new T[size * size];

		for (size_t i = 0; i < size * size; i++) {
			data[i] = matrix.data[i];
		}
	}

	~Matrix()
	{
		if (data != 0) {
			delete[] data;
		}
	}

	void print() const
	{
		printf("[\n");
		for (size_t y = 0; y < size; ++y) {
			for (size_t x = 0; x < size; ++x) {
				printf("%+f ", data[y * size + x]);
				if (x + 1 >= size && y + 1 < size)
					printf(";\n");
			}
		}
		printf("]\n");
	}

	friend void swap(Matrix& first, Matrix& second) // nothrow
	{
	        using std::swap;

	        swap(first.size, second.size);
	        swap(first.data, second.data);
	}

	Matrix& operator =(Matrix r)
	{
		swap(*this, r);

		return *this;
	}

	Matrix& operator +=(const Matrix &r)
	{
		assert(this->size == r.size);

		for (size_t y = 0; y < size; ++y) {
			for (size_t x = 0; x < size; ++x) {
				data[y * size + x] += r.data[y * size + x];
			}
		}

		return *this;
	}

	Matrix operator +(const Matrix &r) const
	{
		Matrix copy(*this);

		copy += r;

		return copy;
	}

	Matrix& operator -=(const Matrix &r)
	{
		assert(size == r.size);

		for (size_t y = 0; y < size; ++y) {
			for (size_t x = 0; x < size; ++x) {
				data[y * size + x] -= r.data[y * size + x];
			}
		}

		return *this;
	}

	Matrix operator -(const Matrix &r) const
	{
		Matrix copy(*this);

		copy -= r;

		return copy;
	}

	Matrix operator +() const
	{
		return *this;
	}

	Matrix &merge1(size_t size, T *r_data, size_t data_offset)
	{
		size_t size_2 = size / 2;

		for (size_t y = 0; y < size_2; ++y) {
			for (size_t x = 0; x < size_2; ++x) {
				data[y * size + x + data_offset] = r_data[y * size_2 + x];
			}
		}

		return *this;
	}

	Matrix& merge4(const Matrix &s1, const Matrix &s2, const Matrix &s3, const Matrix &s4)
	{
		size = s1.size * 2;

		assert(s1.size == size / 2);
		assert(s2.size == size / 2);
		assert(s3.size == size / 2);
		assert(s4.size == size / 2);

		// new size ==> new data[]
		delete[] data;
		data = new T[size * size];

		merge1(size, s1.data, 0);
		merge1(size, s2.data, size/2);
		merge1(size, s3.data, (size/2) * size);
		merge1(size, s4.data, (size/2) * size + size/2);

		return *this;
	}

	enum submatrix {
		S1 = 0,
		S2 = 1,
		S3 = 2,
		S4 = 3,
		S11 = 0,
		S12 = 1,
		S21 = 2,
		S22 = 3
	};

	Matrix getSubMatrix(enum submatrix s) const
	{
		switch(s) {
			case S1: {
				Matrix s1(size/2, 1, size, data + 0);
				return s1;
			}
			case S2: {
				Matrix s2(size/2, 1, size, data + size/2);
				return s2;
			}
			case S3: {
				Matrix s3(size/2, 1, size, data + size * (size/2));
				return s3;
			}
			case S4: {
				Matrix s4(size/2, 1, size, data + size * (size/2) + size/2);
				return s4;
			}
			default: abort();
		}
	}

	// *this = A, r = B
	Matrix operator *(const Matrix &r) const
	{
		assert(size == r.size);

		if (size <= 32) {
			return mul_ref_dot(r);
		}

		if (size == 1) {
			Matrix copy(*this);
			copy.data[0] = data[0] * r.data[0];
			return copy;
		}

		return mul_Strassen_0(r);
	}

	// Iterative algorithm
	Matrix mul_ref_dot(const Matrix &r) const
	{
		Matrix c(size);

		for (size_t i = 0; i < size; ++i) {
			for (size_t j = 0; j < size; ++j) {
				T sum = T(0);
				for (size_t k = 0; k < size; ++k) {
					sum += data[i * size + k] * r.data[k * size + j];
				}
				c.data[i * size + j] = sum;
			}
		}

		return c;
	}

	// Strassen algorithm
	Matrix mul_Strassen_0(const Matrix &r) const
	{
		Matrix A11 = getSubMatrix(S1);
		Matrix A12 = getSubMatrix(S2);
		Matrix A21 = getSubMatrix(S3);
		Matrix A22 = getSubMatrix(S4);

		Matrix B11 = r.getSubMatrix(S1);
		Matrix B12 = r.getSubMatrix(S2);
		Matrix B21 = r.getSubMatrix(S3);
		Matrix B22 = r.getSubMatrix(S4);

		Matrix M0 = (A11 + A22) * (B11 + B22);
		Matrix M1 = (A21 + A22) * B11;
		Matrix M2 = A11 * (B12 - B22);
		Matrix M3 = A22 * (B21 - B11);
		Matrix M4 = (A11 + A12) * B22;
		Matrix M5 = (A21 - A11) * (B11 + B12);
		Matrix M6 = (A12 - A22) * (B21 + B22);

		Matrix C11 = M0 + M3 - M4 + M6;
		Matrix C12 = M2 + M4;
		Matrix C21 = M1 + M3;
		Matrix C22 = M0 - M1 + M2 + M5;

		Matrix C(size);
		C.merge4(C11, C12, C21, C22);

		return C;
	}
};

#define TYPE double

int main()
{
	/*
	 * size of matrices is 4096x4096
	 */
	int SIZE = 4096;

	srand(42);

	unsigned long start_time, stop_time;
	struct timespec ts;

	/*
	 * fill matrices A and B with random numbers
	 */
	Matrix<TYPE> A(SIZE);
	A.randomize();

	Matrix<TYPE> B(SIZE);
	B.randomize();

	if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) {
		printf("[ERROR] clock_gettime\n");
		abort();
	}
	start_time = ts.tv_sec * 1000000000 + ts.tv_nsec;

	/*
	 * multiply A * B
	 */
	Matrix<TYPE> C = A * B;

	if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) {
		printf("[ERROR] clock_gettime\n");
		abort();
	}
	stop_time = ts.tv_sec * 1000000000 + ts.tv_nsec;

	/*
	 * print time measurement
	 */
	printf("matrix size: %i\n", SIZE);
	printf("CLOCK_MONOTONIC: %lu secs in total; %lu microsecs in total; %f microsecs/element; %f microsecs/mul\n",
		(stop_time - start_time + 500000000) / 1000000000,
		(stop_time - start_time + 500) / 1000,
		(double)((stop_time - start_time + 500) / 1000) / SIZE / SIZE,
		(double)((stop_time - start_time + 500) / 1000) / SIZE / SIZE / SIZE
	);

	/*
	 * print matrices
	 */
	// A.print();
	// B.print();
	// C.print();

	return 0;
}
