/* Copyright (c) 2009 NICTA 
 * All rights reserved. 
 * 
 * The contents of this file are subject to the Mozilla Public License 
 * Version 1.1 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at 
 * http://www.mozilla.org/MPL/ 
 * 
 * Software distributed under the License is distributed on an "AS IS" 
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the 
 * License for the specific language governing rights and limitations 
 * under the License. 
 * 
 * Authors: Xinhua Zhang (xinhua.zhang81@gmail.com)
 *
 * Created: (19/05/2009) 
 *
 * Last Updated: (20/05/2009)   
 */

#ifndef _NESTEROV07_CPP_
#define _NESTEROV07_CPP_

#include "common.hpp"
#include "Nesterov07.hpp"
#include "timer.hpp"
#include "configuration.hpp"
#include "loss.hpp"
#include "l2_simplex_projection.hpp"
#include "l2_box_projection.hpp"
#include "hingeloss.hpp"

#include <fstream>
#include <sstream>
#include <algorithm>

using namespace std;


CNesterov07::CNesterov07(CModel *model, CHingeLoss* loss)
   : CSolver(model, loss),
     verbosity(0),
     maxNumOfIter(10000),
     lambda(1.0),
     checkpointPrefix("model.checkpoint"),
     checkpointInterval(0),
     bias_value(0.0),
     checkpointMode(KEEP_LATEST),
     offset(loss->getNumExample(), 1, SML::DENSE),
     grad_y(loss->getNumExample(), 1, SML::DENSE),
     grad_t(loss->getNumExample(), 1, SML::DENSE),
     cumulantGrad(loss->getNumExample(), 1, SML::DENSE),
     xk(loss->getNumExample(), 1, SML::DENSE),
     TLy(loss->getNumExample(), 1, SML::DENSE),
     zk(loss->getNumExample(), 1, SML::DENSE),
     ytemp(loss->getNumExample(), 1, SML::DENSE),
     cur_w(loss->getNumFeature(), 1, SML::DENSE)
{
   ConfirmProgramParameters();
   num_example = loss->getNumExample();
   minDualObj = 1e99;
   minPrimalObj = 1e99;
   A = 0.0;
}


/**  Destructor
 */
CNesterov07::~CNesterov07()
{
}

void CNesterov07::checkVarValid(TheMatrix &var, 
                        ublas::vector<double> &lower_bound, 
                        ublas::vector<double> &upper_bound,
                        ublas::vector<double> &sigma, double z, char * varName)
{
    return;
    double innerProd = 0.0, temp;
    int num_example = lower_bound.size();
    for (int i = 0; i < num_example; i ++)
    {
        var.Get(i, temp);
        if (temp < lower_bound[i] - 1e-6 || temp > upper_bound[i] + 1e-6)
        {
            std::cout<<varName<<"["<<i<<"] out of bound: "<< temp <<std::endl;
            exit(EXIT_FAILURE);
        }        
        innerProd += temp * sigma[i];
    }
    if (useBias && fabs(innerProd - z) > 1e-5)
    {
        std::cout<<"Linear constraint violated."<<std::endl;
        exit(EXIT_FAILURE);
    }
}
        

/**  Start training/learning a model w.r.t. the loss object (and the data supplied to it).
 *
 */
