//---------------------------------------------------------------------------

#include <iostream>
#include <string>
#include "classifier.h"


using namespace std;


static unsigned evals = 0;
static unsigned samplesScanned = 0;
static float avgFeatures = 0.0f;

///////////////////////////////////////////////////////////////////////////////
// Decision stump weak classifier

int TDecisionStump::loadFromXML(const xmlNodePtr root)
{
    if (!root) return 0; // Null pointer given

    xmlNodePtr wcNode = getNode("DecisionStumpWeakHypothesis", root);
    if (!wcNode)
	{
		return 0;
	}

    bool ok = true; // Flag to signalize all parameters are present

    // Read parameters
    if (!getAttr(alpha, "alpha", wcNode)) ok = false;
    if (!getAttr(threshold, "threshold", wcNode)) ok = false;
    if (!getAttr(parity, "parity", wcNode)) ok = false;

    if (!ok)
	{
		return 0;
	}

    if (feature) delete feature;
    
    feature = createContFeature(wcNode);

    if (!feature) return 0;
    
	return 1;
}


/////////////////////////////////////////////////////////////////////////////
// Sample evaluation
int TClassifier::evalSample(const TSampleImage & sample)
{
    float response = 0.0f;
    StageList::iterator stg = wbStage.begin();
    ThrList::iterator thr = wbThreshold.begin();
	
	++samplesScanned;
	while (stg != wbStage.end())
	{
		const TDecisionStump & hypothesis = **stg;
        response += hypothesis.eval(sample);

		++evals;

		// Test wald thresholds
		if (response < thr->first) return 0;
		if (response > thr->second) return 1;

		++thr, ++stg;
	}
	// No decision achieved yet - decide using threshold T.
    return (response > T) ? 1 : 0;
}


int TClassifier::scanFrame(TFrame & frame, std::vector<TRect> & detections, unsigned first, float scale)
{
	if (first >= detections.size())
		return 0;

	TImageUInt sampleSum1;
	TImageUInt sampleSum2;

	unsigned det = first;
	unsigned total = 0;

	// Go through all positions
	for (int y = 0; y < frame.size().h - size.h; y+=shiftY)
	{
		for (int x = 0; x < frame.size().w - size.w; x+=shiftX)
		{
			// get sample image
			TRect sampleArea(x, y, size.w, size.h);
			TSampleImage sample(frame, sampleArea);

			// We are not interested in subwindows with low contrast - skip them
			if (sample.stdDev() < 10) continue;

			// When positive response encountered - add the area to detection list
			if (this->evalSample(sample))
			{
				detections[det++] = TRect(int(x*scale), int(y*scale), int(size.w*scale), int(size.h*scale));
				++total;
			}

			// When list is full - exit
			if (det >= detections.size())
				return total;
		}
	}
	
	return total;
}


// Load complete classifier from XML structure
// <WaldBoostClassifier>
// <stage>
// </stage>
// <WaldBoostClassifier>
int TClassifier::loadFromXML(xmlNodePtr root)
{
    status = false; // reset status flag

    if (xmlStrcmp(root->name, BAD_CAST "WaldBoostClassifier") != 0)
    {
		cout << "No wb classifier!" << endl;
        return 0;
    }

    wbStage.clear();
    wbThreshold.clear();

	if (!getAttr(size.w, "sizeX", root)) return 0;
    if (!getAttr(size.h, "sizeY", root)) return 0;

    for (xmlNodePtr stageNode = root->children; stageNode; stageNode=stageNode->next)
    {
        // we are interested only in "stage" nodes
        if (xmlStrcmp(stageNode->name, BAD_CAST "stage") == 0)
        {
            // defaults for waldboost thresholds
            float posT =  1E10;
            float negT = -1E10;
            getAttr(posT, "posT", stageNode);
            getAttr(negT, "negT", stageNode);
            TDecisionStump * wh = new TDecisionStump();
            if (wh->loadFromXML(stageNode))
            {
                // All stage params are present - add to list
                wbStage.push_back(wh);
                wbThreshold.push_back(pair<float,float>(negT, posT));
            }
            else
            {
                cerr << "WARNING: cannot construct weak hypothesis" << endl;
				delete wh;
            }
        }
    }
    status = true; // all loaded ok
    return 1;
}


//------------------------------------------------------------------------------

TClassifier * loadStrongClassifier(std::string filename)
{
    xmlInitParser();

	//cout << filename.c_str() << endl;   
	xmlDocPtr cXml = xmlParseFile(filename.c_str());

    if (!cXml) return 0;

    // Find strong classifier within xml structure (expected on second level)
    // <ROOT>
    // ...
    // <TWaldBoostLearner>
	// <WaldBoostClassifier/>
	// </TWaldBoostLearner>
    // ...
    // </ROOT>
    xmlNodePtr cNode = getNode("TWaldBoostLearner", xmlDocGetRootElement(cXml));
    if (!cNode)
	{
		cout << "No learner found" << endl;
		return 0;
	}

	cNode = getNode("WaldBoostClassifier", cNode);
	if (!cNode)
	{
		cout << "No classifier found" << endl;
		return 0;
	}

    // Try to construct strong classifier
	TClassifier * classifier = new TClassifier(cNode);
    
    if (classifier->ok())
    {
        return classifier;
    }
    else
    {
		cerr << "Cannot initialize classifier" << endl;
        delete classifier;
        return 0;
    }

    // destroy parser?!!
}


//------------------------------------------------------------------------------

int scanImagePyramid(TImagePyramid * pyramid, TClassifier * classifier, std::vector<TRect> & detections)
{
	unsigned dets = 0;
	
	evals = 0;
	samplesScanned = 0;

	// Go through all pyramid levels
	for (unsigned i = 0; i < pyramid->levels(); ++i)
	{
		// scan the frame
		unsigned d = classifier->scanFrame(pyramid->frame(i), detections, dets, pyramid->getLevelFactor(i));
		dets += d;
		if (dets >= detections.size())
			break; // cannot handle more detections
	}
	
	avgFeatures = (evals > 0) ? float(evals) / samplesScanned : 0.0f;

	return dets;
}


float getAverageFeaturesPerSample()
{
	return avgFeatures;
}


unsigned getSamplesCount()
{
    return samplesScanned;
}

