/* Copyright (c) 2009 NICTA
 * All rights reserved.  
 * 
 * The contents of this file are subject to the Mozilla Public License 
 * Version 1.1 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at 
 * http://www.mozilla.org/MPL/ 
 * 
 * Software distributed under the License is distributed on an "AS IS" 
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the 
 * License for the specific language governing rights and limitations 
 * under the License. 
 * 
 * Authors: Xinhua Zhang (xinhua.zhang81@gmail.com)
 *
 * Created: (19/05/2009) 
 *
 * Last Updated: (20/05/2009)   
 */

#ifndef _NESTEROV83_CPP_
#define _NESTEROV83_CPP_

#include "common.hpp"
#include "Nesterov83.hpp"
#include "timer.hpp"
#include "configuration.hpp"
#include "loss.hpp"
#include "l2_simplex_projection.hpp"
#include "l2_box_projection.hpp"
#include "hingeloss.hpp"

#include <fstream>
#include <sstream>


using namespace std;


CNesterov83::CNesterov83(CModel *model, CHingeLoss* loss)
   : CSolver(model, loss),
     verbosity(0),
     maxNumOfIter(10000),
     lambda(1.0),
     checkpointPrefix("model.checkpoint"),
     checkpointInterval(100000),
     checkpointMode(KEEP_LATEST),
     offset(loss->getNumExample(), 1, SML::DENSE),
     curGrad(loss->getNumExample(), 1, SML::DENSE),
     cumulantGrad(loss->getNumExample(), 1, SML::DENSE),
     xk(loss->getNumExample(), 1, SML::DENSE),
     yk(loss->getNumExample(), 1, SML::DENSE),
     zk(loss->getNumExample(), 1, SML::DENSE),
     bestyk(loss->getNumExample(), 1, SML::DENSE)
{
   ConfirmProgramParameters();
   num_example = loss->getNumExample();
   minDualObj = 1e99;
}


/**  Destructor
 */
CNesterov83::~CNesterov83()
{
}



/**  Start training/learning a model w.r.t. the loss object (and the data supplied to it).
 *
 */