void CNesterov07::Train()
{
    CTimer totalTime;             // total runtime of the training
    CTimer innerSolverTime;       // time for inner optimization (e.g., QP or LP)
    CTimer lossAndGradientTime;   // time for loss and gradient computation

    unsigned int iter = 0;        // iteration count
    double varScaleFactor = num_example / vscaleRaw;

    int exitFlag, i;

    TheMatrix &w = _model->GetW();
    ublas::vector<double> dsq1(num_example), dsq2(num_example);
    ublas::vector<double> lower_bound(num_example), upper_bound(num_example), sigma(num_example);
    ublas::vector<double> solution(num_example), offset_vec(num_example);
    double z = 0.0;
    double useless, temp1, temp2;

    _loss->setVarScaleFactor(varScaleFactor);
    _loss->getLabels(sigma);

    for(i = 0; i < num_example; i ++)
    {
        dsq1(i) = dsq2(i) = 1.0;
        lower_bound(i) = 0.0;
        upper_bound(i) = 1.0 / lambda / varScaleFactor;
    }
    std::cout<<"upper bound = " << upper_bound(0) << std::endl;

    int numPos = 0, numNeg = 0;
    for(i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    numPos ++;
        else                 numNeg ++;
    }
    int n_small = numPos > numNeg ? numNeg : numPos;
    double v_pos = n_small / varScaleFactor / (2.0 * numPos * lambda);
    double v_neg = n_small / varScaleFactor / (2.0 * numNeg * lambda);
    v_neg = v_pos = upper_bound(0);
//     v_neg = v_pos = 0.0;

    for(i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    xk.Set(i, v_pos);
        else                 xk.Set(i, v_neg);
    }

    checkVarValid(xk, lower_bound, upper_bound, sigma, z, "xk");
    zk.Assign(xk);
    cumulantGrad.Assign(xk);

    // start training
    totalTime.Start();

    do
    {
//         std::cout<<"Current Ltil = "<<Ltil<<std::endl;
        do 
        {
            a = (1.0 + sqrt(1 + 2 * A * Ltil)) / Ltil;
            ytemp.Zero();
            ytemp.ScaleAdd(A / (A + a), xk);
            ytemp.ScaleAdd(a / (A + a), zk);
            checkVarValid(ytemp, lower_bound, upper_bound, sigma, z, "ytemp");

//             std::cout<<"Showing ytemp\n";   ytemp.Print();

            lossAndGradientTime.Start();
            _loss->ComputeDualLossAndGradient(useless, grad_y, ytemp);
            lossAndGradientTime.Stop();

            offset.Assign(ytemp);
            offset.ScaleAdd(-1.0 / Ltil, grad_y);
            offset.GetRow(0, num_example, &(offset_vec[0]));

            if (useBias)
            {
                CDiagQPSolver innerSolver1(num_example, lower_bound, upper_bound,
                                            dsq1, sigma, offset_vec, z);

                innerSolver1.Solve(solution);
            }
            else
            {
                CDiagBoxQPSolver innerSolver1(num_example, lower_bound, upper_bound, offset_vec);
                innerSolver1.Solve(solution);
            }

            TLy.SetRow(0, num_example, &(solution.data()[0]));

            checkVarValid(TLy, lower_bound, upper_bound, sigma, z, "TLy");

            //////////////////
            // Now let's do a further mathematical check
//             TheMatrix tmpMat(num_example, 1, SML::DENSE);
//             tmpMat.Assign(curGrad);
//             tmpMat.ScaleAdd(-Ltil, ytemp);
//             tmpMat.ScaleAdd(Ltil, TLy);
//             std::cout<<" tmpMat.norm2 = " << tmpMat.Norm2() << std::endl;
//             std::cin.get();














            //////////////////

            lossAndGradientTime.Start();
            _loss->ComputeDualLossAndGradient(useless, grad_t, TLy);
            lossAndGradientTime.Stop();

            double ytempNorm2 = ytemp.Norm2();

            offset.Assign(grad_t);

            offset.Minus(grad_y);
            offset.ScaleAdd(Ltil, ytemp);
            offset.ScaleAdd(-Ltil, TLy);

            ytemp.Minus(TLy);
            ytemp.Dot(offset, temp1);
            temp2 = offset.Norm2();
//             std::cout << "inner prod = " << temp1 << ", curGrad.norm2 = " << temp2 
//                       << ", ytemp.norm2 = " << ytempNorm2 << ", TLy norm2 = " << TLy.Norm2()
//                       << "\nA = " << A << ", a = " << a 
//                       << ", Ltil = " << Ltil <<std::endl;
//             std::cout << "-----------------------------\n";
            if (temp1 >= temp2 * temp2 / Ltil)
                break;
            Ltil *= gamma_u;
//             std::cout<<".";
        } while (true);

//         std::cout<<"One loop ends\n";
//         std::cin.get();
        A += a;
        xk.Assign(TLy);
        checkVarValid(xk, lower_bound, upper_bound, sigma, z, "xk_later");

        Ltil /= gamma_d;

        cumulantGrad.ScaleAdd(-a, grad_t);
        cumulantGrad.GetRow(0, num_example, &(offset_vec[0]));

        if (useBias)
        {
            CDiagQPSolver innerSolver1(num_example, lower_bound, upper_bound,
                                        dsq1, sigma, offset_vec, z);

            innerSolver1.Solve(solution);
        }
        else
        {
            CDiagBoxQPSolver innerSolver1(num_example, lower_bound, upper_bound, offset_vec);
            innerSolver1.Solve(solution);
        }

        zk.SetRow(0, num_example, &(solution.data()[0]));

        checkVarValid(zk, lower_bound, upper_bound, sigma, z, "zk");

        _loss->ComputeDualLossAndGradient(dualObj, grad_t, xk);
        dualObj *= lambda;
        
         if(minDualObj > dualObj)    {
             minDualObj = dualObj;
         }

        xk.Scale(varScaleFactor * 1.0 / num_example);
        primalObj = _loss->ComputePrimalObj(xk, lambda, cur_w, numPos, useBias, bias_value);
        xk.Scale(num_example * 1.0 / varScaleFactor);

        if (primalObj < minPrimalObj)
        {
            // w is always set to the minimizer of primal
            minPrimalObj = primalObj;
            w.Assign(cur_w);
        }

        // Display details of each iteration
        DisplayIterationInfo(iter, totalTime.CurrentCPUTotal());

        // Save model obtained in previous iteration
        SaveCheckpointModel(iter);

        //      Check if termination criteria satisfied
        exitFlag = CheckTermination(iter, fabs(primalObj + dualObj));  

        iter ++;

    } while(!exitFlag);

    totalTime.Stop();

    // Display after-training details
    DisplayAfterTrainingInfo(iter, lossAndGradientTime, innerSolverTime, totalTime);
}


