//==============================================================================
/*! \file
 * OpenMesh Toolkit for mesh analysis    \n
 * Copyright (c) 2010 by Rostislav Hulik     \n
 *
 * Author:  Rostislav Hulik, rosta.hulik@gmail.com  \n
 * Date:    2010/11/21                          \n
 *
 * This file is part of software developed for support of Rostislav Hulik's dissertation thesis at dcgm-robotics@FIT group.
 *
 * This file is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This file is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this file.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Module description:
 * - Module rasterizes vertex neighbourhood on a tangent raster
 * - Result is sent to connected MDSTk channel
 */

#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable 
#pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable 


// variable constants
#define MATRIX_SIZE 16
#define TRIANGLES_PER_BATCH 448

#define WORKGROUPSIZE 64

// Copyright 2001, softSurfer (www.softsurfer.com)
// This code may be freely used and modified for any purpose
// providing that this copyright notice is included with it.
// SoftSurfer makes no warranty for this code, and cannot be held
// liable for any real or imagined damage resulting from its use.
// Users of this code must verify correctness for their application.

// Assume that classes are already given for the objects:
//    Point and Vector with
//        coordinates {float x, y, z;}
//        operators for:
//            == to test equality
//            != to test inequality
//            (Vector)0 = (0,0,0)         (null vector)
//            Point  = Point  Vector
//            Vector = Point - Point
//            Vector = Scalar * Vector    (scalar product)
//            Vector = Vector * Vector    (cross product)
//    Line and Ray and Segment with defining points {Point P0, P1;}
//        (a Line is infinite, Rays and Segments start at P0)
//        (a Ray extends beyond P1, but a Segment ends at P1)
//    Plane with a point and a normal {Point V0; Vector n;}
//    Triangle with defining vertices {Point V0, V1, V2;}
//    Polyline and Polygon with n vertices {int n; Point *V;}
//        (a Polygon has V[n]=V[0])
//===================================================================

#define SMALL_NUM  0.00000001f // anything that avoids division overflow

// intersect_RayTriangle(): intersect a ray with a 3D triangle
//    Input:  a ray R, and a triangle T
//    Output: *I = intersection point (when it exists)
//    Return: -1 = triangle is degenerate (a segment or point)
//             0 = disjoint (no intersect)
//             1 = intersect in unique point I1
//             2 = are in the same plane
inline int intersect_RayTriangle( float3 Ro, float3 Rv, float3 TV0, float3 TV1, float3 TV2, float3 * I )
{
    float3    u, v, n;             // triangle vectors
    float3    dir, w0, w;          // ray vectors
    float     r, a, b;             // params to calc ray-plane intersect

    // get triangle edge vectors and plane normal
    u = TV1 - TV0;
    v = TV2 - TV0;
    n = cross(u, v);             // cross product

    if ( all(isequal(n, (float3)0 ))) // triangle is degenerate
        return -1;                 // do not deal with this case

    dir = Rv;             // ray direction vector
    w0 = Ro - TV0;
    a = -dot(n, w0);
    b = dot(n, dir);
    if (fabs(b) < SMALL_NUM) {     // ray is parallel to triangle plane
        if (a == 0)                // ray lies in triangle plane
            return 2;
        else return 3;             // ray disjoint from plane
    }

    // get intersect point of ray with triangle plane
    r = a / b;
//    if (r < 0.0f)                   // ray goes away from triangle
//        return 0;                  // => no intersect
    // for a segment, also test if (r > 1.0) => no intersect

    *I = Ro + r * dir;           // intersect point of ray and plane
	//*I = mad((float3)r, dir, Ro);

    // is I inside T?
    float    uu, uv, vv, wu, wv, D;
    uu = dot(u,u);
    uv = dot(u,v);
    vv = dot(v,v);
    w = *I - TV0;
    wu = dot(w,u);
    wv = dot(w,v);
    D = uv * uv - uu * vv;

    // get and test parametric coords
    float s, t;
    s = (uv * wv - vv * wu) / D;
    if (s < 0.0f || s > 1.0f)        // I is outside T
        return 4;
    t = (uv * wu - uu * wv) / D;
    if (t < 0.0f || (s + t) > 1.0f)  // I is outside T
        return 5;

    return 1;                      // I is in T
}