void CNesterov83::Train()
{
    CTimer totalTime;             // total runtime of the training
    CTimer innerSolverTime;       // time for inner optimization (e.g., QP or LP)
    CTimer lossAndGradientTime;   // time for loss and gradient computation

    unsigned int iter = 0;        // iteration count
    double varScaleFactor = num_example * 2;

    int exitFlag, i;

    TheMatrix &w = _model->GetW();
    ublas::vector<double> dsq1(num_example), dsq2(num_example);
    ublas::vector<double> lower_bound(num_example), upper_bound(num_example), sigma(num_example);
    ublas::vector<double> solution(num_example), offset_vec(num_example);
    double z = 0.0;
    double useless, temp;

    _loss->setVarScaleFactor(varScaleFactor);
    _loss->getLabels(sigma);
    cumulantGrad.Zero();

    for(i = 0; i < num_example; i ++)
    {
        dsq1(i) = dsq2(i) = 1.0;
        lower_bound(i) = 0.0;
        upper_bound(i) = 1.0 / lambda / varScaleFactor;
    }
    std::cout<<"upper bound = " << upper_bound(0) << std::endl;

    int numPos = 0, numNeg = 0;
    for(i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    numPos ++;
        else                 numNeg ++;
    }
    int n_small = numPos > numNeg ? numNeg : numPos;
    double v_pos = n_small / varScaleFactor / (2.0 * numPos * lambda);
    double v_neg = n_small / varScaleFactor / (2.0 * numNeg * lambda);
//    v_neg = v_pos = 0.0;
    for(i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    xk.Set(i, v_pos);
        else                 xk.Set(i, v_neg);
    }

    useless = 0.0;
    for (i = 0; i < num_example; i ++)
    {
        xk.Get(i, temp);
        useless += temp * sigma[i];
    }
    if (fabs(useless - 0.0) > 1e-5)
    {
        std::cout<<"incorrect initialization"<<std::endl;
        exit(EXIT_FAILURE);
    }

    // start training
    totalTime.Start();

    do
    {
        lossAndGradientTime.Start();
        _loss->ComputeDualLossAndGradient(dualObj, curGrad, xk);
        lossAndGradientTime.Stop();

//      printf("Current Gradient = \n");    curGrad.Print();    printf("================================\n");

       // Minimize piecewise linear lower bound R_t
        innerSolverTime.Start();

        offset.Assign(xk);
        offset.ScaleAdd(-1.0 / L, curGrad);
        offset.GetRow(0, num_example, &(offset_vec[0]));
        CDiagQPSolver innerSolver1(num_example, lower_bound, upper_bound,
                                    dsq1, sigma, offset_vec, z);
        innerSolver1.Solve(solution);
        yk.SetRow(0, num_example, &(solution.data()[0]));

        useless = 0.0;
        for (i = 0; i < num_example; i ++)  {
            yk.Get(i, temp);
            useless += temp * sigma[i];
        }
        if (fabs(useless - 0.0) > 1e-5) {
            std::cout<<"incorrect solution yk"<<std::endl;
            exit(EXIT_FAILURE);
        }

        cumulantGrad.ScaleAdd((iter + 1.0) / 2.0, curGrad);
        offset.Assign(cumulantGrad);
        offset.Scale(-1.0 / L);
        offset.GetRow(0, num_example, &(offset_vec[0]));

        CDiagQPSolver innerSolver2(num_example, lower_bound, upper_bound,
                                       dsq2, sigma, offset_vec, z);
        innerSolver2.Solve(solution);
        zk.SetRow(0, num_example, &(solution.data()[0]));

//       if (iter == 1375) {
//           std::cout<<"Showing zk:\n";       zk.Print(); 
//           std::cout<<"Showing offset:\n";   std::cout<<offset_vec<<std::endl;
//       }
        useless = 0.0;
        for (i = 0; i < num_example; i ++)  {
            zk.Get(i, temp);
            useless += temp * sigma[i];
        }
        if (fabs(useless - 0.0) > 1e-5)   {
            std::cout << "incorrect solution zk " << useless << std::endl;
            innerSolver2.dumpToFile("qp_def.dat");
            exit(EXIT_FAILURE);
        }

        xk.Zero();
        xk.ScaleAdd(2.0 / (iter + 3.0), zk);
        xk.ScaleAdd((iter + 1.0) / (iter + 3.0), yk);

//       printf("xk = \n");     xk.Print();     printf("==========================\n");

        innerSolverTime.Stop();

        _loss->ComputeDualLossAndGradient(dualObj, curGrad, yk);
        dualObj *= lambda;

        if(minDualObj > dualObj)    {
            minDualObj = dualObj;
            bestyk.Assign(yk);
        }

        // Display details of each iteration
        DisplayIterationInfo(iter, totalTime.CurrentCPUTotal());
   
       // Save model obtained in previous iteration
//       SaveCheckpointModel(iter);
      
//      Check if termination criteria satisfied
        exitFlag = CheckTermination(iter);  

        iter ++;

    } while(!exitFlag);

    totalTime.Stop();
    
    // Display after-training details
    DisplayAfterTrainingInfo(iter, lossAndGradientTime, innerSolverTime, totalTime);
}


/**   Validate program parameters set in Configuration.
 */
