#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <iostream>
#include <boost/numeric/ublas/io.hpp>

#include "l2_simplex_projection.hpp"

//#define MYDEBUG
//#define MYDEBUG_VERBOSE

CDiagQPSolver::~CDiagQPSolver()
{
    delete [] pieceList;
    delete [] dataPool;
    delete unsetList;
}


CDiagQPSolver::CDiagQPSolver(int _n, vector<double> & _lower_bound, vector<double> & _upper_bound,
                            vector<double> & _dsq, vector<double> & _sigma, vector<double> & _offset, 
                            double _z) : n(_n)
{
    double min_sum = 0.0, max_sum = 0.0;

    dsq = &(_dsq.data()[0]);
    offset = &(_offset.data()[0]);
    sigma = &(_sigma.data()[0]);

    lower_bound = &(_lower_bound.data()[0]);
    upper_bound = &(_upper_bound.data()[0]);
    pieceList = new CPiece [n];
    dataPool = new double [2*n];
    z = _z;
    ori_z = _z;

    double *ptr = dataPool;
    
    for (int i = 0; i < n; i ++)
    {
        if (lower_bound[i] > upper_bound[i])
        {
            std::cout << "Lower bounds exceeds upper bound for index " << i << std::endl;
            exit(1);
        }

        pieceList[i].index = i;

        pieceList[i].slope = sigma[i] * sigma[i] / dsq[i];
        if (sigma[i] > 0)
        {
            pieceList[i].x_low = dsq[i] * (lower_bound[i] - offset[i]) / sigma[i];
            pieceList[i].x_high = dsq[i] * (upper_bound[i] - offset[i]) / sigma[i];
            pieceList[i].y_low = sigma[i] * (lower_bound[i] - offset[i]);
            pieceList[i].y_high = sigma[i] * (upper_bound[i] - offset[i]);
            min_sum += sigma[i] * lower_bound[i];
            max_sum += sigma[i] * upper_bound[i];
        }
        else
        {
            pieceList[i].x_low = dsq[i] * (upper_bound[i] - offset[i]) / sigma[i];
            pieceList[i].x_high = dsq[i] * (lower_bound[i] - offset[i]) / sigma[i];
            pieceList[i].y_low = sigma[i] * (upper_bound[i] - offset[i]);
            pieceList[i].y_high = sigma[i] * (lower_bound[i] - offset[i]);
            max_sum += sigma[i] * lower_bound[i];
            min_sum += sigma[i] * upper_bound[i];
        }

        assert(fabs(pieceList[i].y_low - pieceList[i].x_low * pieceList[i].slope) < 1e-5);
        assert(fabs(pieceList[i].y_high - pieceList[i].x_high * pieceList[i].slope) < 1e-5);

        pieceList[i].next = pieceList + i + 1;
        *(ptr++) = pieceList[i].x_low;
        *(ptr++) = pieceList[i].x_high;
        z -= sigma[i] * offset[i];

#ifdef MYDEBUG_VERBOSE
        pieceList[i].Print();
#endif
        
    }

#ifdef MYDEBUG_VERBOSE
    std::cout << "z = " << z << std::endl;
#endif

    if (min_sum > _z || max_sum < _z)
    {
        std::cout << "Feasible region is empty.  Quitting." << std::endl;
        exit(1);
    }

    pieceList[n - 1].next = 0;
    unsetList = new CPiece;
    unsetList->next = pieceList;
}


void CDiagQPSolver::dumpToFile(char * filename)
{
    FILE *fp = fopen(filename, "w");
    int i;

    fprintf(fp, "%d\n", n);
    fprintf(fp, "%lg\n", ori_z);
    for (i = 0; i < n; i++)
        fprintf(fp, "%lg\n", lower_bound[i]);
    for (i = 0; i < n; i++)
        fprintf(fp, "%lg\n", upper_bound[i]);
    for (i = 0; i < n; i++)
        fprintf(fp, "%lg\n", dsq[i]);
    for (i = 0; i < n; i++)
        fprintf(fp, "%lg\n", sigma[i]);
    for (i = 0; i < n; i++)
        fprintf(fp, "%lg\n", offset[i]);
    fclose(fp);
}


double CDiagQPSolver::evaluateObj()
{
    double result = fixedSum + median_x * cumulateDsq;
    CPiece *ptr_cur;
    double innerDsq = 0.0;
    
    ptr_cur = unsetList->next;

    do
    {
        int index = ptr_cur->index;

        switch (ptr_cur->status)
        {
        case 0:
            if (median_x <= ptr_cur->x_low)
                result += ptr_cur->y_low;
            else if (median_x >= ptr_cur->x_high)
                result += ptr_cur->y_high;
            else
                innerDsq += ptr_cur->slope;
            break;
        case 1:
            if (median_x <= ptr_cur->x_low)
                result += ptr_cur->y_low;
            else
                innerDsq += ptr_cur->slope;
            break;
        case 2:
            if (median_x >= ptr_cur->x_high)
                result += ptr_cur->y_high;
            else
                innerDsq += ptr_cur->slope;
            break;
        }

        ptr_cur = ptr_cur->next;

    } while (ptr_cur);

    return result + median_x * innerDsq;
}