/**   Validate program parameters set in Configuration.
 */
void CNesterov07::ConfirmProgramParameters()
{
   Configuration &config = Configuration::GetInstance();  // make sure configuration file is read before this!
   
   if(config.IsSet("Nesterov07.verbosity")) 
      verbosity = config.GetInt("Nesterov07.verbosity");
   
   if(config.IsSet("Nesterov07.maxNumOfIter")) 
   {
      maxNumOfIter = config.GetInt("Nesterov07.maxNumOfIter");
      if(maxNumOfIter < 0)
         throw CBMRMException("Nesterov07.maxNumOfIter must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }
   
   if(config.IsSet("Nesterov07.lambda"))           
   {
      lambda = config.GetDouble("Nesterov07.lambda");
      if(lambda <= 0)
         throw CBMRMException("Nesterov07.lambda must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov07.vscaleRaw"))           
   {
       vscaleRaw = config.GetDouble("Nesterov07.vscaleRaw");
       if(vscaleRaw <= 0)
           throw CBMRMException("Nesterov07.vscaleRaw must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov07.gamma_u"))           
   {
       gamma_u = config.GetDouble("Nesterov07.gamma_u");
       if(gamma_u <= 0)
           throw CBMRMException("Nesterov07.gamma_u must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov07.gamma_d"))           
   {
       gamma_d = config.GetDouble("Nesterov07.gamma_d");
       if(gamma_d <= 0)
           throw CBMRMException("Nesterov07.gamma_d must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov07.L0"))           
   {
       Ltil = config.GetDouble("Nesterov07.L0");
       if(Ltil <= 0)
           throw CBMRMException("Nesterov07.L0 must be > 0\n","CNesterov07::ConfirmProgramParameters()");
   }
   else
       throw CBMRMException("Nesterov07.L0 must set\n","CNesterov07::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov07.bias"))           
       useBias = config.GetBool("Nesterov07.bias");
   else
       throw CBMRMException("Nesterov07.bias must set\n","CNesterov07::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov07.epsilonTol")) 
       epsilonTol = config.GetDouble("Nesterov07.epsilonTol");
   else
       throw CBMRMException("Nesterov07.epsilonTol must set\n","CNesterov07::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov07.checkpointInterval")) 
   {
       checkpointInterval = config.GetInt("Nesterov07.checkpointInterval");
       if(checkpointInterval < 1)
           throw CBMRMException("Nesterov07.checkpointInterval must be a positive integer!\n","CNesterov07::ConfirmProgramParameters()");

       if(!config.IsSet("Nesterov07.testConfFilename"))
           throw CBMRMException("Nesterov07.testConfFilename must be set!\n","CNesterov07::ConfirmProgramParameters()");

       if(!config.IsSet("Nesterov07.testSummaryFilename"))
           throw CBMRMException("Nesterov07.testSummaryFilename must be set!\n","CNesterov07::ConfirmProgramParameters()");

       if(!config.IsSet("Model.modelFile"))
           throw CBMRMException("Model.modelFile must be set!\n","CNesterov07::ConfirmProgramParameters()");

       if(!config.IsSet("Nesterov07.predExeFilename"))
           throw CBMRMException("Nesterov07.predExeFilename must be set!\n","CNesterov07::ConfirmProgramParameters()");

       testConfFilename = config.GetString("Nesterov07.testConfFilename");
       testSummaryFilename = config.GetString("Nesterov07.testSummaryFilename");
       modelFilename = config.GetString("Model.modelFile");
       predExeFilename = config.GetString("Nesterov07.predExeFilename");
   }

   if(config.IsSet("Nesterov07.checkpointPrefix")) 
       checkpointPrefix = config.GetString("Nesterov07.checkpointPrefix");

   if(config.IsSet("Nesterov07.checkpointMode")) 
   {
       string mode = config.GetString("Nesterov07.checkpointMode");
       if(mode == "LATEST")
           checkpointMode = KEEP_LATEST;
       if(mode == "ALL")
           checkpointMode = KEEP_ALL;
   } 

   std::cout << "Got L0 = " << Ltil << std::endl;
   std::cout << "Got vscaleRaw = " <<  vscaleRaw << std::endl;
   std::cout << "Got gamma_d = " << gamma_d << ", gamma_u = " << gamma_u << std::endl;
}