void CNesterov83::ConfirmProgramParameters()
{
   Configuration &config = Configuration::GetInstance();  // make sure configuration file is read before this!
   
   if(config.IsSet("Nesterov83.verbosity")) 
      verbosity = config.GetInt("Nesterov83.verbosity");
   
   if(config.IsSet("Nesterov83.maxNumOfIter")) 
   {
      maxNumOfIter = config.GetInt("Nesterov83.maxNumOfIter");
      if(maxNumOfIter < 0)
         throw CBMRMException("Nesterov83.maxNumOfIter must be > 0\n","CNesterov83::ConfirmProgramParameters()");
   }
   
   if(config.IsSet("Nesterov83.lambda"))           
   {
      lambda = config.GetDouble("Nesterov83.lambda");
      if(lambda <= 0)
         throw CBMRMException("Nesterov83.lambda must be > 0\n","CNesterov83::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov83.L"))           
   {
       L = config.GetDouble("Nesterov83.L");
       if(L <= 0)
           throw CBMRMException("Nesterov83.L must be > 0\n","CNesterov83::ConfirmProgramParameters()");
   }
   else
       throw CBMRMException("Nesterov83.L must set\n","CNesterov83::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov83.checkpointInterval")) 
   {
       checkpointInterval = config.GetInt("Nesterov83.checkpointInterval");
       if(checkpointInterval < 1)
           throw CBMRMException("Nesterov83.checkpointInterval must be a positive integer!\n","CNesterov83::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov83.checkpointPrefix")) 
       checkpointPrefix = config.GetString("Nesterov83.checkpointPrefix");

   if(config.IsSet("Nesterov83.checkpointMode")) 
   {
       string mode = config.GetString("Nesterov83.checkpointMode");
       if(mode == "LATEST")
           checkpointMode = KEEP_LATEST;
       if(mode == "ALL")
           checkpointMode = KEEP_ALL;
   }   
}


void CNesterov83::DisplayIterationInfo(unsigned int iter, double curTime)
{
   if(verbosity <= 0) 
   {
      printf(".");
      if(iter%100 == 0) 
         printf("%d",iter);
   }
   else if(verbosity == 1)
      printf("#%d   dual_obj %.6e\n", iter, dualObj);      
   else if(verbosity >= 2)
   {
      printf("#%d   dual_obj %.6e  time %0.6e\n", 
             iter, dualObj, curTime); 
   } 
   fflush(stdout);
}


void CNesterov83::SaveCheckpointModel(unsigned int iter)
{
   if(iter%checkpointInterval == 0) 
   {
      if(checkpointMode == KEEP_LATEST)
         _model->Save(checkpointPrefix);
      else 
      {
         ostringstream oss;
         oss << checkpointPrefix << "." << iter;
         _model->Save(oss.str());
      }
   }
}


int CNesterov83::CheckTermination(unsigned int iter)
{
   if(iter >= maxNumOfIter)
   { 
      printf("\nProgram status: Exceeded maximum number of iterations (%d)!\n", 
             maxNumOfIter);
      return 5;
   }

   return 0;
}



void CNesterov83::DisplayAfterTrainingInfo(unsigned int iter, CTimer& lossAndGradientTime,
                                      CTimer& innerSolverTime, CTimer& totalTime)
{
   // legends
   if(verbosity >= 1) 
   {
      printf("\n[Legends]\n");
      if(verbosity > 1)
         printf("dual_obj: dual objective function value");
   }
   
   double norm1, norm2, norminf;
   bestyk.Norm1(norm1);
   bestyk.Norm2(norm2);
   bestyk.NormInf(norminf);
   
   printf("\nNote: the final \alpha is the w_t where J(w_t) is the smallest.\n");
   printf("No. of iterations:  %d\n",iter);
   printf("Min Dual obj. val.: %.6e\n",  minDualObj);
   printf("|a|_1:            %.6e\n",norm1);
   printf("|a|_2:            %.6e\n",norm2);
   printf("|a|_oo:           %.6e\n",norminf);
   
   // display timing profile
   printf("\nCPU seconds in:\n");
   printf("1. loss and gradient: %8.2f\n", lossAndGradientTime.CPUTotal());
   printf("2. solver:            %8.2f\n", innerSolverTime.CPUTotal()); 
   printf("               Total: %8.2f\n", totalTime.CPUTotal());
   printf("Wall-clock total:     %8.2f\n", totalTime.WallclockTotal());
}

#endif