// Return whether the piece is still unset
bool CDiagQPSolver::processPiece(CPiece * ptr_cur)
{
    switch (ptr_cur->status)
    {
    case 0:
        if (larger)
        {
            if (median_x <= ptr_cur->x_low)  // fixed, delete from unset
            {
                fixedSum += ptr_cur->y_low;
                return false;
            }
            else if (median_x <= ptr_cur->x_high)
            {
                ptr_cur->status = 1;
                dataPool[dataPoolIndex ++] = ptr_cur->x_low;
            }
            else
            {
                dataPool[dataPoolIndex ++] = ptr_cur->x_low;
                dataPool[dataPoolIndex ++] = ptr_cur->x_high;
            }
        }
        else
        {
            if (median_x >= ptr_cur->x_high)  // fixed, delete from unset
            {
                fixedSum += ptr_cur->y_high;
                return false;
            }
            else if (median_x >= ptr_cur->x_low)
            {
                ptr_cur->status = 2;
                dataPool[dataPoolIndex ++] = ptr_cur->x_high;
            }
            else
            {
                dataPool[dataPoolIndex ++] = ptr_cur->x_low;
                dataPool[dataPoolIndex ++] = ptr_cur->x_high;
            }
        }
        break;

    case 1:
        if (larger)
        {
            if (median_x <= ptr_cur->x_low)  // fixed, delete from unset
            {
                fixedSum += ptr_cur->y_low;
                return false;
            }
        }
        else
        {
            if (median_x >= ptr_cur->x_low)
            {
                cumulateDsq += ptr_cur->slope;
                return false;
            }
        }
        dataPool[dataPoolIndex ++] = ptr_cur->x_low;
        break;

    case 2:
        if (larger)
        {
            if (median_x <= ptr_cur->x_high)  // fixed, delete from unset
            {
                cumulateDsq += ptr_cur->slope;
                return false;
            }
        }
        else
        {
            if (median_x >= ptr_cur->x_high)  // fixed, delete from unset
            {
                fixedSum += ptr_cur->y_high;
                return false;
            }
        }
        dataPool[dataPoolIndex ++] = ptr_cur->x_high;
    }

    return true;
}



void CDiagQPSolver::Solve(vector<double> & solution)
{
    CPiece *ptr_pre, *ptr_cur;

    cumulateDsq = 0.0;
    fixedSum = 0.0;
    dataPoolIndex = 2 * n;

    while (unsetList->next)
    {
        // first get the median of current hinge set
        int median_index = (dataPoolIndex - 1) / 2;
        std::nth_element(dataPool, dataPool + median_index, dataPool + dataPoolIndex);
        
        median_x = dataPool[median_index];
        double median_fx = evaluateObj();

        if (fabs(median_fx - z) < 1e-5)     {
//             std::cout<<"quick exit"<<std::endl; 
            break;
        }

        larger = median_fx > z;

        // Delete all the pieces that have been set
        ptr_pre = unsetList;
        ptr_cur = unsetList->next;
        dataPoolIndex = 0;

        while (ptr_cur)
        {
            if (processPiece(ptr_cur)) 
            {
                ptr_pre = ptr_cur;
                ptr_cur = ptr_cur->next;                     
            }
            else    // set already, kick out of the unset list
            {
                ptr_pre->next = ptr_cur->next;
                ptr_cur = ptr_pre->next;
            }
        }
    }

    if (!unsetList->next)
    {
        median_x = (z - fixedSum) / cumulateDsq;
//         std::cout<<"median x = " << median_x <<", z - fixedSum = " << z - fixedSum 
//             << ", cumulateDsq = " << cumulateDsq << std::endl;
    }

    // Compute the ultimate solution
    double func_val = 0.0;
    for (int i = 0; i < n; i ++)
    {
        if (median_x <= pieceList[i].x_low)
        {
            if (sigma[i] > 0)
                solution(i) = lower_bound[i];
            else
                solution(i) = upper_bound[i];
        }
        else if (median_x >= pieceList[i].x_high)
        {
            if (sigma[i] > 0)
                solution(i) = upper_bound[i];
            else
                solution(i) = lower_bound[i];
        }
        else
            solution(i) = median_x * sigma[i] / dsq[i] + offset[i];

        func_val += dsq[i] * (solution(i) - offset[i]) * (solution(i) - offset[i]);
    }

#ifdef MYDEBUG
    std::cout << "Solution:\n";
    std::cout << solution << std::endl;

    std::cout << "Final median_x = " << median_x << std::endl;  
    check_result(solution);
    std::cout << "Minimum function value = " << func_val / 2.0 << std::endl;    
#endif

}


void CDiagQPSolver::check_result(vector<double> & solution)
{
    double check;
    CPiece *ptr_cur;

    check = 0.0;
    for (int i = 0; i < n; i ++)
    {
        ptr_cur = pieceList + i;
        if (median_x <= ptr_cur->x_low)
            check += ptr_cur->y_low;
        else if (median_x >= ptr_cur->x_high)
            check += ptr_cur->y_high;
        else
            check += median_x * ptr_cur->slope;
    }   

    printf("Discrepancy of inner piecewise linear function root finding = %g\n", check - z);

    check = 0.0;
    for (int i = 0; i < n; i ++)
        check += solution(i) * sigma[i];

    printf("Discrepancy of linear constraint = %g\n", check - ori_z);
}