inline float vectorSignedAngle(float3 vec1, float3 vec2, float3 ref)
{
	float vec1Len = length(vec1);
	float vec2Len = length(vec2);

	if ( (vec1Len != 0.0f) && (vec2Len != 0.0f) )
    {
		// given vectors normal
        // calculate vectors normal
		float3 normal = cross(vec1, vec2);
        // determine angle sign
        float sign = (( dot(normal, ref) ) < 0) ? -1 : 1; 
        // return result angle
		return acos( ( dot(vec1, vec2)) / (vec1Len * vec2Len) ) * sign;
    }
    else
       return 0.0f;
}

inline void matMul(const float4 mat1[4], const float4 mat2[4], float4 (* res)[4])
{
	float4 row1, row2, row3;

	row1.x = dot( (float4)(mat2[0].x, mat2[1].x, mat2[2].x, mat2[3].x), mat1[0]);
	row1.y = dot( (float4)(mat2[0].y, mat2[1].y, mat2[2].y, mat2[3].y), mat1[0]);
	row1.z = dot( (float4)(mat2[0].z, mat2[1].z, mat2[2].z, mat2[3].z), mat1[0]);
	row1.w = dot( (float4)(mat2[0].w, mat2[1].w, mat2[2].w, mat2[3].w), mat1[0]);

	row2.x = dot( (float4)(mat2[0].x, mat2[1].x, mat2[2].x, mat2[3].x), mat1[1]);
	row2.y = dot( (float4)(mat2[0].y, mat2[1].y, mat2[2].y, mat2[3].y), mat1[1]);
	row2.z = dot( (float4)(mat2[0].z, mat2[1].z, mat2[2].z, mat2[3].z), mat1[1]);
	row2.w = dot( (float4)(mat2[0].w, mat2[1].w, mat2[2].w, mat2[3].w), mat1[1]);

	row3.x = dot( (float4)(mat2[0].x, mat2[1].x, mat2[2].x, mat2[3].x), mat1[2]);
	row3.y = dot( (float4)(mat2[0].y, mat2[1].y, mat2[2].y, mat2[3].y), mat1[2]);
	row3.z = dot( (float4)(mat2[0].z, mat2[1].z, mat2[2].z, mat2[3].z), mat1[2]);
	row3.w = dot( (float4)(mat2[0].w, mat2[1].w, mat2[2].w, mat2[3].w), mat1[2]);
	
	(*res)[0] = row1;
	(*res)[1] = row2;
	(*res)[2] = row3;
}

// vector matrix multiplication
inline float3 transformTo2DLinear(float3 vector, float4 (* matrix)[4])
{
	return (float3) ( dot(vector, (*matrix)[0].xyz), dot(vector, (*matrix)[1].xyz), dot(vector, (*matrix)[2].xyz) );
}

// matrix-vector multiplication
inline float3 vertexTransform(float3 vector, float4 (* matrix)[4])
{
	return (float3) ( dot( (*matrix)[0], (float4)(vector, 1.0f) ),
					  dot( (*matrix)[1], (float4)(vector, 1.0f) ),
					  dot( (*matrix)[2], (float4)(vector, 1.0f) ) );
}

