src/GapsRunner.h
888a1fd4
 #ifndef __COGAPS_GAPS_RUNNER_H__
 #define __COGAPS_GAPS_RUNNER_H__
49a5b154
 
7bbb1f81
 #include "GapsParameters.h"
8f9d4161
 #include "GapsResult.h"
74377719
 #include "GapsStatistics.h"
bd6604c5
 #include "gibbs_sampler/GibbsSampler.h"
 #include "gibbs_sampler/DenseGibbsSampler.h"
 #include "gibbs_sampler/SparseGibbsSampler.h"
 
 #include <string>
74377719
 
dfc276e0
 // boost time helpers
 #include <boost/date_time/posix_time/posix_time.hpp>
 namespace bpt = boost::posix_time;
 #define bpt_now() bpt::microsec_clock::local_time()
 
bd6604c5
 // forward declarations
 class AbstractGapsRunner;
 
 ///////////////////////////// RAII wrapper /////////////////////////////////////
 
 // This is the class that is exposed to the top-level CoGAPS routine - all 
 // aspects of CoGAPS can be managed through this class. The class itself is
 // just a lightweight wrapper around an abstract interface, which allows for
 // multiple types of GapsRunner to be declared. Which implementation is used
 // depends on the parameters passed to the GapsRunner constructor.
49a5b154
 class GapsRunner
 {
7bbb1f81
 public:
 
     template <class DataType>
     GapsRunner(const DataType &data, const GapsParameters &params);
 
bd6604c5
     ~GapsRunner();
 
7bbb1f81
     template <class DataType>
     void setUncertainty(const DataType &unc, const GapsParameters &params);
 
     GapsResult run();
 
49a5b154
 private:
bd6604c5
 
     AbstractGapsRunner *mRunner;
 
     GapsRunner(const GapsRunner &p); // don't allow copies
     GapsRunner& operator=(const GapsRunner &p); // don't allow copies    
 };
 
 ///////////////////////// Abstract Interface ///////////////////////////////////
 
 // This class is the abstract interface that any implementation of GapsRunner
 // must satisfy. It provides a factory method that will create the appropiate
 // derived class depending on the parameters passed in.
 class AbstractGapsRunner
 {
 public:
 
     AbstractGapsRunner(const GapsParameters &params);
     virtual ~AbstractGapsRunner() {}
 
     template <class DataType>
     static AbstractGapsRunner* create(const DataType &data, const GapsParameters &params);
 
     // can't use template with virtual function
     virtual void setUncertainty(const Matrix &unc, const GapsParameters &params) = 0;
     virtual void setUncertainty(const std::string &unc, const GapsParameters &params) = 0;
 
     GapsResult run();
 
 protected:
 
49a5b154
     GapsStatistics mStatistics;
 
7bbb1f81
     mutable GapsRng mRng;
 
e13559eb
     std::string mCheckpointOutFile;
b9cc8323
 
dfc276e0
     bpt::ptime mStartTime;
6668db1c
 
7bbb1f81
     unsigned mCurrentIteration;
     unsigned mMaxIterations;
     unsigned mMaxThreads;
     unsigned mOutputFrequency;
     unsigned mCheckpointInterval;
e13559eb
     unsigned mNumPatterns;
     unsigned mNumUpdatesA;
     unsigned mNumUpdatesP;
7bbb1f81
     uint32_t mSeed;
6594bb4f
 
7bbb1f81
     bool mPrintMessages;
     bool mPrintThreadUsage;
 
     char mPhase;
     char mFixedMatrix;
e13559eb
         
     void runOnePhase();
731c4313
     double estimatedPercentComplete() const;
e13559eb
     void displayStatus();
     void createCheckpoint();
bd6604c5
 
     virtual float chiSq() const = 0;
     virtual float meanChiSq() const = 0;
     virtual unsigned nAtoms(char which) const = 0;
     virtual void setAnnealingTemp(float temp) = 0;
     virtual void updateStatistics() = 0;
     virtual Archive& readSamplers(Archive &ar) = 0;
     virtual Archive& writeSamplers(Archive &ar) = 0;
     virtual void updateSampler(unsigned nA, unsigned nP) = 0;
 };
 
 ///////////////////// GapsRunner Implementations ///////////////////////////////
 
 // This implementation uses a DenseGibbsSampler internally
 class DenseGapsRunner : public AbstractGapsRunner
 {
 public:
 
     ~DenseGapsRunner() {}
 
     template <class DataType>
     DenseGapsRunner(const DataType &data, const GapsParameters &params);
 
     void setUncertainty(const Matrix &unc, const GapsParameters &params);
     void setUncertainty(const std::string &unc, const GapsParameters &params);
 
 private:
 
     DenseGibbsSampler mASampler;
     DenseGibbsSampler mPSampler;
 
     float chiSq() const;
     float meanChiSq() const;
     unsigned nAtoms(char which) const;
     void setAnnealingTemp(float temp);
     void updateStatistics();
     Archive& readSamplers(Archive &ar);
     Archive& writeSamplers(Archive &ar);
     void updateSampler(unsigned nA, unsigned nP);
 };
 
 // This implementation uses a SparseGibbsSampler internally
 class SparseGapsRunner : public AbstractGapsRunner
 {
 public:
 
     ~SparseGapsRunner() {}
 
     template <class DataType>
     SparseGapsRunner(const DataType &data, const GapsParameters &params);
 
     void setUncertainty(const Matrix &unc, const GapsParameters &params);
     void setUncertainty(const std::string &unc, const GapsParameters &params);
 
 private:
 
     SparseGibbsSampler mASampler;
     SparseGibbsSampler mPSampler;
 
     float chiSq() const;
     float meanChiSq() const;
     unsigned nAtoms(char which) const;
     void setAnnealingTemp(float temp);
     void updateStatistics();
     Archive& readSamplers(Archive &ar);
     Archive& writeSamplers(Archive &ar);
     void updateSampler(unsigned nA, unsigned nP);
49a5b154
 };
 
