/* Copyright (c) 2009 NICTA 
 * All rights reserved. 
 * 
 * The contents of this file are subject to the Mozilla Public License 
 * Version 1.1 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at 
 * http://www.mozilla.org/MPL/ 
 * 
 * Software distributed under the License is distributed on an "AS IS" 
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the 
 * License for the specific language governing rights and limitations 
 * under the License. 
 * 
 * Authors: Xinhua Zhang (xinhua.zhang81@gmail.com)
 *
 * Created: (19/05/2009) 
 *
 * Last Updated: (20/05/2009)   
 */

#ifndef _NESTEROV03_CPP_
#define _NESTEROV03_CPP_

#include "common.hpp"
#include "Nesterov03.hpp"
#include "timer.hpp"
#include "configuration.hpp"
#include "loss.hpp"
#include "l2_simplex_projection.hpp"
#include "l2_box_projection.hpp"
#include "hingeloss.hpp"

#include <fstream>
#include <sstream>


using namespace std;


CNesterov03::CNesterov03(CModel *model, CHingeLoss* loss)
   : CSolver(model, loss),
     verbosity(0),
     maxNumOfIter(10000),
     lambda(1.0),
     checkpointPrefix("model.checkpoint"),
     checkpointInterval(100000),
     checkpointMode(KEEP_LATEST),
     offset(loss->getNumExample(), 1, SML::DENSE),
     curGrad(loss->getNumExample(), 1, SML::DENSE),
     beta(loss->getNumExample(), 1, SML::DENSE),
     alphak(loss->getNumExample(), 1, SML::DENSE),
     wk(loss->getNumFeature(), 1, SML::DENSE),
     dsq1(loss->getNumExample()), dsq2(loss->getNumExample()),
     lower_bound(loss->getNumExample()), upper_bound(loss->getNumExample()), 
     sigma(loss->getNumExample()),
     solution(loss->getNumExample()), offset_vec(loss->getNumExample()),
     best_alpha(loss->getNumExample(), 1, SML::DENSE),
     best_w(loss->getNumFeature(), 1, SML::DENSE)
{
   ConfirmProgramParameters();
   num_example = loss->getNumExample();
   num_feature = loss->getNumFeature();
   maxDualObj = -1e99;
   minPrimalObj = 1e99;

   varScaleFactor03 = 1.0 / lambda / num_example / vscaleRaw;

   L /= lambda * pow(num_example * varScaleFactor03, 2);

   _loss->setVarScaleFactor03(varScaleFactor03);
   _loss->getLabels(sigma);

   for(int i = 0; i < num_example; i ++)
   {
       dsq1(i) = dsq2(i) = 1.0;
       lower_bound(i) = 0.0;
       upper_bound(i) = varScaleFactor03;
   }
   std::cout<<"upper bound = " << upper_bound(0) << std::endl;
}




/**  Destructor
 */
CNesterov03::~CNesterov03()
{
}


void CNesterov03::init_alpha(TheMatrix &var)
{
    int i, numPos = 0, numNeg = 0;

    for(i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    numPos ++;
        else                 numNeg ++;
    }

    int n_small = numPos > numNeg ? numNeg : numPos;
    double v_pos = n_small * varScaleFactor03 / (2.0 * numPos);
    double v_neg = n_small * varScaleFactor03 / (2.0 * numNeg);
//      v_pos = v_neg = 0.0;

    for(int i = 0; i < num_example; i ++)
    {
        if (sigma[i] > 0)    var.Set(i, v_pos);
        else                 var.Set(i, v_neg);
    }

    checkVarValid(var, lower_bound, upper_bound, sigma, z, "beta0");
}


void CNesterov03::checkVarValid(TheMatrix &var, 
                                ublas::vector<double> &lower_bound, 
                                ublas::vector<double> &upper_bound,
                                ublas::vector<double> &sigma, double z, char * varName)
{
    double innerProd = 0.0, temp;
    int num_example = lower_bound.size();
    for (int i = 0; i < num_example; i ++)
    {
        var.Get(i, temp);
        if (temp < lower_bound[i] - 1e-6 || temp > upper_bound[i] + 1e-6)
        {
            std::cout<<varName<<"["<<i<<"] out of bound: "<< temp <<std::endl;
            exit(EXIT_FAILURE);
        }        
        innerProd += temp * sigma[i];
    }
    if (useBias && fabs(innerProd - z) > 1e-5)
    {
        std::cout<<"Linear constraint violated."<<std::endl;
        exit(EXIT_FAILURE);
    }
}



