#include "GapsRunner.h" #include "utils/Archive.h" #ifdef __GAPS_R_BUILD__ #include <Rcpp.h> #endif #ifdef __GAPS_OPENMP__ #include <omp.h> #endif // boost time helpers #include <boost/date_time/posix_time/posix_time.hpp> namespace bpt = boost::posix_time; #define bpt_now() bpt::microsec_clock::local_time() // for conditionally printing status messages #define GAPS_MESSAGE(b, m) \ do { \ if (b) { \ gaps_printf(m); \ } \ } while(0) // forward declaration template <class Sampler, class DataType> static GapsResult runCoGAPSAlgorithm(const DataType &data, GapsParameters ¶ms, const DataType &uncertainty, GapsRandomState *randState); //////////////////////////////////////////////////////////////////////////////// // helper function, this dispatches the correct run function depending // on the type of GibbsSampler needed for the given parameters template <class DataType> static GapsResult run_helper(const DataType &data, GapsParameters ¶ms, const DataType &uncertainty, GapsRandomState *randState) { // fetch parameters from checkpoint - some are used in initialization if (params.useCheckPoint) { Archive ar(params.checkpointFile, ARCHIVE_READ); ar >> params; ar >> *randState; } if (params.useSparseOptimization) { return runCoGAPSAlgorithm<SparseGibbsSampler>(data, params, uncertainty, randState); } else { return runCoGAPSAlgorithm<DenseGibbsSampler>(data, params, uncertainty, randState); } } // these two functions are the top-level functions exposed to the C++ // code that is being wrapped by any given language GapsResult gaps::run(const Matrix &data, GapsParameters ¶ms, const Matrix &uncertainty, GapsRandomState *randState) { return run_helper(data, params, uncertainty, randState); } GapsResult gaps::run(const std::string &data, GapsParameters ¶ms, const std::string &uncertainty, GapsRandomState *randState) { return run_helper(data, params, uncertainty, randState); } //////////////////////////////////////////////////////////////////////////////// // sum coef * log(i) for i = 1 to total, fit coef from number of atoms // approximates sum of number of atoms (stirling approx to factorial) // this should be proportional to total number of updates static double estimatedNumUpdates(double current, double total, float nAtoms) { double coef = nAtoms / std::log(current); return coef * std::log(std::sqrt(2 * total * gaps::pi)) + total * coef * std::log(total) - total * coef; } template <class Sampler> static double estimatedPercentComplete(const GapsParameters ¶ms, const Sampler &ASampler, const Sampler &PSampler, bpt::ptime startTime, char phase, unsigned iter) { double nIter = static_cast<double>(iter); double nAtomsA = static_cast<double>(ASampler.nAtoms()); double nAtomsP = static_cast<double>(PSampler.nAtoms()); if (phase == 'S') { nIter += params.nIterations; } double totalIter = 2.0 * static_cast<double>(params.nIterations); double estimatedCompleted = estimatedNumUpdates(nIter, nIter, nAtomsA) + estimatedNumUpdates(nIter, nIter, nAtomsP); double estimatedTotal = estimatedNumUpdates(nIter, totalIter, nAtomsA) + estimatedNumUpdates(nIter, totalIter, nAtomsP); return estimatedCompleted / estimatedTotal; } template <class Sampler> static void displayStatus(const GapsParameters ¶ms, const Sampler &ASampler, const Sampler &PSampler, bpt::ptime startTime, char phase, unsigned iter) { if (params.printMessages && params.outputFrequency > 0 && ((iter + 1) % params.outputFrequency) == 0) { bpt::time_duration diff = bpt_now() - startTime; double perComplete = estimatedPercentComplete(params, ASampler, PSampler, startTime, phase, iter); double nSecondsCurrent = diff.total_seconds(); double nSecondsTotal = nSecondsCurrent / perComplete; unsigned elapsedSeconds = static_cast<unsigned>(nSecondsCurrent); unsigned totalSeconds = static_cast<unsigned>(nSecondsTotal); unsigned elapsedHours = elapsedSeconds / 3600; elapsedSeconds -= elapsedHours * 3600; unsigned elapsedMinutes = elapsedSeconds / 60; elapsedSeconds -= elapsedMinutes * 60; unsigned totalHours = totalSeconds / 3600; totalSeconds -= totalHours * 3600; unsigned totalMinutes = totalSeconds / 60; totalSeconds -= totalMinutes * 60; gaps_printf("%d of %d, Atoms: %lu(%lu), ChiSq: %.0f, Time: %02d:%02d:%02d / %02d:%02d:%02d\n", iter + 1, params.nIterations, ASampler.nAtoms(), PSampler.nAtoms(), PSampler.chiSq(), elapsedHours, elapsedMinutes, elapsedSeconds, totalHours, totalMinutes, totalSeconds); gaps_flush(); } } template <class Sampler> static void updateSampler(const GapsParameters ¶ms, Sampler &ASampler, Sampler &PSampler, unsigned nA, unsigned nP) { if (!params.useFixedMatrix || params.whichFixedMatrix != 'A') { ASampler.update(nA, params.maxThreads); if (!params.useFixedMatrix || params.whichFixedMatrix != 'P') { PSampler.sync(ASampler, params.maxThreads); } } if (!params.useFixedMatrix || params.whichFixedMatrix != 'P') { PSampler.update(nP, params.maxThreads); if (!params.useFixedMatrix || params.whichFixedMatrix != 'A') { ASampler.sync(PSampler, params.maxThreads); } } } template <class Sampler> static void createCheckpoint(const GapsParameters ¶ms, Sampler &ASampler, Sampler &PSampler, const GapsRandomState *randState, const GapsStatistics &stats, const GapsRng &rng, char phase, unsigned iter) { if (params.checkpointInterval > 0 && ((iter + 1) % params.checkpointInterval) == 0 && !params.subsetData) { // create backup file std::rename(params.checkpointOutFile.c_str(), (params.checkpointOutFile + ".backup").c_str()); // create checkpoint file Archive ar(params.checkpointOutFile, ARCHIVE_WRITE); ar << params; ar << *randState; ar << ASampler << PSampler << stats << phase << iter << rng; // delete backup file std::remove((params.checkpointOutFile + ".backup").c_str()); ASampler.extraInitialization(); PSampler.extraInitialization(); } } template <class Sampler> static void processCheckpoint(GapsParameters ¶ms, Sampler &ASampler, Sampler &PSampler, GapsRandomState *randState, GapsStatistics &stats, GapsRng &rng, char &phase, unsigned ¤tIter) { // check if running from checkpoint, get all saved data if (params.useCheckPoint) { Archive ar(params.checkpointFile, ARCHIVE_READ); ar >> params; ar >> *randState; ar >> ASampler >> PSampler >> stats >> phase >> currentIter >> rng; } } template <class Sampler> static void runOnePhase(const GapsParameters ¶ms, Sampler &ASampler, Sampler &PSampler, GapsStatistics &stats, const GapsRandomState *randState, GapsRng &rng, bpt::ptime startTime, char phase, unsigned ¤tIter) { for (; currentIter < params.nIterations; ++currentIter) { #ifdef __GAPS_R_BUILD__ Rcpp::checkUserInterrupt(); #endif createCheckpoint(params, ASampler, PSampler, randState, stats, rng, phase, currentIter); // set annealing temperature in calibration phase if (phase == 'C') { float temp = static_cast<float>(2 * currentIter) / static_cast<float>(params.nIterations); ASampler.setAnnealingTemp(gaps::min(1.f, temp)); PSampler.setAnnealingTemp(gaps::min(1.f, temp)); } // number of updates per iteration is poisson unsigned nA = rng.poisson(gaps::max(ASampler.nAtoms(), 10)); unsigned nP = rng.poisson(gaps::max(PSampler.nAtoms(), 10)); updateSampler(params, ASampler, PSampler, nA, nP); if (phase == 'S') { stats.update(ASampler, PSampler); } displayStatus(params, ASampler, PSampler, startTime, phase, currentIter); } } template <class Sampler> static void processFixedMatrix(const GapsParameters ¶ms, Sampler &ASampler, Sampler &PSampler) { // check if we're fixing a matrix if (params.useFixedMatrix) { switch (params.whichFixedMatrix) { GAPS_ASSERT(params.fixedMatrix.nCol() == params.nPatterns); case 'A' : GAPS_ASSERT(params.fixedMatrix.nRow() == params.nGenes); ASampler.setMatrix(params.fixedMatrix); break; case 'P' : GAPS_ASSERT(params.fixedMatrix.nRow() == params.nSamples); PSampler.setMatrix(params.fixedMatrix); break; default: break; // 'N' for none } } } static void calculateNumberOfThreads(GapsParameters params) { // calculate appropiate number of threads if compiled with openmp #ifdef __GAPS_OPENMP__ if (params.printMessages && params.printThreadUsage) { unsigned availableThreads = omp_get_max_threads(); params.maxThreads = gaps::min(availableThreads, params.maxThreads); gaps_printf("Running on %d out of %d available threads\n", params.maxThreads, availableThreads); } #endif } template <class Sampler, class DataType> static void processUncertainty(const GapsParameters params, Sampler &ASampler, Sampler &PSampler, const DataType &uncertainty) { // read in the uncertainty matrix if one is provided if (!uncertainty.empty()) { ASampler.setUncertainty(uncertainty, !params.transposeData, !params.subsetGenes, params); PSampler.setUncertainty(uncertainty, params.transposeData, params.subsetGenes, params); } } // here is the CoGAPS algorithm template <class Sampler, class DataType> static GapsResult runCoGAPSAlgorithm(const DataType &data, GapsParameters ¶ms, const DataType &uncertainty, GapsRandomState *randState) { // check if running in debug mode #ifdef GAPS_DEBUG GAPS_MESSAGE(params.printMessages, "Running in debug mode\n"); #endif // load data into gibbs samplers // we transpose the data in the A sampler so that the update step // is symmetrical for each sampler, this simplifies the code // within the sampler, note the subsetting genes/samples flag must be // flipped if we are flipping the transpose flag GAPS_MESSAGE(params.printMessages, "Loading Data..."); Sampler ASampler(data, !params.transposeData, !params.subsetGenes, params.alphaA, params.maxGibbsMassA, params, randState); Sampler PSampler(data, params.transposeData, params.subsetGenes, params.alphaP, params.maxGibbsMassP, params, randState); processUncertainty(params, ASampler, PSampler, uncertainty); processFixedMatrix(params, ASampler, PSampler); GAPS_MESSAGE(params.printMessages, "Done!\n"); // these variables will get overwritten by checkpoint if provided GapsStatistics stats(params.nGenes, params.nSamples, params.nPatterns); GapsRng rng(randState); char phase = 'C'; unsigned currentIter = 0; processCheckpoint(params, ASampler, PSampler, randState, stats, rng, phase, currentIter); calculateNumberOfThreads(params); // sync samplers and run any additional initialization needed ASampler.sync(PSampler); PSampler.sync(ASampler); ASampler.extraInitialization(); PSampler.extraInitialization(); // record start time bpt::ptime startTime = bpt_now(); // fallthrough through phases, allows algorithm to be resumed in any phase GAPS_ASSERT(phase == 'C' || phase == 'S'); switch (phase) { case 'C': GAPS_MESSAGE(params.printMessages, "-- Calibration Phase --\n"); runOnePhase(params, ASampler, PSampler, stats, randState, rng, startTime, phase, currentIter); phase = 'S'; currentIter = 0; case 'S': GAPS_MESSAGE(params.printMessages, "-- Sampling Phase --\n"); runOnePhase(params, ASampler, PSampler, stats, randState, rng, startTime, phase, currentIter); } // get result GapsResult result(stats); result.meanChiSq = stats.meanChiSq(PSampler); return result; }