#include "umf.h"
#include "util/umfdebug.h"
#include "util/draw.h"
#include "util/chromakey.h"
#include "util/mask_tracker.h"
#include "util/grid_util.h"
#include "util/kalman_filter.h"
#include "util/median_filter.h"
#include <Eigen/Dense>
#include <Eigen/StdVector>
#include <sstream>
#include <iostream>


/**
 * @mainpage UMF detector
 *
 * @section umf_usg Usage
 *  requires Eigen3 and optionally OpenCV
 *
 * @section umf_gath Gallery and Theory
 *  The main process goes like this:
 *
 *  -# source comes from opencv/firewire etc see factories \link umf::StreamFactory \endlink
 *  \image html 0_source.png
 *  -# sparse scanlines and edge detection \link umf::EdgelDetector::detectEdges \endlink
 *  \image html 1_scanline.png
 *  -# use edges along scanlines as seeds for edgel detection (scanlines 200) \link umf::EdgelDetector::findEdgel \endlink
 *  \image html 2_edgelgrow.png
 *  \image html 3_edgels.png
 *  -# separate into two groups \link umf::GridDetector::separateTwoGroups \endlink
 *  \image html 4_groups.png
 *  -# find vanish for both groups and filter out outliers using RANSAC \link umf::GridDetector::findVanish \endlink
 *  -# detect pencils of lines \link umf::GridDetector::detectMesh \endlink
 *  \image html 5_mesh.png
 *  -# extract edge directions in RGB \link umf::EdgeDirDetector \endlink
 *  \image html 6_edgedir.png
 *  -# match position using decision tree \link umf::DecisionTree \endlink, \link umf::Model \endlink
 *  -# camera pose estimation \link umf::Model \endlink
 *  -# Extra stuff - like iterative refinement by trying to detect more of the map ( for larger maps) )
 *
 */