inline void computeMatrices(float4 (* forward)[4], float4 (* inverse)[4], float matrixlength, float resolution, float3 normal, float3 direction, float3 origin)
{
	float pixelSize = resolution / matrixlength;
	float halfSize = resolution / 2.0f - 0.5f;

	// shift to center
	(*forward)[0] = (float4)(1,0,0,-origin.x);
	(*forward)[1] = (float4)(0,1,0,-origin.y);
	(*forward)[2] = (float4)(0,0,1,-origin.z);
	(*forward)[3] = (float4)(0,0,0,1);
	// set aux to normal orthogonaly projected to XY plane
	float3 aux = (float3)(normal.xy, 0.0f);
	// Align with ZX plane
	float angle = vectorSignedAngle((float3)(1,0,0), aux, (float3)(0,0,1));

	float4 rotMatrix[4];
	rotMatrix[0] = (float4)(cos(-angle),-sin(-angle),0,0);
	rotMatrix[1] = (float4)(sin(-angle), cos(-angle),0,0);
	rotMatrix[2] = (float4)(0,0,1,0);
	rotMatrix[3] = (float4)(0,0,0,1);
	//matMul(rotMatrix, *forward, forward); 
	matMul(rotMatrix, *forward, forward); 

	// update aux
	aux = transformTo2DLinear(normal, forward);
	// Align with ZY plane
	angle = vectorSignedAngle((float3)(0,0,1), aux, (float3)(0,1,0));
	float4 rotMatrix2[4];
	rotMatrix2[0] = (float4)(cos(-angle),0,sin(-angle),0);
	rotMatrix2[1] = (float4)(0,1,0,0);
	rotMatrix2[2] = (float4)(-sin(-angle),0,cos(-angle),0);
	rotMatrix2[3] = (float4)(0,0,0,1);
	matMul(rotMatrix2, *forward, forward); 

 /*(*inverse)[0] = (float4)angle;*/

	// Rotate ZX plane to satisfy direction
	aux = transformTo2DLinear(direction, forward);
	aux.z = 0.0f;

	angle = vectorSignedAngle((float3)(1,0,0), aux, (float3)(0,0,1));
	float4 rotMatrix3[4];
	rotMatrix3[0] = (float4)(cos(-angle),-sin(-angle),0,0);
	rotMatrix3[1] = (float4)(sin(-angle), cos(-angle),0,0);
	rotMatrix3[2] = (float4)(0,0,1,0);
	rotMatrix3[3] = (float4)(0,0,0,1);
	matMul(rotMatrix3, *forward, forward); 

	// scale
	float4 scaleMatrix[4];
	scaleMatrix[0] = (float4)(pixelSize,0,0,0);
	scaleMatrix[1] = (float4)(0,pixelSize,0,0);
	scaleMatrix[2] = (float4)(0,0,pixelSize,0);
	scaleMatrix[3] = (float4)(0,0,0,1);
	matMul(scaleMatrix, *forward, forward);

	// shift center to center of a matrix
	float4 shiftMatrix2[4];
	shiftMatrix2[0] = (float4)(1,0,0,halfSize);
	shiftMatrix2[1] = (float4)(0,1,0,halfSize);
	shiftMatrix2[2] = (float4)(0,0,1,0);
	shiftMatrix2[3] = (float4)(0,0,0,1);
	matMul(shiftMatrix2, *forward, forward);

}

 __kernel void omc_single(__global const float4 * vertices,
						  __global const int4 * faces, // multiple of WORKGROUPSIZE !
						  __global float * matrices,
						  int vertexCount, int faceCount)
{
    __local float  localMatrix[MATRIX_SIZE * MATRIX_SIZE];
	__local int4   localFaces[TRIANGLES_PER_BATCH];
	__local float4 localVerticies[TRIANGLES_PER_BATCH * 3];

	float4 forwardTransform[4]; // converts triangles to matrix space
	float4 inverseTransform[4]; // converts matrix index to origin point in 3D

    size_t gid = get_global_id(0);
    size_t lid = get_local_id(0);
        
	// what vertex/matrix are we at
	int currentVertex = gid / WORKGROUPSIZE;
	float3 currentVertexPosition = vertices[currentVertex * 2].xyz;
	float3 currentVertexNormal = vertices[currentVertex * 2 + 1].xyz;

	// calculate transformation matrices
	computeMatrices(&forwardTransform, &inverseTransform, 2, 3, currentVertexNormal, (float3)(0,0,0), currentVertexPosition); // TODO parameters !!!!!!!!!!!!!!!

	// clear output matrix array - use all CPUs
	for (int i = 0; i < (MATRIX_SIZE * MATRIX_SIZE) / WORKGROUPSIZE; i++)
	{
		localMatrix[i * WORKGROUPSIZE + lid] = -1;
	}
/*
	if (lid == 0)
	{
		((__local float4 *)localMatrix)[0] = forwardTransform[0];
		((__local float4 *)localMatrix)[1] = forwardTransform[1];
		((__local float4 *)localMatrix)[2] = forwardTransform[2];
		((__local float4 *)localMatrix)[3] = forwardTransform[3];

		((__local float4 *)localMatrix)[4] = inverseTransform[0];
	
	}
*/
	// outer loop - process batch of tiangles
	for (int o = 0; o < faceCount / TRIANGLES_PER_BATCH; o++)
	{ // FOR ALL TRIANGLES
		// group copy face data from global memory
		int srcIndex = TRIANGLES_PER_BATCH * o;         // TODO ---------------------- staggered read optimization to avoid simultanious access
		event_t copyEvent = async_work_group_copy(localFaces,
											   &faces[srcIndex],
											   (size_t)TRIANGLES_PER_BATCH, 
												0);
		wait_group_events(1, &copyEvent);
		// transform triangles into matrix view space
		for (int t = 0; t < TRIANGLES_PER_BATCH / WORKGROUPSIZE; t++)
		{
			// each processor transforms one triangle - 64 way paralellism
			int index = WORKGROUPSIZE * t + lid;

			float3 V1 = vertices[ localFaces[index].x * 2].xyz;
			localVerticies[index * 3 + 0] = (float4) (vertexTransform(V1, &forwardTransform), 0);

			float3 V2 = vertices[ localFaces[index].y * 2].xyz;
			localVerticies[index * 3 + 1] = (float4) (vertexTransform(V2, &forwardTransform), 0);

			float3 V3 = vertices[ localFaces[index].z * 2].xyz;
			localVerticies[index * 3 + 2] = (float4) (vertexTransform(V3, &forwardTransform), 0);
		}
/*
	if (lid == 0)
	{
		((__local float4 *)localMatrix)[0] = (float4)(faces[30].x, faces[30].y, faces[30].z, 0);
		((__local float4 *)localMatrix)[1] = (float4)(faces[301].x, faces[301].y, faces[301].z, 0);

		((__local float4 *)localMatrix)[0] = (float4)(localFaces[300].x, localFaces[300].y, localFaces[300].z, faceCount);
		((__local float4 *)localMatrix)[1] = (float4)(localFaces[301].x, localFaces[301].y, localFaces[301].z, 0);
		((__local float4 *)localMatrix)[0] = localVerticies[300 * 3 + 0].xyzz;
		((__local float4 *)localMatrix)[1] = localVerticies[300 * 3 + 1].xyzz;
		((__local float4 *)localMatrix)[2] = localVerticies[300 * 3 + 2].xyzz;
	}
*/
		// wait for execution to finish on all units
		barrier(CLK_LOCAL_MEM_FENCE);
		// triangles transformed, now every matrix ray ray traces all the triangles
		// compiler will do automatic loop unrolling to boost performance, 16x16 matrix is automatically unrolled by the compiler
		for (int i = 0; i < (MATRIX_SIZE * MATRIX_SIZE) / WORKGROUPSIZE; i++)
		{ // FOR EACH MATRIX CELL
			short index = i * WORKGROUPSIZE + lid;
			// ray tracing here
			float3 origin;
			// calculate origin point in the matrix
			int x = index % MATRIX_SIZE;
			int y = index / MATRIX_SIZE;

			origin = (float3)(x/50.0,y/50.0,0); // !!!
		
			// test ALL triangles in local memory
			for (int j = 0; j < TRIANGLES_PER_BATCH; j++)
			{ // FOR EACH LOCAL TRIANGLE
				float3 intersection;
	
				int triangle = (j + lid) % TRIANGLES_PER_BATCH; // LDS memory access optimization
				//int triangle = j;
				
				int rayTracing = intersect_RayTriangle( origin, (float3)(0,0,1) , // ray defined as point and vector
																localVerticies[triangle * 3 + 0].xyz, // triangle P1 P2 P3
																localVerticies[triangle * 3 + 1].xyz,
																localVerticies[triangle * 3 + 2].xyz,
																&intersection );
				
/*
				int rayTracing = intersect_RayTriangle( (float3)(0.5,0.5,0), (float3)(0,1,0), // ray defined as point and vector
																currentVertexPosition, // triangle P1 P2 P3
																vertices[2].xyz,
																vertices[4].xyz,
																&intersection );

				if (i == 0 && lid == 0 && j == 0)
				{
					localMatrix[0] = rayTracing;
					localMatrix[1]++;
					localMatrix[2] = distance((float3)(0,0,0), intersection);

					localMatrix[3] = currentVertexPosition.x;
					localMatrix[4] = currentVertexPosition.y;
					localMatrix[5] = currentVertexPosition.z;

					localMatrix[6] = vertices[2].x;
					localMatrix[7] = vertices[2].y;
					localMatrix[8] = vertices[2].z;

					localMatrix[9] = vertices[4].x;
					localMatrix[10] = vertices[4].y;
					localMatrix[11] = vertices[4].z;

				}
*/
				bool intersects = rayTracing  == 1;
/*
				if (intersects)
				{
					localMatrix[index]++;
				}
*/
				// output matrix calculates minimum of the original distance and new distance
				localMatrix[index] = intersects ? 
										 ( localMatrix[index] == -1 ? distance(origin, intersection) : min(localMatrix[index], distance(origin, intersection) ) ) :
										localMatrix[index];

			}
		}
	}
	//---------------------------------------------------------------
	// copy the matrix back into global memory
	event_t copyEvent = async_work_group_copy( &matrices[currentVertex * (MATRIX_SIZE*MATRIX_SIZE)],
									   localMatrix,
									   MATRIX_SIZE * MATRIX_SIZE, 
									   0);
	wait_group_events(1, &copyEvent);

}
