/* Copyright (c) 2009 NICTA 
 * All rights reserved. 
 * 
 * The contents of this file are subject to the Mozilla Public License 
 * Version 1.1 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at 
 * http://www.mozilla.org/MPL/ 
 * 
 * Software distributed under the License is distributed on an "AS IS" 
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the 
 * License for the specific language governing rights and limitations 
 * under the License. 
 * 
 * Authors: Xinhua Zhang (xinhua.zhang81@gmail.com)
 *
 * Created: (19/05/2009) 
 *
 * Last Updated: (20/05/2009)   
 */

#ifndef _Nesterov07_HPP_
#define _Nesterov07_HPP_

#include <string>
#include "common.hpp"
#include "solver.hpp"
#include "bmrminnersolver.hpp"
#include "model.hpp"
#include "loss.hpp"
#include "sml.hpp"
#include "timer.hpp"
#include "hingeloss.hpp"


struct aux_bias
{
    double cross_value;
    int label;
};

/**   Class for CNesterov07 solver.
 *    This type of solver iteratively builds up a convex lower-bound of the 
 *      objective function, and performs minimization on the lower-bound.
 */
class CNesterov07 : public CSolver
{
   public:
      /** Constructor
       *  @param model [r/w] model object
       *  @param loss [read] loss object
       */
      CNesterov07(CModel* model, CHingeLoss* loss);
      
      // Destructor
      virtual ~CNesterov07();
      
      // Methods
      virtual void Train();

   protected:

      TheMatrix grad_y, grad_t, cumulantGrad, xk, TLy, zk, offset, ytemp, cur_w;
      unsigned int num_example;
      double dualObj, minDualObj, primalObj, minPrimalObj;            // dual objective function value   
      double a, A, Ltil;
      double gamma_d, gamma_u;
      bool useBias;
      double vscaleRaw;

      /** Verbosity level
       */
      int verbosity;

//       double find_bias(ublas::vector<double> &disc, ublas::vector<double> &sigma);
      double bias_value;

      std::string testConfFilename;
      std::string testSummaryFilename;
      std::string modelFilename;
      std::string predExeFilename;
      double test_accuracy;

//       struct aux_bias ** aux_buf_pointer;

      /** Maximum number of CNesterov07 iteration
       */
      unsigned int maxNumOfIter;

      double epsilonTol;

      /** Regularization constant
       */
      double lambda;
      
      /** Prefix for intermediate/checkpoint model files
       */
      std::string checkpointPrefix;
      
      /** The number of iterations before saving a checkpoint model
       */
      unsigned int checkpointInterval;
      
      /** Selected type of checkpoint.
       */
      unsigned int checkpointMode;   
      
      /** Types of checkpoint mode.
       *  KEEP_ALL -- Keep a checkpoint model after every $checkpointInterval$ 
       *              iterations
       *  KEEP_LATEST -- Keep only the latest model
       */
      enum CHECKPOINT_MODE {KEEP_ALL, KEEP_LATEST};
      
      
      /** Validate all provided program parameters are good to use.
       *  E.g. lambda must be > 0.
       */
      virtual void ConfirmProgramParameters();

      
      /** Display iteration information
       */
      virtual void DisplayIterationInfo(unsigned int iter, double curTime);
      

      /** Save model at every #checkpointInterval iterations
       */
      virtual void SaveCheckpointModel(unsigned int iter);
         
      
      /** Termination criteria check
       *
       */
      virtual int CheckTermination(unsigned int iter, double dual_gap);


      /** Display information after training is done
       */
      virtual void  DisplayAfterTrainingInfo(unsigned int iter, CTimer& lossAndGradientTime,
                                             CTimer& innerSolverTime, CTimer& totalTime);

      void checkVarValid(TheMatrix &var, 
          ublas::vector<double> &lower_bound, 
          ublas::vector<double> &upper_bound,
          ublas::vector<double> &sigma, double z, char * varName);      
};

#endif