void CNesterov07::DisplayIterationInfo(unsigned int iter, double curTime)
{
   if(verbosity <= 0) 
   {
      printf(".");
      if(iter%100 == 0) 
         printf("%d",iter);
   }
   else if(verbosity == 1)
      printf("#%d   dual_obj %.6e, pobj = %.6e, dgap = %.6e, min_dobj = %.6e, min_pobj = %.6e, Ltil = %lg\n", 
             iter, dualObj, primalObj, abs(dualObj + primalObj), minDualObj, minPrimalObj, Ltil);      
   else if(verbosity >= 2)
   {
      printf("#%d   dual_obj %.6e, pobj = %.6e, dgap = %.6e, min_dobj = %.6e, min_pobj = %.6e, Ltil = %lg  time %0.6e\n", 
             iter, dualObj, primalObj, abs(dualObj + primalObj), minDualObj, minPrimalObj, Ltil, curTime); 
   } 
   fflush(stdout);
}


void CNesterov07::SaveCheckpointModel(unsigned int iter)
{
    if(checkpointInterval > 0 && (iter == 1 || iter % checkpointInterval == 0))
    {
         _model->bias = bias_value;
         _model->Save(modelFilename);

         string cmd = predExeFilename + " " + testConfFilename;
         system(cmd.c_str());

         ifstream res_file;
         res_file.open(testSummaryFilename.c_str(), ifstream::in);
         res_file >> test_accuracy;
         res_file.close();

         printf("#%d   test_acc %.6e\n", iter, test_accuracy);   

//       if(checkpointMode == KEEP_LATEST)
//          _model->Save(checkpointPrefix);
//       else 
//       {
//          ostringstream oss;
//          oss << checkpointPrefix << "." << iter;
//          _model->Save(oss.str());
//       }
   }
}


int CNesterov07::CheckTermination(unsigned int iter, double dual_gap)
{
   if(iter >= maxNumOfIter)
   { 
      printf("\nProgram status: Exceeded maximum number of iterations (%d)!\n", 
             maxNumOfIter);
      return 5;
   }
   if (dual_gap < epsilonTol)
   {
       printf("\nProgram status: Converged. (dual_gap < epsilonTol : %.6e < %.6e)", 
           dual_gap, epsilonTol);
       return 2;
   }  

   return 0;
}



void CNesterov07::DisplayAfterTrainingInfo(unsigned int iter, CTimer& lossAndGradientTime,
                                      CTimer& innerSolverTime, CTimer& totalTime)
{
   // legends
   if(verbosity >= 1) 
   {
      printf("\n[Legends]\n");
      if(verbosity > 1)
         printf("dual_obj: dual objective function value");
   }
   
   double norm1, norm2, norminf;
   TheMatrix &w = _model->GetW();

   w.Norm1(norm1);
   w.Norm2(norm2);
   w.NormInf(norminf);
   
   printf("\nNote: the final \alpha is the w_t where J(w_t) is the smallest.\n");
   printf("No. of iterations:  %d\n",iter);
   printf("Min Dual obj. val.: %.6e\n",  minDualObj);
   printf("Min Primal obj. val.: %.6e\n",  minPrimalObj);
   printf("|a|_1:            %.6e\n",norm1);
   printf("|a|_2:            %.6e\n",norm2);
   printf("|a|_oo:           %.6e\n",norminf);
   
   // display timing profile
   printf("\nCPU seconds in:\n");
   printf("1. loss and gradient: %8.2f\n", lossAndGradientTime.CPUTotal());
   printf("2. solver:            %8.2f\n", innerSolverTime.CPUTotal()); 
   printf("               Total: %8.2f\n", totalTime.CPUTotal());
   printf("Wall-clock total:     %8.2f\n", totalTime.WallclockTotal());
}


// int aux_less_than(const void * a, const void * b)
// {
//     struct aux_bias * ptr1 = (struct aux_bias *) a;
//     struct aux_bias * ptr2 = (struct aux_bias *) b;
// 
//     return ptr1->cross_value < ptr2->cross_value ? 1 : -1;
// }

#endif