namespace umf {

template <int NCHAN>
UMFDetector<NCHAN>::UMFDetector(int flags)
{
    this->flags = flags;
    this->subWindowVerticalCount = 3;
    this->filterMask = NULL;
}

template <int NCHAN>
UMFDetector<NCHAN>::~UMFDetector()
{
    if(this->filterMask)
    {
        delete this->filterMask;
    }
}

template <int NCHAN>
void UMFDetector<NCHAN>::getSubwindowOffsets(const Eigen::Vector2i &imgSize,
                                             std::vector<Eigen::Vector2i> &offsets,
                                             Eigen::Vector2i &subwindowSize)
{
    const int windowSize = imgSize[1]/this->subWindowVerticalCount; //~200x200 window for VGA
    const int windowSizeHalf = windowSize/2;
    subwindowSize = Eigen::Vector2i(windowSize, windowSize);

    int gridWidth = (int) ceil(imgSize[0]*1.0/windowSize);
    int gridHeight = (int) ceil(imgSize[1]*1.0/windowSize);
    int gridMidWidth = (int) ceil(imgSize[0]*1.0/windowSize - 0.5);
    int gridMidHeight = (int) ceil(imgSize[1]*1.0/windowSize - 0.5);

    offsets.resize(gridWidth*gridHeight + gridMidWidth*gridMidHeight);

    int counter = 0;
    //normal grid
    for(int j = 0; j < gridHeight; j++)
    {
        for(int i = 0; i < gridWidth; i++)
        {
            offsets[counter++] = Eigen::Vector2i(i*windowSize, j*windowSize);
        }
    }

    //overlap
    for(int j = 0; j < gridMidHeight; j++)
    {
        for(int i = 0; i < gridMidWidth; i++)
        {
            offsets[counter++] = Eigen::Vector2i(i*windowSize + windowSizeHalf, j*windowSize + windowSizeHalf);
        }
    }
}

template <int NCHAN>
void UMFDetector<NCHAN>::getPointPosition(Location &refLocation, Eigen::Vector2f &imgPos, Eigen::Vector2f &modelPos)
{

    std::vector<Eigen::Vector3f> &rows = this->gridDetect.getPencil(0);
    std::vector<Eigen::Vector3f> &cols = this->gridDetect.getPencil(1);

    Eigen::Vector3f v1 = rows.front().cross(rows.back());
    Eigen::Vector3f v2 = cols.front().cross(cols.back());
    Eigen::Vector3f horizont = v1.cross(v2);

    Eigen::Vector3f rowLine = v1.cross(Eigen::Vector3f(imgPos[0], imgPos[1], 1));
    Eigen::Vector3f colLine = v2.cross(Eigen::Vector3f(imgPos[0], imgPos[1], 1));

    float rowK = lineGetKPsI(rows.size(), horizont, rows.front(), rows.back());
    float colK = lineGetKPsI(cols.size(), horizont, cols.front(), cols.back());

    float pointRowK = lineGetKPsI(1.f, horizont, rows.front(), rowLine);
    float pointColK = lineGetKPsI(1.f, horizont, cols.front(), colLine);

    float rowOffset = rowK/pointRowK;
    float colOffset = colK/pointColK;

    modelPos[0] = colOffset + refLocation.c;
    modelPos[1] = rowOffset + refLocation.r;
    changeBackLocationf(modelPos, refLocation.rotation, this->model.getMarker()->w , this->model.getMarker()->h, 0, 0);
}

template <int NCHAN> template<class T>
bool UMFDetector<NCHAN>::detectPosition(Image<T, NCHAN> *image, std::vector<Eigen::Vector2f> &imgPos, std::vector<Eigen::Vector2f> &modelPos)
{
    UMFDebug *dbg = UMFDSingleton::Instance();

    const unsigned int POINT_COUNT = imgPos.size();

#ifdef UMF_DEBUG_TIMING
    int ovid = dbg->logEventStart();

    int logid = dbg->logEventStart();
#endif

    this->edgelDetect.detectEdges(image, NULL, false);

    //Orientation filter
    if(this->flags & UMF_FLAG_ORIENTATION)
    {
        OrientationFilter &track = this->edgelDetect.getOrientationFilter();
        track.enable(true);
        //track.filterPoints(image, this->edgelDetect.getEdges());
    }

    this->edgelDetect.findEdgels(image, NULL, false);

    //Orientation filter 2
    if(this->flags & UMF_FLAG_ORIENTATION)
    {
        OrientationFilter &track = this->edgelDetect.getOrientationFilter();
        track.filterEdgels(this->edgelDetect.getEdgels());
    }


#ifdef UMF_DEBUG_TIMING
    dbg->logEventEnd(logid, "EE");
    logid = dbg->logEventStart();
#endif
    //global pass first
    Location bestLoc;
    int bestCount = -1;
    Eigen::Vector2i poffset(0, 0);
    Eigen::Vector2i psize(image->width, image->height);
    std::vector<Eigen::Vector2f> bestPosition(imgPos.size());

    this->model.setUseCornerSearch(false);
    int pcount = this->processSubWindow(image, poffset, psize, bestLoc, NULL, false);
    bool success = pcount != -1;
    success = false;

#ifdef UMF_DEBUG_TIMING
    dbg->logEvent(1, "COUNT");
    dbg->logEventEnd(logid, "FULL");
#endif

    if(success)
    {
        for(unsigned int i = 0; i < POINT_COUNT; i++)
        {
            this->getPointPosition(bestLoc, imgPos[i], bestPosition[i]);
        }
    }
    else if(success == false && (this->flags & UMF_FLAG_SUBWINDOWS) != 0)
    {
#ifdef UMF_DEBUG_TIMING
        logid = dbg->logEventStart();
#endif
        //processing in subwindows
        int bestSubWindow = -1;

        std::vector<Eigen::Vector2i> offsets;
        Eigen::Vector2i wSize;
        this->getSubwindowOffsets(Eigen::Vector2i(image->width, image->height), offsets, wSize);
#ifdef UMF_DEBUG_TIMING
        dbg->logEvent(offsets.size(), "COUNT");
#endif
        for(unsigned int subI = 0; subI < offsets.size(); subI++)
        {
            Location loc;
            int pcount = this->processSubWindow(image, offsets[subI], wSize, loc, NULL, subI == 5);

            if(pcount > bestCount)
            {
                bestCount = pcount;
                bestSubWindow = subI;
                bestLoc = loc;

                for(unsigned int i = 0; i < POINT_COUNT; i++)
                {
                    this->getPointPosition(bestLoc, imgPos[i], bestPosition[i]);
                }
            }
        }
#ifdef UMF_DEBUG_TIMING
        dbg->logEventEnd(logid, "SUB");
#endif

        success = bestSubWindow != -1;
    }

    static std::vector<MedianFilter<float, 2> > filter(imgPos.size());
    typedef SimpleState<2> mstate_t;
    typedef SimpleState<4> sstate_t;
    typedef ConstantProcess<sstate_t> process_t;
    static std::vector< KalmanFilter<sstate_t, process_t>, Eigen::aligned_allocator<KalmanFilter<sstate_t, process_t> > > kf(imgPos.size());
    static bool kfinited = false;
    const float kfMeasurementVariance[2] = {2.0, 1.0};
    const double dt = 0.5;
    const float medianThreshold = 10.f;
    if(success)
    {
        for(unsigned int i = 0; i < POINT_COUNT; i++)
        {
            modelPos[i] = bestPosition[i];
            Eigen::Vector2f med = filter[i].filter(bestPosition[i]);

            if((med - modelPos[i]).norm() > medianThreshold)
            {
                //std::cout << "Median diff" << (med - modelPos).norm() << std::endl;
                modelPos[i] = med;
            }
        }

        if(!kfinited)
        {
            kfinited = true;
            std::vector<float> processModelVariance(imgPos.size(), 0.1f);
            processModelVariance[0] = 5e-4f;
            for(unsigned int i = 0; i < POINT_COUNT; i++)
            {
                kf[i].state.x = Eigen::Vector4d(modelPos[i][0], modelPos[i][1], 0, 0);

                kf[i].processModel.sigma = sstate_t::VecState::Constant(processModelVariance[i]);
                kf[i].processModel.jacobian << 1, 0, 1, 0,
                        0, 1, 0, 1,
                        0, 0, 1, 0,
                        0, 0, 0, 1;
            }
        }

        for(unsigned int i = 0; i < POINT_COUNT; i++)
        {
            kf[i].predict(dt);

            AbsoluteMeasurement<mstate_t, sstate_t> meas;
            meas.measurement = modelPos[i].template cast<double>();
            meas.covariance = Eigen::Vector2d::Constant(kfMeasurementVariance[i]).asDiagonal();

            kf[i].correct(meas);

            modelPos[i][0] = kf[i].state.x[0];
            modelPos[i][1] = kf[i].state.x[1];
        }
    }

#ifdef UMF_DEBUG_TIMING
    dbg->logEventEnd(ovid, "OVRL");
#endif
    return success;
}


template <int NCHAN> template<class T>
bool UMFDetector<NCHAN>::detect(Image<T, NCHAN> *image,  float timeout)
{
    UMFDebug *dbg = UMFDSingleton::Instance();
    this->detectionTimer.start();

#ifdef UMF_DEBUG_TIMING
    int ovid = dbg->logEventStart();
#endif

    if(this->flags & UMF_FLAG_CHROMAKEY && image->channels == 3)
    {
        if(this->filterMask == NULL)
        {
            this->filterMask = new ImageGray(image->width, image->height);
        }
        getChromeMask(image, this->filterMask);
    }


    //EDGEL DETECTION
#ifdef UMF_DEBUG_TIMING
    int logid = dbg->logEventStart();
#endif

    this->edgelDetect.detectEdges(image, this->filterMask, false);

    //MASK TRACKING
    if(this->flags & UMF_FLAG_TRACK_POS)
    {
        MaskTracker &track = this->model.getMaskTracker();
        track.filterPoints(this->edgelDetect.getEdges());

        //track.show();
    }

    this->edgelDetect.findEdgels(image, this->filterMask, false);

    
    this->checkTimeout(timeout);

#ifdef UMF_DEBUG_TIMING
    dbg->logEventEnd(logid, "EDGEL");
#endif

    bool success = false;

    this->model.setUseSubPixel(this->flags & UMF_FLAG_SUBPIXEL);

    //global pass first
    Location goodLoc;
    Eigen::Vector2i poffset(0, 0);
    Eigen::Vector2i psize(image->width, image->height);
    int pcount = this->processSubWindow(image, poffset, psize, goodLoc, filterMask, false);
#ifdef UMF_DEBUG_TIMING
    dbg->logEvent(1, "COUNT");
#endif
    success = pcount != -1;


    if(success == false && (this->flags & UMF_FLAG_SUBWINDOWS) != 0)
    {
        //check for timeout
        this->checkTimeout(timeout);
        
        //processing in subwindows
        int bestSubWindow = -1;
        CorrespondenceSet bestSet;
        Location bestLoc;

        //generate offset
        //foreach subset
        std::vector<Eigen::Vector2i> offsets;
        Eigen::Vector2i wSize;
        this->getSubwindowOffsets(Eigen::Vector2i(image->width, image->height),
                                  offsets,
                                  wSize);

#ifdef UMF_DEBUG_TIMING
        dbg->logEvent(offsets.size(), "COUNT");
#endif
        for(unsigned int subI = 0; subI < offsets.size(); subI++)
        {
            Location loc;
            int pcount = this->processSubWindow(image, offsets[subI], wSize, loc, filterMask, false);

            if(pcount > (int) bestSet.size())
            {
                bestSet = this->model.getCorrespondences();
                bestSubWindow = subI;
                bestLoc = loc;
            }

            if(checkTimeout(timeout, false))
            {
                if(bestSubWindow != -1)
                {
                    break;
                } else {
                    throw DetectionTimeoutException();
                }
            }
        }

        success = bestSubWindow != -1;

        if(success){
            this->model.setCorrespondences(bestSet);
            goodLoc = bestLoc;
        }

    }


    if(success)
    {
        if((this->flags & UMF_FLAG_HOMOGRAPHY) != 0)
        {
            success = this->model.computeHomography(false);
        } else {
            success = this->model.computeCameraPosition(image,
                                                        this->edgeDirDetect.getCols(),
                                                        this->edgeDirDetect.getRows(),
                                                        goodLoc,
                                                        ((bool)(this->flags & UMF_FLAG_ITER_REFINE)) && !checkTimeout(timeout), false);
        }
    }

    if(this->flags & UMF_FLAG_TRACK_POS)
    {
        if(success)
        {
            this->model.updateMaskTracker(Eigen::Vector2i(image->width, image->height), this->flags );
        } else {
            MaskTracker &track = this->model.getMaskTracker();
            track.disable();
        }
    }


#ifdef UMF_DEBUG_TIMING
    dbg->logEventEnd(ovid, "OVRL");
#endif

    return success;
}


template <int NCHAN> template<class T>
int UMFDetector<NCHAN>::processSubWindow(Image<T, NCHAN> *image, Eigen::Vector2i &offset, Eigen::Vector2i &size, Location &loc, ImageGray *mask, bool show)
{
    UMFDebug *dbg = UMFDSingleton::Instance();

    bool success = false;
    //GRID detect

    Eigen::Vector2f offsetF = offset.template cast<float>();
    Eigen::Vector2f sizeF = size.template cast<float>();

    //set the transform scale and other stuff accordingly
    this->gridDetect.setTransformCenter(Eigen::Vector2i(offset[0] + size[0]/2, offset[1] + size[1]/2));
    this->gridDetect.setTransformScale(2.0f/(float) (std::min)((float) size[0], (float) size[1]));

    bool showGrid = true && show;
    if(size[0] >= image->width && size[1] >= image->height)
    {
        success = this->gridDetect.detect(this->edgelDetect.getEdgels(), showGrid);
    }
    else {
        std::vector<Edgel> edgels;
        std::vector<Edgel> &allEdgels = this->edgelDetect.getEdgels();
        for(std::vector<Edgel>::iterator edgelIt = allEdgels.begin(); edgelIt != allEdgels.end(); edgelIt++)
        {
            Eigen::Vector2f diff1 = edgelIt->endPoints[0] - offsetF;
            Eigen::Vector2f diff2 = edgelIt->endPoints[1] - offsetF;
            //either endpoint1 or 2 is inside the box
            if( ((diff1.array() >= 0).all() && (diff1.array() < sizeF.array()).all()) || //endpoint 1
                    ((diff2.array() >= 0).all() && (diff2.array() < sizeF.array()).all())) //endpoint2
            {
                edgels.push_back(*edgelIt);
            }
        }
        success = this->gridDetect.detect(edgels, showGrid);
    }

    if(!success)
    {
#ifdef UMF_DEBUG_TIMING
        dbg->logEvent(0, "X");
#endif
        return -1;
    }


#ifdef UMF_DEBUG_TIMING
    int logid = dbg->logEventStart();
#endif

    this->edgeDirDetect.extract(image, this->gridDetect.getPencil(0), this->gridDetect.getPencil(1), mask, true && show);

#ifdef UMF_DEBUG_TIMING
    dbg->logEventEnd(logid, "X");
#endif

    success = this->model.matchModel(image, this->edgeDirDetect.getCols(), this->edgeDirDetect.getRows(),
                                     this->edgeDirDetect.getEdgeDirections(),
                                     this->edgeDirDetect.getExtractionPoints(), loc, true && show);

    return success ? this->model.getCorrespondences().size() : -1;
}


/**
 * @brief load a marker from a string
 * @tparam NCHAN detector using this number of channels
 * @param marker_str the string containing the marker
 * @return whether the marker was successfully loaded
 *
 *The expected marker format:
 * width
 * height
 * kernelSize;mask
 * typeCode[;optionally colors in hex - #aabbcc;#1144dd]
 * data...........
 *
 */
template<int NCHAN>
bool UMFDetector<NCHAN>::loadMarker(const char* markerStr)
{

    std::stringstream dataStream(markerStr, std::stringstream::in);

    if(dataStream.fail())
    {
        return false;
    }

    int width;
    int height;
    int ksize = 2;
    int code;
    int mask;

    char comma = ';';

    dataStream >> width;
    dataStream >> height;
    dataStream >> ksize;
    dataStream >> std::noskipws >> comma >> mask >> std::skipws;

    if(mask != 0)
    {
        std::cerr << "Error - unsupported format with a mask!" << std::endl;
        return false;
    }

    dataStream >> code;

    MarkerType mType; mType.decode(code);

    if(mType.torus)
    {
        std::cerr << "Error - unsupported torus format!" << std::endl;
        return false;
    }

    std::vector< Eigen::Matrix<unsigned char, NCHAN, 1> > colors(mType.range);

    //TODO should somehow handle greenscreen marker's too if only one channel - get the mapping from somewhere
    if(mType.color)
    {
        unsigned long *colorHex = new unsigned long[mType.range];

        const int CHANNELS = 3;
        unsigned char *rgb = new unsigned char[mType.range*CHANNELS];

        for(int i = 0; i < mType.range; i++)
        {
            dataStream >> std::noskipws >> comma >> comma /*hash*/ >> std::hex >> colorHex[i];
            rgb[i*CHANNELS + 0] = (colorHex[i] >> 16) & 255;
            rgb[i*CHANNELS + 1] = (colorHex[i] >> 8) & 255;
            rgb[i*CHANNELS + 2] = (colorHex[i]) & 255;
        }
        delete [] colorHex;

        if(NCHAN == 1)
        {
            for(int i = 0; i < mType.range; i++)
            {
                colors[i](0) = (unsigned char)(0.299f*rgb[i*CHANNELS + 0] + 0.587f*rgb[i*CHANNELS + 1] + 0.114f*rgb[i*CHANNELS + 2]); //convert to grayscale
            }
        }
        else if (NCHAN == 3)
        {
            for(int i = 0; i < mType.range; i++)
            {
                colors[i](0) = rgb[i*CHANNELS + 0];
                colors[i](1) = rgb[i*CHANNELS + 1];
                colors[i](2) = rgb[i*CHANNELS + 2];
            }
        } else {
            //no idea how to map colors
            delete [] rgb;
            return false;
        }

        delete [] rgb;
    } else {
        //We got grayscale on need to read anything
        int step = 255/(mType.range - 1);
        for(int i = 0 ; i < mType.range; i++)
        {
            colors[i].setOnes(); colors[i] *= i*step;
        }
    }

    //read data

    std::vector<unsigned short> data;
    data.reserve(height*width);

    Marker<NCHAN> *marker = new Marker<NCHAN>(height, width, ksize, colors);

    comma = '\0';

    int row = 0;
    while(1)
    {
        if(dataStream.eof() || row == height){
            break;
        }

        for(int i = 0; i < width; i++)
        {
            short val = 0;
            dataStream >> std::skipws >> val >> std::noskipws >> comma;
            data.push_back((unsigned short) val);
        }
        row++;
    }

    if(data.size() != (unsigned)(width*height))
    {
        std::cerr << "Non matching data size." << data.size() << " " << (width*height) << std::endl;
        delete marker;
        marker = NULL;
        return false;
    }

    if(marker->setField(data) != true)
    {
        delete marker;
        marker = NULL;
        return false;
    }

    this->model.addMarker(marker);

    return true;
}

template <int NCHAN>
bool UMFDetector<NCHAN>::checkTimeout(float timeout, bool shouldThrow)
{
    if(timeout > 0)
    {
        double diff = this->detectionTimer.stop();
        if(diff > timeout)
        {
            if(shouldThrow)
            {
                throw DetectionTimeoutException();
            }
            return true;
        }
    }
    return false;
}

template UMFDetector<1>::UMFDetector(int flags);
template UMFDetector<3>::UMFDetector(int flags);
template UMFDetector<1>::~UMFDetector();
template UMFDetector<3>::~UMFDetector();

template bool UMFDetector<1>::detect(ImageGray *img, float timeout);
template bool UMFDetector<3>::detect(ImageRGB *img, float timeout);

template bool UMFDetector<1>::detectPosition(ImageGray *image, std::vector<Eigen::Vector2f> &imgPos, std::vector<Eigen::Vector2f> &modelPos);
template bool UMFDetector<3>::detectPosition(ImageRGB *image, std::vector<Eigen::Vector2f> &imgPos, std::vector<Eigen::Vector2f> &modelPos);


template bool UMFDetector<1>::loadMarker(const char* markerStr);
template bool UMFDetector<3>::loadMarker(const char* markerStr);

}