void CNesterov03::map_w_to_alpha(TheMatrix &wk, TheMatrix &result, double mu)
{
    TheMatrix temp_ex(num_example, 1, SML::DENSE);
    double inv_nsmu = 1.0 / (num_example * varScaleFactor03 * mu);

    _loss->aux_map_w_to_alpha(wk, temp_ex);
    
    temp_ex.GetRow(0, num_example, &(offset_vec[0]));
    for (int i = 0; i < num_example; i ++)
        offset_vec[i] = (1.0 - offset_vec[i]) * inv_nsmu;

    if (useBias)
    {
        CDiagQPSolver innerSolver(num_example, lower_bound, upper_bound,
            dsq1, sigma, offset_vec, z);

        innerSolver.Solve(solution);
    }
    else
    {
        CDiagBoxQPSolver innerSolver(num_example, lower_bound, upper_bound, offset_vec);
        innerSolver.Solve(solution);
    }

    result.SetRow(0, num_example, &(solution.data()[0]));
}


void CNesterov03::map_alpha_to_alpha(TheMatrix &alpha, TheMatrix &result)
{
    double useless;

    _loss->ComputeDualLossAndGradient03(useless, curGrad, alpha, lambda);
    curGrad.Scale(1.0 / L);
    curGrad.Add(alpha);

    curGrad.GetRow(0, num_example, &(offset_vec[0]));

    if (useBias)
    {
        CDiagQPSolver innerSolver(num_example, lower_bound, upper_bound,
            dsq1, sigma, offset_vec, z);

        innerSolver.Solve(solution);
    }
    else
    {
        CDiagBoxQPSolver innerSolver(num_example, lower_bound, upper_bound, offset_vec);
        innerSolver.Solve(solution);
    }

    result.SetRow(0, num_example, &(solution.data()[0]));
}




/**  Start training/learning a model w.r.t. the loss object (and the data supplied to it).
 *
 */
void CNesterov03::Train()
{
    CTimer totalTime;             // total runtime of the training
    CTimer innerSolverTime;       // time for inner optimization (e.g., QP or LP)
    CTimer lossAndGradientTime;   // time for loss and gradient computation

    unsigned int iter = 0;        // iteration count
    double mu, dualObj, primalObj;
    int exitFlag, i;

    TheMatrix temp_fea(num_feature, 1, SML::DENSE);
    
    z = 0.0;
    mu = 2.0 * L;

    init_alpha(beta);

    map_alpha_to_w (beta, wk);
    map_alpha_to_alpha(beta, alphak);
    
    // start training
    totalTime.Start();

    do
    {
        double tau = 2.0 / (iter + 3.0);

        map_w_to_alpha(wk, beta, mu);
        beta.Scale(tau);
        beta.ScaleAdd(1.0 - tau, alphak);

        map_alpha_to_w(beta, temp_fea);
        wk.Scale(1.0 - tau);
        wk.ScaleAdd(tau, temp_fea);

        map_alpha_to_alpha(beta, alphak);

        mu *= 1.0 - tau;

        _loss->ComputeDualLossAndGradient03(dualObj, curGrad, alphak, lambda);

        if(maxDualObj < dualObj)    {
            maxDualObj = dualObj;
            best_alpha.Assign(alphak);
        }

        primalObj = _loss->ComputePrimalObj(best_alpha, lambda, temp_fea, 0, false, mu);
        if (minPrimalObj > primalObj)
        {
            minPrimalObj = primalObj;
            best_w.Assign(wk);
        }
        

        // Display details of each iteration
        DisplayIterationInfo(iter, totalTime.CurrentCPUTotal(), dualObj, primalObj);
   
       // Save model obtained in previous iteration
//       SaveCheckpointModel(iter);
      
//      Check if termination criteria satisfied
        exitFlag = CheckTermination(iter);  

        iter ++;

    } while(!exitFlag);

    totalTime.Stop();
    
    // Display after-training details
    DisplayAfterTrainingInfo(iter, lossAndGradientTime, innerSolverTime, totalTime);
}


/**   Validate program parameters set in Configuration.
 */