bd6604c5
 /////////////////////// GapsRunner - templated functions ///////////////////////
 
2cdb0256
 template <class DataType>
7bbb1f81
 GapsRunner::GapsRunner(const DataType &data, const GapsParameters &params)
bd6604c5
     : mRunner(AbstractGapsRunner::create(data, params))
 {}
 
 template <class DataType>
 void GapsRunner::setUncertainty(const DataType &unc, const GapsParameters &params)
 {
     mRunner->setUncertainty(unc, params);
 }
 
 /////////////////// AbstractGapsRunner - templated functions ///////////////////
 
 template <class DataType>
 AbstractGapsRunner* AbstractGapsRunner::create(const DataType &data,
 const GapsParameters &params)
86b53c65
 {
bd6604c5
     if (params.useSparseOptimization)
     {
         return new SparseGapsRunner(data, params);
     }
     return new DenseGapsRunner(data, params);
 }
 
 //////////////////// DenseGapsRunner - templated functions /////////////////////
7bbb1f81
 
bd6604c5
 template <class DataType>
 DenseGapsRunner::DenseGapsRunner(const DataType &data,
 const GapsParameters &params)
     :
 AbstractGapsRunner(params),
 mASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA, params.maxGibbsMassA, params),
 mPSampler(data, params.transposeData, params.subsetGenes, params.alphaP, params.maxGibbsMassP, params)
 {
7bbb1f81
     switch (mFixedMatrix)
     {
bd6604c5
         case 'A' : mASampler.setMatrix(params.fixedMatrix); break;
         case 'P' : mPSampler.setMatrix(params.fixedMatrix); break;
         default: break; // 'N' for none
7bbb1f81
     }
 
     // overwrite with info from checkpoint file
     if (params.useCheckPoint)
     {
         Archive ar(params.checkpointFile, ARCHIVE_READ);
         ar >> mNumPatterns >> mSeed >> mMaxIterations >> mFixedMatrix >> mPhase
bd6604c5
             >> mCurrentIteration >> mNumUpdatesA >> mNumUpdatesP >> mRng;
         readSamplers(ar);
7bbb1f81
         GapsRng::load(ar);
     }
 
bd6604c5
     mASampler.sync(mPSampler);
     mPSampler.sync(mASampler);
 
     // AP matrix not stored in checkpoint
     if (params.useCheckPoint)
     {
         mASampler.recalculateAPMatrix();
         mPSampler.recalculateAPMatrix();
     }
86b53c65
 }
 
bd6604c5
 //////////////////// SparseGapsRunner - templated functions ////////////////////
 
86b53c65
 template <class DataType>
bd6604c5
 SparseGapsRunner::SparseGapsRunner(const DataType &data,
 const GapsParameters &params)
     :
 AbstractGapsRunner(params),
 mASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA, params.maxGibbsMassA, params),
 mPSampler(data, params.transposeData, params.subsetGenes, params.alphaP, params.maxGibbsMassP, params)
2cdb0256
 {
bd6604c5
     switch (mFixedMatrix)
     {
         case 'A' : mASampler.setMatrix(params.fixedMatrix); break;
         case 'P' : mPSampler.setMatrix(params.fixedMatrix); break;
         default: break;
     }
 
     // overwrite with info from checkpoint file
     if (params.useCheckPoint)
     {
         Archive ar(params.checkpointFile, ARCHIVE_READ);
         ar >> mNumPatterns >> mSeed >> mMaxIterations >> mFixedMatrix >> mPhase
             >> mCurrentIteration >> mNumUpdatesA >> mNumUpdatesP >> mRng;
         readSamplers(ar);
         GapsRng::load(ar);
     }
 
     mASampler.sync(mPSampler);
     mPSampler.sync(mASampler);
2cdb0256
 }
 
bd6604c5
 #endif // __COGAPS_GAPS_RUNNER_H__