void CNesterov03::ConfirmProgramParameters()
{
   Configuration &config = Configuration::GetInstance();  // make sure configuration file is read before this!
   
   if(config.IsSet("Nesterov03.verbosity")) 
      verbosity = config.GetInt("Nesterov03.verbosity");
   
   if(config.IsSet("Nesterov03.maxNumOfIter")) 
   {
      maxNumOfIter = config.GetInt("Nesterov03.maxNumOfIter");
      if(maxNumOfIter < 0)
         throw CBMRMException("Nesterov03.maxNumOfIter must be > 0\n","CNesterov03::ConfirmProgramParameters()");
   }
   
   if(config.IsSet("Nesterov03.lambda"))           
   {
      lambda = config.GetDouble("Nesterov03.lambda");
      if(lambda <= 0)
         throw CBMRMException("Nesterov03.lambda must be > 0\n","CNesterov03::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov03.vscaleRaw"))           
   {
       vscaleRaw = config.GetDouble("Nesterov03.vscaleRaw");
       if(vscaleRaw <= 0)
           throw CBMRMException("Nesterov03.vscaleRaw must be > 0\n","CNesterov03::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov03.L"))           
   {
       L = config.GetDouble("Nesterov03.L");
       if(L <= 0)
           throw CBMRMException("Nesterov03.L must be > 0\n","CNesterov03::ConfirmProgramParameters()");
   }
   else
       throw CBMRMException("Nesterov03.L must set\n","CNesterov03::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov03.bias"))           
       useBias = config.GetBool("Nesterov03.bias");
   else
       throw CBMRMException("Nesterov03.bias must set\n","CNesterov03::ConfirmProgramParameters()");

   if(config.IsSet("Nesterov03.checkpointInterval")) 
   {
       checkpointInterval = config.GetInt("Nesterov03.checkpointInterval");
       if(checkpointInterval < 1)
           throw CBMRMException("Nesterov03.checkpointInterval must be a positive integer!\n","CNesterov03::ConfirmProgramParameters()");
   }

   if(config.IsSet("Nesterov03.checkpointPrefix")) 
       checkpointPrefix = config.GetString("Nesterov03.checkpointPrefix");

   if(config.IsSet("Nesterov03.checkpointMode")) 
   {
       string mode = config.GetString("Nesterov03.checkpointMode");
       if(mode == "LATEST")
           checkpointMode = KEEP_LATEST;
       if(mode == "ALL")
           checkpointMode = KEEP_ALL;
   }   

   std::cout << "Got L = " << L << std::endl;
   std::cout << "Got vscaleRaw = " <<  vscaleRaw << std::endl;
}


void CNesterov03::DisplayIterationInfo(unsigned int iter, double curTime, double dualObj, double primalObj)
{
   if(verbosity <= 0) 
   {
      printf(".");
      if(iter%100 == 0) 
         printf("%d",iter);
   }
   else if(verbosity == 1)
      printf("#%d   primal_obj = %.6e\tdual_obj = %.6e\tdual_gap = %.6e\n", iter, primalObj, dualObj, 
                                primalObj - dualObj);      
   else if(verbosity >= 2)
   {
      printf("#%d   primal_obj = %.6e\tdual_obj = %.6e\tdual_gap = %.6e\ttime %0.6e\n", 
             iter, primalObj, dualObj, primalObj - dualObj, curTime); 
   } 
   fflush(stdout);
}


void CNesterov03::SaveCheckpointModel(unsigned int iter)
{
   if(iter%checkpointInterval == 0) 
   {
      if(checkpointMode == KEEP_LATEST)
         _model->Save(checkpointPrefix);
      else 
      {
         ostringstream oss;
         oss << checkpointPrefix << "." << iter;
         _model->Save(oss.str());
      }
   }
}


int CNesterov03::CheckTermination(unsigned int iter)
{
   if(iter >= maxNumOfIter)
   { 
      printf("\nProgram status: Exceeded maximum number of iterations (%d)!\n", 
             maxNumOfIter);
      return 5;
   }

   return 0;
}



void CNesterov03::DisplayAfterTrainingInfo(unsigned int iter, CTimer& lossAndGradientTime,
                                      CTimer& innerSolverTime, CTimer& totalTime)
{
   // legends
   if(verbosity >= 1) 
   {
      printf("\n[Legends]\n");
      if(verbosity > 1)
         printf("dual_obj: dual objective function value");
   }
   
   double norm1, norm2, norminf;
   best_alpha.Norm1(norm1);
   best_alpha.Norm2(norm2);
   best_alpha.NormInf(norminf);
   
   printf("\nNote: the final \alpha is the w_t where J(w_t) is the smallest.\n");
   printf("No. of iterations:  %d\n",iter);
   printf("Min Dual obj. val.: %.6e\n",  maxDualObj);
   printf("|a|_1:            %.6e\n",norm1);
   printf("|a|_2:            %.6e\n",norm2);
   printf("|a|_oo:           %.6e\n",norminf);
   
   // display timing profile
   printf("\nCPU seconds in:\n");
   printf("1. loss and gradient: %8.2f\n", lossAndGradientTime.CPUTotal());
   printf("2. solver:            %8.2f\n", innerSolverTime.CPUTotal()); 
   printf("               Total: %8.2f\n", totalTime.CPUTotal());
   printf("Wall-clock total:     %8.2f\n", totalTime.WallclockTotal());
}

#endif
