Browse code

new R Interface working

Tom Sherman authored on 02/07/2018 00:58:55
Showing24 changed files

... ...
@@ -51,7 +51,6 @@ Collate:
51 51
     'CoGAPS.R'
52 52
     'GWCoGAPS.R'
53 53
     'RcppExports.R'
54
-    'SMatrix.R'
55 54
     'binaryA.R'
56 55
     'calcCoGAPSStat.R'
57 56
     'calcGeneGSStat.R'
... ...
@@ -1,6 +1,6 @@
1 1
 # Generated by roxygen2: do not edit by hand
2 2
 
3
-export()
3
+export(CoGAPS)
4 4
 export(GWCoGAPS)
5 5
 export(GWCoGapsFromCheckpoint)
6 6
 export(MergeResultsWithSCE)
... ...
@@ -1,10 +1,11 @@
1 1
 #' @include class-CogapsParams.R
2 2
 NULL
3 3
 
4
-#' CoGAPS
5
-#' @name CoGAPS Matrix Factorization Algorithm
4
+#' CoGAPS Matrix Factorization Algorithm
5
+#' @export 
6 6
 #' @docType methods
7 7
 #' @rdname CoGAPS-methods
8
+#'
8 9
 #' @description calls the C++ MCMC code and performs Bayesian
9 10
 #' matrix factorization returning the two matrices that reconstruct
10 11
 #' the data matrix
... ...
@@ -15,34 +16,52 @@ NULL
15 16
 #' @param uncertainty uncertainty matrix (same supported types as data)
16 17
 #' @param fixedMatrix data for fixing the values of either the A or P matrix;
17 18
 #'  used in conjuction with whichMatrixFixed (see CogapsParams)
18
-#' @param checkpointFile name of the checkpoint file
19
+#' @param checkpointInFile name of the checkpoint file
19 20
 #' @param ... keeps backwards compatibility with arguments from older versions
20 21
 #' @return CogapsResult object
21 22
 #' @examples
22 23
 #' # Running from R object
23 24
 #' data(GIST)
24 25
 #' resultA <- CoGAPS(GIST.D)
26
+#'
25 27
 #' # Running from file name
26 28
 #' gist_path <- system.file("extdata/GIST.mtx", package="CoGAPS")
27 29
 #' resultB <- CoGAPS(gist_path)
30
+#'
28 31
 #' Setting Parameters
29 32
 #' params <- new("CogapsParams")
30 33
 #' params <- setParam(params, "nPatterns", 5)
31 34
 #' resultC <- CoGAPS(GIST.D, params)
32 35
 #' @importFrom methods new
33
-#' @export
34 36
 setGeneric("CoGAPS", function(data, params=new("CogapsParams"),
35
-uncertainty=NULL, fixedMatrix=NULL, checkpointFile=NULL, ...)
37
+uncertainty=NULL, fixedMatrix=matrix(0), checkpointInFile="", ...)
36 38
 {
37 39
     # parse parameters from ...
38 40
     params <- parseOldParams(params, list(...))
39 41
     params <- parseDirectParams(params, list(...))
40 42
 
43
+    # check if fixed matrix is set so is whichMatrixFixed
44
+    if (!all(fixedMatrix == matrix(0)) & (!params@whichMatrixFixed %in% c('A', 'P')))
45
+        stop("fixedMatrix passed, but whichMatrixFixed not set") 
46
+
47
+    # check if uncertainty is null
48
+    if (is.null(uncertainty) & class(data) == "character")
49
+        uncertainty <- ""
50
+    else if (is.null(uncertainty))
51
+        uncertainty <- matrix(0)
52
+
53
+    # check if we're running from a checkpoint
54
+    if (nchar(checkpointInFile) > 0)
55
+    {
56
+        message(paste("Running CoGAPS from a checkpoint, all parameters",
57
+            "besides data and uncertainty will be ignored"))
58
+    }
59
+
41 60
     # call method
42 61
     gapsReturnList <- standardGeneric("CoGAPS")
43 62
 
44 63
     # convert list to CogapsResult object
45
-    return(CogapsResult(
64
+    return(new("CogapsResult",
46 65
         Amean       = gapsReturnList$Amean,
47 66
         Asd         = gapsReturnList$Asd,
48 67
         Pmean       = gapsReturnList$Pmean,
... ...
@@ -50,114 +69,91 @@ uncertainty=NULL, fixedMatrix=NULL, checkpointFile=NULL, ...)
50 69
         seed        = gapsReturnList$seed,
51 70
         meanChiSq   = gapsReturnList$meanChiSq,
52 71
         diagnostics = gapsReturnList$diagnostics
53
-    )) 
72
+    ))
54 73
 })
55 74
 
56 75
 #' @rdname CoGAPS-methods
57 76
 #' @aliases CoGAPS
58 77
 #' @importFrom tools file_ext
59 78
 setMethod("CoGAPS", signature(data="character", params="CogapsParams"),
60
-function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
79
+function(data, params, uncertainty, fixedMatrix, checkpointInFile, ...)
61 80
 {
62 81
     # check file extension
63 82
     if (!(file_ext(data) %in% c("tsv", "csv", "mtx")))
64 83
         stop("unsupported file extension for data")
65 84
 
66 85
     # check uncertainty matrix
67
-    if (!is.null(uncertainty))
68
-    {
69
-        if (class(uncertainty) != "character")
70
-            stop("uncertainty must be same data type as data (file name)")
71
-        if (!(file_ext(uncertainty) %in% c("tsv", "csv", "mtx")))
72
-            stop("unsupported file extension for uncertainty")
73
-    }
86
+    if (class(uncertainty) != "character")
87
+        stop("uncertainty must be same data type as data (file name)")
88
+    if (nchar(uncertainty) > 0 & !(file_ext(uncertainty) %in% c("tsv", "csv", "mtx")))
89
+        stop("unsupported file extension for uncertainty")
74 90
 
75 91
     # call C++ function
76
-    cogaps_cpp_from_file(data, uncertainty, params@nPatterns,
77
-        params@nIterations, params@outputFrequency, params@seed, params@alphaA,
78
-        params@alphaP, params@maxGibbsMassA, params@maxGibbsMassP,
79
-        params@messages, params@singleCell, params@checkpointOutFile,
80
-        params@nCores)
92
+    cogaps_cpp_from_file(data, params, uncertainty, fixedMatrix, checkpointInFile)
81 93
 })
82 94
 
83 95
 #' @rdname CoGAPS-methods
84 96
 #' @aliases CoGAPS
85 97
 setMethod("CoGAPS", signature(data="matrix", params="CogapsParams"),
86
-function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
98
+function(data, params, uncertainty, fixedMatrix, checkpointInFile, ...)
87 99
 {
88 100
     # check matrix
89
-    if (!is.null(uncertainty) & class(uncertainty) != "matrix")
101
+    if (class(uncertainty) != "matrix")
90 102
         stop("uncertainty must be same data type as data (matrix)")
91
-    checkDataMatrix(data, uncertainty)
103
+    checkDataMatrix(data, uncertainty, params)
92 104
 
93 105
     # call C++ function
94
-    cogaps_cpp(data, uncertainty, params@nPatterns,
95
-        params@nIterations, params@outputFrequency, params@seed, params@alphaA,
96
-        params@alphaP, params@maxGibbsMassA, params@maxGibbsMassP,
97
-        params@messages, params@singleCell, params@checkpointOutFile,
98
-        params@nCores)
106
+    cogaps_cpp(data, params, uncertainty, fixedMatrix, checkpointInFile)
99 107
 })
100 108
 
101 109
 #' @rdname CoGAPS-methods
102 110
 #' @aliases CoGAPS
103 111
 setMethod("CoGAPS", signature(data="data.frame", params="CogapsParams"),
104
-function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
112
+function(data, params, uncertainty, fixedMatrix, checkpointInFile, ...)
105 113
 {
106 114
     # check matrix
107
-    if (!is.null(uncertainty) & class(uncertainty) != "matrix")
115
+    if (class(uncertainty) != "matrix")
108 116
         stop("uncertainty must be matrix when data is a data.frame")
109
-    checkDataMatrix(data.matrix(data), uncertainty)
117
+    checkDataMatrix(data.matrix(data), uncertainty, params)
110 118
 
111 119
     # call C++ function
112
-    cogaps_cpp(data.matrix(data), uncertainty, params@nPatterns,
113
-        params@nIterations, params@outputFrequency, params@seed, params@alphaA,
114
-        params@alphaP, params@maxGibbsMassA, params@maxGibbsMassP,
115
-        params@messages, params@singleCell, params@checkpointOutFile,
116
-        params@nCores)
120
+    cogaps_cpp(data.matrix(data), params, uncertainty, fixedMatrix, checkpointInFile)
117 121
 })
118 122
 
119 123
 #' @rdname CoGAPS-methods
120 124
 #' @aliases CoGAPS
121 125
 #' @importClassesFrom SummarizedExperiment SummarizedExperiment
122 126
 setMethod("CoGAPS", signature(data="SummarizedExperiment", params="CogapsParams"),
123
-function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
127
+function(data, params, uncertainty, fixedMatrix, checkpointInFile, ...)
124 128
 {
125 129
     # extract count matrix
126 130
     countMatrix = assay(data, "counts")
127 131
 
128 132
     # check matrix
129
-    if (!is.null(uncertainty) & class(uncertainty) != "matrix")
133
+    if (class(uncertainty) != "matrix")
130 134
         stop("uncertainty must be matrix when data is a SummarizedExperiment")
131
-    checkDataMatrix(countMatrix, uncertainty)
135
+    checkDataMatrix(countMatrix, uncertainty, params)
132 136
 
133 137
     # call C++ function
134
-    cogaps_cpp(countMatrix, uncertainty, params@nPatterns,
135
-        params@nIterations, params@outputFrequency, params@seed, params@alphaA,
136
-        params@alphaP, params@maxGibbsMassA, params@maxGibbsMassP,
137
-        params@messages, params@singleCell, params@checkpointOutFile,
138
-        params@nCores)
138
+    cogaps_cpp(countMatrix, params, uncertainty, fixedMatrix, checkpointInFile)
139 139
 })
140 140
 
141 141
 #' @rdname CoGAPS-methods
142 142
 #' @aliases CoGAPS
143 143
 #' @importClassesFrom SingleCellExperiment SingleCellExperiment
144 144
 setMethod("CoGAPS", signature(data="SingleCellExperiment", params="CogapsParams"),
145
-function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
145
+function(data, params, uncertainty, fixedMatrix, checkpointInFile, ...)
146 146
 {
147 147
     # extract count matrix
148 148
     countMatrix = assay(data, "counts")
149 149
 
150 150
     # check matrix
151
-    if (!is.null(uncertainty) & class(uncertainty) != "matrix")
151
+    if (class(uncertainty) != "matrix")
152 152
         stop("uncertainty must be matrix when data is a SingleCellExperiment")
153
-    checkDataMatrix(countMatrix, uncertainty)
153
+    checkDataMatrix(countMatrix, uncertainty, params)
154 154
 
155 155
     # call C++ function
156
-    cogaps_cpp(countMatrix, uncertainty, params@nPatterns,
157
-        params@nIterations, params@outputFrequency, params@seed, params@alphaA,
158
-        params@alphaP, params@maxGibbsMassA, params@maxGibbsMassP,
159
-        params@messages, params@singleCell, params@checkpointOutFile,
160
-        params@nCores)
156
+    cogaps_cpp(countMatrix, params, uncertainty, fixedMatrix, checkpointInFile)
161 157
 })
162 158
 
163 159
 #' Information About Package Compilation
... ...
@@ -171,4 +167,17 @@ function(data, params, uncertainty, fixedMatrix, checkpointFile, ...)
171 167
 buildReport <- function()
172 168
 {
173 169
     getBuildReport_cpp()
170
+}
171
+
172
+#' Check that provided data is valid
173
+#'
174
+#' @param data data matrix
175
+#' @param uncertainty uncertainty matrix
176
+#' @return throws an error if data has problems
177
+checkDataMatrix <- function(data, uncertainty, params)
178
+{
179
+    if (sum(data < 0) > 0 | sum(uncertainty < 0) > 0)
180
+        stop("negative values in data and/or uncertainty matrix")
181
+    if (nrow(data) == params@nPatterns || ncol(data) == params@nPatterns)
182
+        stop("nPatterns must be less than dimensions of data")
174 183
 }
175 184
\ No newline at end of file
... ...
@@ -149,7 +149,7 @@ runInitialPhase <- function(simulationName, allDataSets, nFactor, ...)
149 149
 
150 150
         # run CoGAPS without any fixed patterns
151 151
         cptFileName <- paste(simulationName, "_initial_cpt_", i, ".out", sep="")
152
-        CoGAPS(sampleD, sampleS, nFactor=nFactor, seed=nut[i],
152
+        CoGAPS(sampleD, uncertainty=sampleS, nFactor=nFactor, seed=nut[i],
153 153
             checkpointFile=cptFileName, ...)
154 154
     }
155 155
     return(initialResult)
... ...
@@ -202,7 +202,7 @@ runFinalPhase <- function(simulationName, allDataSets, consensusPatterns, nCores
202 202
 
203 203
         # run CoGAPS with fixed patterns
204 204
         cptFileName <- paste(simulationName, "_final_cpt_", i, ".out", sep="")
205
-        CoGAPS(sampleD, sampleS, fixedPatterns=consensusPatterns,
205
+        CoGAPS(sampleD, uncertainty=sampleS, fixedMatrix=consensusPatterns,
206 206
             nFactor=nFactorFinal, seed=nut[i], checkpointFile=cptFileName,
207 207
             whichMatrixFixed='P', ...)
208 208
 
... ...
@@ -1,12 +1,12 @@
1 1
 # Generated by using Rcpp::compileAttributes() -> do not edit by hand
2 2
 # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
3 3
 
4
-cogaps_cpp_from_file <- function(data, unc, nPatterns, maxIterations, outputFrequency, seed, alphaA, alphaP, maxGibbsMassA, maxGibbsMassP, messages, singleCell, checkpointOutFile, nCores) {
5
-    .Call('_CoGAPS_cogaps_cpp_from_file', PACKAGE = 'CoGAPS', data, unc, nPatterns, maxIterations, outputFrequency, seed, alphaA, alphaP, maxGibbsMassA, maxGibbsMassP, messages, singleCell, checkpointOutFile, nCores)
4
+cogaps_cpp_from_file <- function(data, params, unc, fixedMatrix, checkpointInFile) {
5
+    .Call('_CoGAPS_cogaps_cpp_from_file', PACKAGE = 'CoGAPS', data, params, unc, fixedMatrix, checkpointInFile)
6 6
 }
7 7
 
8
-cogaps_cpp <- function(data, unc, nPatterns, maxIterations, outputFrequency, seed, alphaA, alphaP, maxGibbsMassA, maxGibbsMassP, messages, singleCell, checkpointOutFile, nCores) {
9
-    .Call('_CoGAPS_cogaps_cpp', PACKAGE = 'CoGAPS', data, unc, nPatterns, maxIterations, outputFrequency, seed, alphaA, alphaP, maxGibbsMassA, maxGibbsMassP, messages, singleCell, checkpointOutFile, nCores)
8
+cogaps_cpp <- function(data, params, unc, fixedMatrix, checkpointInFile) {
9
+    .Call('_CoGAPS_cogaps_cpp', PACKAGE = 'CoGAPS', data, params, unc, fixedMatrix, checkpointInFile)
10 10
 }
11 11
 
12 12
 getBuildReport_cpp <- function() {
13 13
deleted file mode 100644
... ...
@@ -1,75 +0,0 @@
1
-CoGAPS <- function(D, S=NULL, nFactor=7, nEquil=1000, nSample=1000, nOutputs=1000,
2
-                   nSnapshots=0, alphaA=0.01, alphaP=0.01, maxGibbmassA=100, maxGibbmassP=100,
3
-                   seed=-1, messages=TRUE, singleCellRNASeq=FALSE, whichMatrixFixed='N',
4
-                   fixedPatterns=matrix(0), checkpointInterval=0, 
5
-                   checkpointFile="gaps_checkpoint.out", nCores=1, ...)
6
-{
7
-  #returns a default uncertainty matrix of 0.1*D dataset if it's greater than 0.1
8
-  if(!is.null(S))
9
-    S <- pmax(0.1*D, 0.1)
10
-  
11
-  # get v2 arguments
12
-  oldArgs <- list(...)
13
-  if (!is.null(oldArgs$nOutR))
14
-    nOutputs <- oldArgs$nOutR
15
-  if (!is.null(oldArgs$max_gibbmass_paraA))
16
-    maxGibbmassA <- oldArgs$max_gibbmass_paraA
17
-  if (!is.null(oldArgs$max_gibbmass_paraP))
18
-    maxGibbmassP <- oldArgs$max_gibbmass_paraP
19
-  if (!is.null(oldArgs$sampleSnapshots) & is.null(oldArgs$numSnapshots))
20
-    nSnapshots <- 100
21
-  if (!is.null(oldArgs$sampleSnapshots) & !is.null(oldArgs$numSnapshots))
22
-    nSnapshots <- oldArgs$numSnapshots
23
-  if (missing(D) & !is.null(oldArgs$data))
24
-    D <- oldArgs$data
25
-  if (missing(S) & !is.null(oldArgs$unc))
26
-    S <- oldArgs$unc
27
-  
28
-  # get pump arguments - hidden for now from user
29
-  pumpThreshold <- "unique"
30
-  nPumpSamples <- 0
31
-  if (!is.null(list(...)$pumpThreshold))
32
-    pumpThreshold <- list(...)$pumpThreshold
33
-  if (!is.null(list(...)$nPumpSamples))
34
-    pumpThreshold <- list(...)$nPumpSamples
35
-  
36
-  # check arguments
37
-  if (class(D) != "matrix" | class(S) != "matrix")
38
-    stop('D and S must be matrices')
39
-  if (any(D < 0) | any(S < 0))
40
-    stop('D and S matrix must be non-negative')
41
-  if (nrow(D) != nrow(S) | ncol(D) != ncol(S))
42
-    stop('D and S matrix have different dimensions')
43
-  if (whichMatrixFixed == 'A' & nrow(fixedPatterns) != nrow(D))
44
-    stop('invalid number of rows for fixedPatterns')
45
-  if (whichMatrixFixed == 'A' & ncol(fixedPatterns) > nFactor)
46
-    stop('invalid number of columns for fixedPatterns')
47
-  if (whichMatrixFixed == 'P' & nrow(fixedPatterns) > nFactor)
48
-    stop('invalid number of rows for fixedPatterns')
49
-  if (whichMatrixFixed == 'P' & ncol(fixedPatterns) != ncol(D))
50
-    stop('invalid number of columns for fixedPatterns')
51
-  thresholdEnum <- c("unique", "cut")
52
-  
53
-  # get seed
54
-  if (seed < 0)
55
-  {
56
-    # TODO get time in milliseconds
57
-    seed <- 0
58
-  }
59
-  
60
-  # run algorithm with call to C++ code
61
-  result <- cogaps_cpp(D, S, nFactor, nEquil, nEquil/10, nSample, nOutputs,
62
-                       nSnapshots, alphaA, alphaP, maxGibbmassA, maxGibbmassP, seed, messages,
63
-                       singleCellRNASeq, whichMatrixFixed, fixedPatterns, checkpointInterval,
64
-                       checkpointFile, which(thresholdEnum==pumpThreshold), nPumpSamples,
65
-                       nCores)
66
-  
67
-  # label matrices and return list
68
-  patternNames <- paste('Patt', 1:nFactor, sep='')
69
-  rownames(result$Amean) <- rownames(result$Asd) <- rownames(D)
70
-  colnames(result$Amean) <- colnames(result$Asd) <- patternNames
71
-  rownames(result$Pmean) <- rownames(result$Psd) <- patternNames
72
-  colnames(result$Pmean) <- colnames(result$Psd) <- colnames(D)
73
-  return(v2CoGAPS(result, ...)) # backwards compatible with v2
74
-}
75
-  
... ...
@@ -12,7 +12,7 @@ setClass("CogapsParams", slots = c(
12 12
     maxGibbsMassP = "numeric",
13 13
     seed = "numeric",
14 14
     messages = "logical",
15
-    singleCellRNASeq = "logical",
15
+    singleCell = "logical",
16 16
     whichMatrixFixed = "character",
17 17
     checkpointInterval = "numeric",
18 18
     checkpointOutFile = "character", 
... ...
@@ -36,7 +36,7 @@ setMethod("initialize", "CogapsParams",
36 36
         .Object@maxGibbsMassP <- 100
37 37
         .Object@seed <- getMilliseconds(as.POSIXlt(Sys.time()))
38 38
         .Object@messages <- TRUE
39
-        .Object@singleCellRNASeq <- FALSE
39
+        .Object@singleCell <- FALSE
40 40
         .Object@whichMatrixFixed <- "N"
41 41
         .Object@checkpointInterval <- 0
42 42
         .Object@checkpointOutFile <- "gaps_checkpoint.out"
... ...
@@ -125,9 +125,9 @@ setGeneric("parseDirectParams", function(object, args)
125 125
 setMethod("setParam", signature(object="CogapsParams"),
126 126
     function(object, whichParam, value)
127 127
     {
128
-        slot(params, whichParam) <- value
129
-        validObject(params)
130
-        return(params)
128
+        slot(object, whichParam) <- value
129
+        validObject(object)
130
+        return(object)
131 131
     }
132 132
 )
133 133
 
... ...
@@ -136,7 +136,7 @@ setMethod("setParam", signature(object="CogapsParams"),
136 136
 setMethod("getParam", signature(object="CogapsParams"),
137 137
     function(object, whichParam)
138 138
     {
139
-        slot(params, whichParam)
139
+        slot(object, whichParam)
140 140
     }
141 141
 )
142 142
 
... ...
@@ -145,25 +145,40 @@ setMethod("getParam", signature(object="CogapsParams"),
145 145
 setMethod("parseOldParams", signature(object="CogapsParams"),
146 146
     function(object, oldArgs)
147 147
     {
148
-        if (!is.null(oldArgs$nFactor))
149
-            params@nPatterns <- oldArgs$nFactor
150
-        if (!is.null(oldArgs$nEquil))
151
-            params@nIterations <- oldArgs$nEquil
152
-        if (!is.null(oldArgs$nSample))
153
-            params@nIterations <- oldArgs$nSample
154
-        if (!is.null(oldArgs$nOutputs))
155
-            params@outputFrequency <- nOutputs
156
-        if (!is.null(oldArgs$maxGibbmassA))
157
-            params@maxGibbsMassA <- oldArgs$maxGibbmassA
158
-        if (!is.null(oldArgs$maxGibbmassP))
159
-            params@maxGibbsMassA <- oldArgs$maxGibbmassP
148
+        helper <- function(arg, params, newArg)
149
+        {
150
+            if (!is.null(oldArgs[[arg]]))
151
+            {
152
+                warning(paste("parameter", arg, "is deprecated, it will still",
153
+                    "work, but setting", newArg, "in the params object is the",
154
+                    "preferred method"))
155
+                params <- setParam(params, newArg, oldArgs[[arg]])
156
+                oldArgs[[arg]] <- NULL
157
+            }            
158
+            return(params)
159
+        }
160 160
 
161
-        if (!is.null(oldArgs$nSnapshots))
161
+        object <- helper("nFactor", object, "nPatterns")
162
+        object <- helper("nIter", object, "nIterations")
163
+        object <- helper("nEquil", object, "nIterations")
164
+        object <- helper("nSample", object, "nIterations")
165
+        object <- helper("nOutR", object, "outputFrequency")
166
+        object <- helper("nOutput", object, "outputFrequency")
167
+        object <- helper("maxGibbmassA", object, "maxGibbsMassA")
168
+        object <- helper("max_gibbmass_paraA", object, "maxGibbsMassA")
169
+        object <- helper("maxGibbmassP", object, "maxGibbsMassP")
170
+        object <- helper("max_gibbmass_paraP", object, "maxGibbsMassP")
171
+        object <- helper("checkpointFile", object, "checkpointOutFile")
172
+        object <- helper("singleCellRNASeq", object, "singleCell")
173
+
174
+        if (!is.null(oldArgs$nSnapshots) | !is.null(oldArgs$sampleSnapshots) | !is.null(oldArgs$numSnapshots))
162 175
             warning("snapshots not currently supported in release build")
163 176
         if (!is.null(oldArgs$fixedPatterns))
164 177
             stop("pass fixed matrix in with 'fixedMatrix' argument")
165
-        
166
-        return(params)
178
+        if (!is.null(oldArgs$S))
179
+            stop("pass uncertainty matrix in with 'uncertainty', not 'S'")
180
+
181
+        return(object)
167 182
     }
168 183
 )
169 184
 
... ...
@@ -174,11 +189,11 @@ setMethod("parseDirectParams", signature(object="CogapsParams"),
174 189
     {
175 190
         for (s in slotNames(object))
176 191
         {
177
-            if (!is.null(args[s]))
192
+            if (!is.null(args[[s]]))
178 193
             {
179
-                params <- setParam(object, s, args[s])
194
+                object <- setParam(object, s, args[[s]])
180 195
             }
181 196
         }
182
-        return(params)
197
+        return(object)
183 198
     }
184 199
 )
185 200
\ No newline at end of file
... ...
@@ -188,7 +188,7 @@ sc_runFinalPhase <- function(simulationName, allDataSets, consensusAs, nCores, .
188 188
 
189 189
         # run CoGAPS with fixed patterns
190 190
         cptFileName <- paste(simulationName, "_final_cpt_", i, ".out", sep="")
191
-        CoGAPS(sampleD, sampleS, fixedPatterns=consensusAs,
191
+        CoGAPS(sampleD, uncertainty=sampleS, fixedMatrix=consensusAs,
192 192
             nFactor=nFactorFinal, seed=nut[i], checkpointFile=cptFileName,
193 193
             whichMatrixFixed='A', singleCellRNASeq=TRUE, ...)
194 194
 
... ...
@@ -1,33 +1,34 @@
1 1
 % Generated by roxygen2: do not edit by hand
2 2
 % Please edit documentation in R/CoGAPS.R
3 3
 \docType{methods}
4
-\name{CoGAPS Matrix Factorization Algorithm}
4
+\name{CoGAPS}
5 5
 \alias{CoGAPS}
6
-\alias{CoGAPS Matrix Factorization Algorithm}
7 6
 \alias{CoGAPS,SingleCellExperiment,CogapsParams-method}
8 7
 \alias{CoGAPS,SummarizedExperiment,CogapsParams-method}
9 8
 \alias{CoGAPS,character,CogapsParams-method}
10 9
 \alias{CoGAPS,data.frame,CogapsParams-method}
11 10
 \alias{CoGAPS,matrix,CogapsParams-method}
12
-\title{CoGAPS}
13
-\format{An object of class \code{NULL} of length 0.}
11
+\title{CoGAPS Matrix Factorization Algorithm}
14 12
 \usage{
13
+CoGAPS(data, params = new("CogapsParams"), uncertainty = NULL,
14
+  fixedMatrix = matrix(0), checkpointInFile = "", ...)
15
+
15 16
 \S4method{CoGAPS}{character,CogapsParams}(data, params = new("CogapsParams"),
16
-  uncertainty = NULL, fixedMatrix = NULL, checkpointFile = NULL, ...)
17
+  uncertainty = NULL, fixedMatrix = matrix(0), checkpointInFile = "", ...)
17 18
 
18 19
 \S4method{CoGAPS}{matrix,CogapsParams}(data, params = new("CogapsParams"),
19
-  uncertainty = NULL, fixedMatrix = NULL, checkpointFile = NULL, ...)
20
+  uncertainty = NULL, fixedMatrix = matrix(0), checkpointInFile = "", ...)
20 21
 
21 22
 \S4method{CoGAPS}{data.frame,CogapsParams}(data, params = new("CogapsParams"),
22
-  uncertainty = NULL, fixedMatrix = NULL, checkpointFile = NULL, ...)
23
+  uncertainty = NULL, fixedMatrix = matrix(0), checkpointInFile = "", ...)
23 24
 
24 25
 \S4method{CoGAPS}{SummarizedExperiment,CogapsParams}(data,
25
-  params = new("CogapsParams"), uncertainty = NULL, fixedMatrix = NULL,
26
-  checkpointFile = NULL, ...)
26
+  params = new("CogapsParams"), uncertainty = NULL,
27
+  fixedMatrix = matrix(0), checkpointInFile = "", ...)
27 28
 
28 29
 \S4method{CoGAPS}{SingleCellExperiment,CogapsParams}(data,
29
-  params = new("CogapsParams"), uncertainty = NULL, fixedMatrix = NULL,
30
-  checkpointFile = NULL, ...)
30
+  params = new("CogapsParams"), uncertainty = NULL,
31
+  fixedMatrix = matrix(0), checkpointInFile = "", ...)
31 32
 }
32 33
 \arguments{
33 34
 \item{data}{File name or R object (see details for supported types)}
... ...
@@ -39,7 +40,7 @@
39 40
 \item{fixedMatrix}{data for fixing the values of either the A or P matrix;
40 41
 used in conjuction with whichMatrixFixed (see CogapsParams)}
41 42
 
42
-\item{checkpointFile}{name of the checkpoint file}
43
+\item{checkpointInFile}{name of the checkpoint file}
43 44
 
44 45
 \item{...}{keeps backwards compatibility with arguments from older versions}
45 46
 }
... ...
@@ -59,13 +60,14 @@ Currently, raw count matrices are the only supported R object. For
59 60
 # Running from R object
60 61
 data(GIST)
61 62
 resultA <- CoGAPS(GIST.D)
63
+
62 64
 # Running from file name
63 65
 gist_path <- system.file("extdata/GIST.mtx", package="CoGAPS")
64 66
 resultB <- CoGAPS(gist_path)
67
+
65 68
 Setting Parameters
66 69
 params <- new("CogapsParams")
67 70
 params <- setParam(params, "nPatterns", 5)
68 71
 resultC <- CoGAPS(GIST.D, params)
69 72
 }
70
-\keyword{datasets}
71 73
 
72 74
new file mode 100644
... ...
@@ -0,0 +1,20 @@
1
+% Generated by roxygen2: do not edit by hand
2
+% Please edit documentation in R/CoGAPS.R
3
+\name{checkDataMatrix}
4
+\alias{checkDataMatrix}
5
+\title{Check that provided data is valid}
6
+\usage{
7
+checkDataMatrix(data, uncertainty, params)
8
+}
9
+\arguments{
10
+\item{data}{data matrix}
11
+
12
+\item{uncertainty}{uncertainty matrix}
13
+}
14
+\value{
15
+throws an error if data has problems
16
+}
17
+\description{
18
+Check that provided data is valid
19
+}
20
+
... ...
@@ -59,6 +59,7 @@ public:
59 59
     // don't have C++11 and don't want to add another dependency on boost,
60 60
     // so no template tricks
61 61
 
62
+    friend Archive& operator<<(Archive &ar, char val)     { return writeToArchive(ar, val); }
62 63
     friend Archive& operator<<(Archive &ar, bool val)     { return writeToArchive(ar, val); }
63 64
     friend Archive& operator<<(Archive &ar, int val)      { return writeToArchive(ar, val); }
64 65
     friend Archive& operator<<(Archive &ar, unsigned val) { return writeToArchive(ar, val); }
... ...
@@ -68,6 +69,7 @@ public:
68 69
     friend Archive& operator<<(Archive &ar, double val)   { return writeToArchive(ar, val); }
69 70
     friend Archive& operator<<(Archive &ar, boost::random::mt11213b val) { return writeToArchive(ar, val); }
70 71
 
72
+    friend Archive& operator>>(Archive &ar, char &val)     { return readFromArchive(ar, val); }
71 73
     friend Archive& operator>>(Archive &ar, bool &val)     { return readFromArchive(ar, val); }
72 74
     friend Archive& operator>>(Archive &ar, int &val)      { return readFromArchive(ar, val); }
73 75
     friend Archive& operator>>(Archive &ar, unsigned &val) { return readFromArchive(ar, val); }
... ...
@@ -76,28 +78,6 @@ public:
76 78
     friend Archive& operator>>(Archive &ar, float &val)    { return readFromArchive(ar, val); }
77 79
     friend Archive& operator>>(Archive &ar, double &val)   { return readFromArchive(ar, val); }
78 80
     friend Archive& operator>>(Archive &ar, boost::random::mt11213b &val) { return readFromArchive(ar, val); }
79
-
80
-/*
81
-    friend Archive& operator>>(Archive &ar, uint64_t &val)
82
-    {
83
-        return readPrimitiveFromArchive(ar, val);
84
-    }
85
-
86
-    template<typename T>
87
-    friend Archive& operator<<(Archive &ar, T val)
88
-    {
89
-        ar.mStream.write(reinterpret_cast<char*>(&val), sizeof(T)); // NOLINT
90
-        return ar;
91
-    }
92
-
93
-
94
-    template<typename T>
95
-    friend Archive& operator>>(Archive &ar, T &val)
96
-    {
97
-        ar.mStream.read(reinterpret_cast<char*>(&val), sizeof(T)); // NOLINT
98
-        return ar;
99
-    }
100
-*/
101 81
 };
102 82
 
103 83
 #endif
... ...
@@ -32,73 +32,91 @@ static Rcpp::NumericMatrix createRMatrix(const Matrix &mat)
32 32
     return rmat;
33 33
 }
34 34
 
35
-static bool nonNullUncertainty(const RowMatrix &mat)
35
+static bool isNull(const RowMatrix &mat)
36 36
 {
37
-    return mat.nRow() > 1 || mat.nCol() > 1;
37
+    return mat.nRow() == 1 && mat.nCol() == 1;
38 38
 }
39 39
 
40
-static bool nonNullUncertainty(const std::string &path)
40
+static bool isNull(const std::string &path)
41 41
 {
42
-    return !path.empty();
42
+    return path.empty();
43 43
 }
44 44
 
45 45
 template <class T>
46
-static Rcpp::List cogapsRun(const T &data, const T &unc, unsigned nPatterns,
47
-unsigned maxIter, unsigned outputFrequency, unsigned seed, float alphaA,
48
-float alphaP, float maxGibbsMassA, float maxGibbsMassP, bool messages,
49
-bool singleCell, unsigned nCores)
46
+static Rcpp::List cogapsRun(const T &data, const Rcpp::S4 &params, const T &unc,
47
+const RowMatrix &fixedMatrix, const std::string &checkpointInFile)
50 48
 {
51
-    GapsDispatcher dispatcher(seed);
49
+    GapsDispatcher dispatcher;
52 50
 
53
-    dispatcher.setNumPatterns(nPatterns);
54
-    dispatcher.setMaxIterations(maxIter);
55
-    dispatcher.setOutputFrequency(outputFrequency);
56
-    
57
-    dispatcher.setAlpha(alphaA, alphaP);
58
-    dispatcher.setMaxGibbsMass(maxGibbsMassA, maxGibbsMassP);
51
+    // check if we're initializing with a checkpoint or not
52
+    if (!isNull(checkpointInFile))
53
+    {
54
+        dispatcher.initialize(data, checkpointInFile);
55
+    }
56
+    else
57
+    {
58
+        dispatcher.initialize(data, params.slot("nPatterns"), params.slot("seed"));
59 59
 
60
-    dispatcher.printMessages(messages);
61
-    dispatcher.singleCell(singleCell);
62
-    dispatcher.setNumCoresPerSet(nCores);
63
-    
64
-    dispatcher.loadData(data);
60
+        // set optional parameters
61
+        dispatcher.setMaxIterations(params.slot("nIterations"));
62
+        dispatcher.setOutputFrequency(params.slot("outputFrequency"));
63
+
64
+        dispatcher.setSparsity(params.slot("alphaA"), params.slot("alphaP"),
65
+            params.slot("singleCell"));
66
+        
67
+        dispatcher.setMaxGibbsMass(params.slot("maxGibbsMassA"),
68
+            params.slot("maxGibbsMassP"));
65 69
 
66
-    if (nonNullUncertainty(unc))
70
+        dispatcher.printMessages(params.slot("messages"));
71
+
72
+        // check if running with a fixed matrix
73
+        if (!isNull(fixedMatrix))
74
+        {
75
+            dispatcher.setFixedMatrix(params.slot("whichMatrixFixed"), fixedMatrix);
76
+        }
77
+    }
78
+
79
+    // set the uncertainty matrix
80
+    if (!isNull(unc))
67 81
     {
68 82
         dispatcher.setUncertainty(unc);
69 83
     }
84
+    
85
+    // set parameters that aren't saved in the checkpoint
86
+    dispatcher.setNumCoresPerSet(params.slot("nCores"));
87
+    dispatcher.setCheckpointInterval(params.slot("checkpointInterval"));
88
+    dispatcher.setCheckpointOutFile(params.slot("checkpointOutFile"));
70 89
 
90
+    // run the dispatcher and return the GapsResult in an R list
71 91
     GapsResult result(dispatcher.run());
92
+    GAPS_ASSERT(result.meanChiSq > 0.f);
72 93
     return Rcpp::List::create(
73 94
         Rcpp::Named("Amean") = createRMatrix(result.Amean),
74 95
         Rcpp::Named("Pmean") = createRMatrix(result.Pmean),
75 96
         Rcpp::Named("Asd") = createRMatrix(result.Asd),
76
-        Rcpp::Named("Psd") = createRMatrix(result.Psd)
97
+        Rcpp::Named("Psd") = createRMatrix(result.Psd),
98
+        Rcpp::Named("seed") = result.seed,
99
+        Rcpp::Named("meanChiSq") = result.meanChiSq,
100
+        Rcpp::Named("diagnostics") = Rcpp::List::create()
77 101
     );
78 102
 }
79 103
 
80 104
 // [[Rcpp::export]]
81
-Rcpp::List cogaps_cpp_from_file(const std::string &data, const std::string &unc,
82
-unsigned nPatterns, unsigned maxIterations, unsigned outputFrequency,
83
-uint32_t seed, float alphaA, float alphaP, float maxGibbsMassA,
84
-float maxGibbsMassP, bool messages, bool singleCell,
85
-const std::string &checkpointOutFile, unsigned nCores)
105
+Rcpp::List cogaps_cpp_from_file(const std::string &data, const Rcpp::S4 &params,
106
+const std::string &unc, const Rcpp::NumericMatrix &fixedMatrix,
107
+const std::string &checkpointInFile)
86 108
 {
87
-    return cogapsRun(data, unc, nPatterns, maxIterations, outputFrequency, seed,
88
-        alphaA, alphaP, maxGibbsMassA, maxGibbsMassP, messages, singleCell,
89
-        nCores);
109
+    return cogapsRun(data, params, unc, convertRMatrix(fixedMatrix),
110
+        checkpointInFile);
90 111
 }
91 112
 
92 113
 // [[Rcpp::export]]
93
-Rcpp::List cogaps_cpp(const Rcpp::NumericMatrix &data,
94
-const Rcpp::NumericMatrix &unc, unsigned nPatterns, unsigned maxIterations,
95
-unsigned outputFrequency, uint32_t seed, float alphaA, float alphaP,
96
-float maxGibbsMassA, float maxGibbsMassP, bool messages, bool singleCell,
97
-const std::string &checkpointOutFile, unsigned nCores)
114
+Rcpp::List cogaps_cpp(const Rcpp::NumericMatrix &data, const Rcpp::S4 &params,
115
+const Rcpp::NumericMatrix &unc, const Rcpp::NumericMatrix &fixedMatrix,
116
+const std::string &checkpointInFile)
98 117
 {
99
-    return cogapsRun(convertRMatrix(data), convertRMatrix(unc), nPatterns,
100
-        maxIterations, outputFrequency, seed, alphaA, alphaP, maxGibbsMassA,
101
-        maxGibbsMassP, messages, singleCell, nCores);
118
+    return cogapsRun(convertRMatrix(data), params, convertRMatrix(unc),
119
+        convertRMatrix(fixedMatrix), checkpointInFile);
102 120
 }
103 121
 
104 122
 // [[Rcpp::export]]
... ...
@@ -6,37 +6,72 @@
6 6
     #include <omp.h>
7 7
 #endif
8 8
 
9
-/*
10
-static std::vector< std::vector<unsigned> > sampleIndices(unsigned n, unsigned nSets)
9
+GapsDispatcher::GapsDispatcher() : mSeed(0), mNumPatterns(3),
10
+    mMaxIterations(1000), mNumCoresPerSet(1), mPrintMessages(true),
11
+    mCheckpointsCreated(0), mPhase('C'), mCheckpointInterval(0),
12
+    mCheckpointOutFile("gaps_checkpoint.out"), mInitialized(false)
13
+{}
14
+
15
+GapsDispatcher::~GapsDispatcher()
11 16
 {
12
-    unsigned setSize = n / nSets;
13
-    std::vector< std::vector<unsigned> > sampleIndices;
14
-    std::vector<unsigned> toBeSampled;
15
-    for (unsigned i = 0; i < n; ++i)
17
+    for (unsigned i = 0; i < mRunners.size(); ++i)
16 18
     {
17
-        toBeSampled.push_back(i);
19
+        delete mRunners[i];
18 20
     }
21
+}
19 22
 
20
-    for (unsigned i = 0; i < (nSets - 1); ++i)
21
-    {
22
-        sampleIndices.push_back(gaps::random::sample(toBeSampled, setSize));
23
-    }
23
+void GapsDispatcher::setMaxIterations(unsigned n)
24
+{
25
+    mMaxIterations = n;
26
+}
24 27
 
25
-    GAPS_ASSERT(!toBeSampled.empty());
28
+void GapsDispatcher::printMessages(bool print)
29
+{
30
+    mPrintMessages = print;
31
+    mRunners[0]->printMessages(print);
32
+}
26 33
 
27
-    sampleIndices.push_back(toBeSampled);
28
-    return sampleIndices;
34
+void GapsDispatcher::setOutputFrequency(unsigned n)
35
+{
36
+    mRunners[0]->setOutputFrequency(n);
29 37
 }
30
-*/
31 38
 
32
-void GapsDispatcher::runOneCycle(unsigned k)
39
+void GapsDispatcher::setSparsity(float alphaA, float alphaP, bool singleCell)
40
+{
41
+    mRunners[0]->setSparsity(alphaA, alphaP, singleCell);
42
+}
43
+
44
+void GapsDispatcher::setMaxGibbsMass(float maxA, float maxP)
45
+{
46
+    mRunners[0]->setMaxGibbsMass(maxA, maxP);
47
+}
48
+
49
+void GapsDispatcher::setFixedMatrix(char which, const RowMatrix &mat)
50
+{
51
+    mRunners[0]->setFixedMatrix(which, mat);
52
+}
53
+
54
+void GapsDispatcher::setNumCoresPerSet(unsigned n)
55
+{
56
+    mNumCoresPerSet = n;
57
+}
58
+
59
+void GapsDispatcher::setCheckpointInterval(unsigned n)
33 60
 {
34
-    GAPS_ASSERT(mDataIsLoaded);
35
-    mRunners[0]->run(k, mOutputFrequency, mPrintMessages, mNumCoresPerSet);
61
+    mCheckpointInterval = n;
62
+}
63
+
64
+void GapsDispatcher::setCheckpointOutFile(const std::string &path)
65
+{
66
+    mCheckpointOutFile = path;
36 67
 }
37 68
 
38 69
 GapsResult GapsDispatcher::run()
39 70
 {
71
+    GAPS_ASSERT(mInitialized);
72
+    GAPS_ASSERT(mPhase == 'C' || mPhase == 'S');
73
+
74
+    // calculate appropiate number of cores if compiled with openmp
40 75
     #ifdef __GAPS_OPENMP__
41 76
     if (mPrintMessages)
42 77
     {
... ...
@@ -47,55 +82,81 @@ GapsResult GapsDispatcher::run()
47 82
     }
48 83
     #endif
49 84
 
50
-    GapsResult result(mRunners[0]->nRow(), mRunners[0]->nCol());
51
-    GAPS_ASSERT(mDataIsLoaded);
52
-
53
-    if (mPrintMessages)
85
+    // this switch allows for the algorithm to be interruptable
86
+    switch (mPhase)
54 87
     {
55
-        gaps_printf("Calibration Phase\n");
88
+        case 'C':
89
+            if (mPrintMessages)
90
+            {
91
+                gaps_printf("Calibration Phase\n");
92
+            }
93
+            runOneCycle(mMaxIterations);
94
+            mPhase = 'S';
95
+
96
+        case 'S':
97
+            if (mPrintMessages)
98
+            {
99
+                gaps_printf("Sampling Phase\n");
100
+            }
101
+            mRunners[0]->startSampling();
102
+            runOneCycle(mMaxIterations);
103
+            break;
56 104
     }
57
-    runOneCycle(mMaxIterations);
58 105
 
59
-    if (mPrintMessages)
60
-    {
61
-        gaps_printf("Sampling Phase\n");
62
-    }
63
-    mRunners[0]->startSampling();
64
-    runOneCycle(mMaxIterations);
65
-
66
-    result.Amean = mRunners[0]->AMean();
67
-    result.Pmean = mRunners[0]->PMean();
68
-    result.Asd = mRunners[0]->AStd();
69
-    result.Psd = mRunners[0]->PStd();
70
-    result.seed = mSeed;
106
+    // extract useful information from runners
107
+    GapsResult result(mRunners[0]->nRow(), mRunners[0]->nCol(), mSeed);
108
+    result.Amean = mRunners[0]->Amean();
109
+    result.Pmean = mRunners[0]->Pmean();
110
+    result.Asd = mRunners[0]->Asd();
111
+    result.Psd = mRunners[0]->Psd();
112
+    result.meanChiSq = mRunners[0]->meanChiSq();
71 113
     return result;
72 114
 }
73 115
 
74
-void GapsDispatcher::loadData(const RowMatrix &D)
116
+void GapsDispatcher::createCheckpoint() const
75 117
 {
76
-    mRunners.push_back(new GapsRunner(D, mNumPatterns, mAlphaA,
77
-        mAlphaP, mMaxGibbsMassA, mMaxGibbsMassP, mSingleCell));
78
-    mDataIsLoaded = true;
118
+    Archive ar(mCheckpointOutFile, ARCHIVE_WRITE);
119
+
120
+    gaps::random::save(ar);
121
+    ar << mSeed << mNumPatterns << mMaxIterations << mPrintMessages <<
122
+        mCheckpointsCreated << mPhase;
123
+
124
+    ar << *mRunners[0];
79 125
 }
80 126
 
81
-void GapsDispatcher::setUncertainty(const RowMatrix &S)
127
+void GapsDispatcher::runOneCycle(unsigned k)
82 128
 {
83
-    GAPS_ASSERT(mDataIsLoaded);
84
-    mRunners[0]->setUncertainty(S);
129
+    unsigned nCheckpoints = mCheckpointInterval > 0 ? k / mCheckpointInterval : 0;
130
+    while (mCheckpointsCreated < nCheckpoints)
131
+    {
132
+        mRunners[0]->run(mCheckpointInterval, mNumCoresPerSet);
133
+        createCheckpoint();
134
+        ++mCheckpointsCreated;
135
+    }
136
+    mRunners[0]->run(k - mCheckpointInterval * mCheckpointsCreated, mNumCoresPerSet);
137
+    mCheckpointsCreated = 0; // reset checkpoint count for next cycle
85 138
 }
86 139
 
87
-void GapsDispatcher::loadData(const std::string &pathToData)
140
+/*
141
+static std::vector< std::vector<unsigned> > sampleIndices(unsigned n, unsigned nSets)
88 142
 {
89
-    gaps_printf("Loading Data...");
90
-    gaps_flush();
91
-    mRunners.push_back(new GapsRunner(pathToData, mNumPatterns, mAlphaA,
92
-        mAlphaP, mMaxGibbsMassA, mMaxGibbsMassP, mSingleCell));
93
-    mDataIsLoaded = true;
94
-    gaps_printf("Done!\n");
143
+    unsigned setSize = n / nSets;
144
+    std::vector< std::vector<unsigned> > sampleIndices;
145
+    std::vector<unsigned> toBeSampled;
146
+    for (unsigned i = 0; i < n; ++i)
147
+    {
148
+        toBeSampled.push_back(i);
149
+    }
150
+
151
+    for (unsigned i = 0; i < (nSets - 1); ++i)
152
+    {
153
+        sampleIndices.push_back(gaps::random::sample(toBeSampled, setSize));
154
+    }
155
+
156
+    GAPS_ASSERT(!toBeSampled.empty());
157
+
158
+    sampleIndices.push_back(toBeSampled);
159
+    return sampleIndices;
95 160
 }
161
+*/
96 162
 
97
-void GapsDispatcher::setUncertainty(const std::string &pathToMatrix)
98
-{
99
-    GAPS_ASSERT(mDataIsLoaded);
100
-    mRunners[0]->setUncertainty(pathToMatrix);
101
-}
102 163
\ No newline at end of file
... ...
@@ -16,12 +16,14 @@ struct GapsResult
16 16
     float meanChiSq;
17 17
     uint32_t seed;
18 18
 
19
-    GapsResult(unsigned nrow, unsigned ncol) : Amean(nrow, ncol),
20
-        Asd(nrow, ncol), Pmean(nrow, ncol), Psd(nrow, ncol), meanChiSq(0.f),
21
-        seed(0)
19
+    GapsResult(unsigned nrow, unsigned ncol, uint32_t rngSeed) :
20
+        Amean(nrow, ncol), Asd(nrow, ncol), Pmean(nrow, ncol), Psd(nrow, ncol),
21
+        meanChiSq(0.f), seed(rngSeed)
22 22
     {}
23 23
 
24 24
     void writeCsv(const std::string &path);
25
+    void writeTsv(const std::string &path);
26
+    void writeGct(const std::string &path);
25 27
 };
26 28
 
27 29
 // should be agnostic to external caller (R/Python/CLI)
... ...
@@ -29,86 +31,99 @@ class GapsDispatcher
29 31
 {
30 32
 private:
31 33
 
32
-    unsigned mNumPatterns;
33
-    unsigned mMaxIterations;
34
-    unsigned mOutputFrequency;
35 34
     uint32_t mSeed;
35
+    unsigned mNumPatterns;
36 36
 
37
-    float mAlphaA;
38
-    float mAlphaP;
39
-    float mMaxGibbsMassA;
40
-    float mMaxGibbsMassP;
41
-
37
+    unsigned mMaxIterations;
38
+    unsigned mNumCoresPerSet;
42 39
     bool mPrintMessages;
43
-    bool mSingleCell;
44 40
 
45
-    bool mDataIsLoaded;    
41
+    unsigned mCheckpointsCreated;
42
+    char mPhase; // 'C' for calibration, 'S' for sample
46 43
 
47
-    unsigned mNumCoresPerSet;
44
+    unsigned mCheckpointInterval;
45
+    std::string mCheckpointOutFile;
48 46
 
47
+    bool mInitialized;    
49 48
     std::vector<GapsRunner*> mRunners;
50 49
 
51
-    char mFixedMatrix;
52
-
53 50
     void runOneCycle(unsigned k);
51
+    void createCheckpoint() const;
54 52
 
55 53
     GapsDispatcher(const GapsDispatcher &p); // don't allow copies
56 54
     GapsDispatcher& operator=(const GapsDispatcher &p); // don't allow copies
57 55
 
56
+    template <class DataType>
57
+    void loadData(const DataType &data);
58
+
58 59
 public:
59 60
 
60
-    explicit GapsDispatcher(uint32_t seed=0) : mNumPatterns(3),
61
-        mMaxIterations(1000), mOutputFrequency(250), mSeed(seed), mAlphaA(0.01),
62
-        mAlphaP(0.01), mMaxGibbsMassA(100.f), mMaxGibbsMassP(100.f),
63
-        mPrintMessages(true), mSingleCell(false), mDataIsLoaded(false),
64
-        mNumCoresPerSet(1), mFixedMatrix('N')
65
-    {
66
-        gaps::random::setSeed(mSeed);
67
-    }
68
-
69
-    ~GapsDispatcher()
70
-    {
71
-        for (unsigned i = 0; i < mRunners.size(); ++i)
72
-        {
73
-            delete mRunners[i];
74
-        }
75
-    }
76
-
77
-    void setNumPatterns(unsigned n)     { mNumPatterns = n; }
78
-    void setMaxIterations(unsigned n)   { mMaxIterations = n; }
79
-    void setOutputFrequency(unsigned n) { mOutputFrequency = n; }
80
-    void setNumCoresPerSet(unsigned n)  { mNumCoresPerSet = n; }
81
-
82
-    void printMessages(bool print) { mPrintMessages = print; }
83
-    void singleCell(bool sc) { mSingleCell = sc; }
84
-
85
-    void setAlpha(float alphaA, float alphaP)
86
-    {
87
-        mAlphaA = alphaA;
88
-        mAlphaP = alphaP;
89
-    }
90
-
91
-    void setMaxGibbsMass(float maxA, float maxP)
92
-    {
93
-        mMaxGibbsMassA = maxA;
94
-        mMaxGibbsMassP = maxP;
95
-    }
96
-
97
-    void setFixedMatrix(char which, const RowMatrix &mat)
98
-    {
99
-        mFixedMatrix = which;
100
-        mRunners[0]->setFixedMatrix(which, mat);
101
-    }
102
-    
103
-    void loadCheckpointFile(const std::string &pathToCptFile);
61
+    GapsDispatcher();
62
+    ~GapsDispatcher();
63
+
64
+    template <class DataType>
65
+    void initialize(const DataType &data, unsigned nPatterns, uint32_t seed=0);
66
+
67
+    template <class DataType>
68
+    void initialize(const DataType &data, const std::string &cptFile);
104 69
 
105
-    void setUncertainty(const RowMatrix &S);
106
-    void setUncertainty(const std::string &pathToMatrix);
70
+    template <class DataType>
71
+    void setUncertainty(const DataType &unc);
107 72
 
108
-    void loadData(const RowMatrix &D);
109
-    void loadData(const std::string &pathToData);
73
+    void setMaxIterations(unsigned n);
74
+    void printMessages(bool print);
75
+    void setOutputFrequency(unsigned n);
76
+    void setSparsity(float alphaA, float alphaP, bool singleCell);
77
+    void setMaxGibbsMass(float maxA, float maxP);
78
+    void setFixedMatrix(char which, const RowMatrix &mat);
79
+    
80
+    void setNumCoresPerSet(unsigned n);
81
+    void setCheckpointInterval(unsigned n);
82
+    void setCheckpointOutFile(const std::string &path);
110 83
     
111 84
     GapsResult run();
112 85
 };
113 86
 
87
+template <class DataType>
88
+void GapsDispatcher::initialize(const DataType &data, unsigned nPatterns,
89
+uint32_t seed)
90
+{
91
+    mSeed = seed;
92
+    mNumPatterns = nPatterns;
93
+    gaps::random::setSeed(mSeed);
94
+    
95
+    loadData(data);
96
+    mInitialized = true;
97
+}
98
+
99
+template <class DataType>
100
+void GapsDispatcher::initialize(const DataType &data, const std::string &cptFile)
101
+{
102
+    Archive ar(cptFile, ARCHIVE_READ);
103
+    gaps::random::load(ar);
104
+
105
+    ar >> mSeed >> mNumPatterns >> mMaxIterations >> mPrintMessages >>
106
+        mCheckpointsCreated >> mPhase;
107
+
108
+    loadData(data);
109
+    ar >> *mRunners[0];
110
+    mInitialized = true;
111
+}
112
+
113
+template <class DataType>
114
+void GapsDispatcher::setUncertainty(const DataType &unc)
115
+{
116
+    GAPS_ASSERT(mInitialized);
117
+    mRunners[0]->setUncertainty(unc);
118
+}
119
+
120
+template <class DataType>
121
+void GapsDispatcher::loadData(const DataType &data)
122
+{
123
+    gaps_printf("Loading Data...");
124
+    gaps_flush();
125
+    mRunners.push_back(new GapsRunner(data, mNumPatterns));
126
+    gaps_printf("Done!\n");
127
+}
128
+
114 129
 #endif // __COGAPS_GAPS_DISPATCHER_H__
115 130
\ No newline at end of file
... ...
@@ -8,64 +8,51 @@
8 8
 #include <Rcpp.h>
9 9
 #endif
10 10
 
11
-#include <algorithm>
11
+void GapsRunner::printMessages(bool print)
12
+{
13
+    mPrintMessages = print;
14
+}
12 15
 
13
-GapsRunner::GapsRunner(const RowMatrix &data, unsigned nPatterns, float alphaA,
14
-float alphaP, float maxGibbsMassA, float maxGibbsMassP, bool singleCell)
15
-    :
16
-mNumRows(data.nRow()), mNumCols(data.nCol()),
17
-mASampler(data, nPatterns, alphaA, maxGibbsMassA, singleCell),
18
-mPSampler(data, nPatterns, alphaP, maxGibbsMassP, singleCell),
19
-mStatistics(data.nRow(), data.nCol(), nPatterns),
20
-mSamplePhase(false), mNumUpdatesA(0), mNumUpdatesP(0)
16
+void GapsRunner::setOutputFrequency(unsigned n)
21 17
 {
22
-    mASampler.sync(mPSampler);
23
-    mPSampler.sync(mASampler);
18
+    mOutputFrequency = n;
24 19
 }
25 20
 
26
-GapsRunner::GapsRunner(const std::string &pathToData, unsigned nPatterns,
27
-float alphaA, float alphaP, float maxGibbsMassA, float maxGibbsMassP,
28
-bool singleCell)
29
-    :
30
-mNumRows(FileParser(pathToData).nRow()), mNumCols(FileParser(pathToData).nCol()),
31
-mASampler(pathToData, nPatterns, alphaA, maxGibbsMassA, singleCell),
32
-mPSampler(pathToData, nPatterns, alphaP, maxGibbsMassP, singleCell),
33
-mStatistics(mNumRows, mNumCols, nPatterns),
34
-mSamplePhase(false), mNumUpdatesA(0), mNumUpdatesP(0)
21
+void GapsRunner::setSparsity(float alphaA, float alphaP, bool singleCell)
35 22
 {
36
-    mASampler.sync(mPSampler);
37
-    mPSampler.sync(mASampler);
23
+    mASampler.setSparsity(alphaA, singleCell);
24
+    mPSampler.setSparsity(alphaP, singleCell);
25
+}
26
+
27
+void GapsRunner::setMaxGibbsMass(float maxA, float maxP)
28
+{
29
+    mASampler.setMaxGibbsMass(maxA);
30
+    mPSampler.setMaxGibbsMass(maxP);
38 31
 }
39 32
 
40 33
 void GapsRunner::setFixedMatrix(char which, const RowMatrix &mat)
41 34
 {
35
+    mFixedMatrix = which;
42 36
     if (which == 'A')
43 37
     {
44 38
         mASampler.setMatrix(ColMatrix(mat));
39
+        mASampler.recalculateAPMatrix();
40
+        mPSampler.sync(mASampler);
45 41
     }
46 42
     else if (which == 'P')
47 43
     {
48 44
         mPSampler.setMatrix(mat);
45
+        mPSampler.recalculateAPMatrix();
46
+        mASampler.sync(mPSampler);
49 47
     }
50 48
 }
51 49
 
52
-void GapsRunner::setUncertainty(const std::string &pathToMatrix)
50
+void GapsRunner::startSampling()
53 51
 {
54
-    mASampler.setUncertainty(pathToMatrix);
55
-    mPSampler.setUncertainty(pathToMatrix);
52
+    mSamplePhase = true;
56 53
 }
57 54
 
58
-void GapsRunner::displayStatus(unsigned outFreq, unsigned current, unsigned total)
59
-{
60
-    if (outFreq > 0 && ((current + 1) % outFreq) == 0)
61
-    {
62
-        gaps_printf("%d of %d, Atoms:%lu(%lu) Chi2 = %.2f\n", current + 1,
63
-            total, mASampler.nAtoms(), mPSampler.nAtoms(), mASampler.chi2());
64
-    }
65
-}
66
-
67
-void GapsRunner::run(unsigned nIter, unsigned outputFreq, bool printMessages,
68
-unsigned nCores)
55
+void GapsRunner::run(unsigned nIter, unsigned nCores)
69 56
 {
70 57
     for (unsigned i = 0; i < nIter; ++i)
71 58
     {
... ...
@@ -85,9 +72,9 @@ unsigned nCores)
85 72
         unsigned nP = gaps::random::poisson(std::max(mPSampler.nAtoms(), 10ul));
86 73
         updateSampler(nA, nP, nCores);
87 74
 
88
-        if (printMessages)
75
+        if (mPrintMessages)
89 76
         {
90
-            displayStatus(outputFreq, i, nIter);
77
+            displayStatus(i, nIter);
91 78
         }
92 79
 
93 80
         if (mSamplePhase)
... ...
@@ -97,19 +84,87 @@ unsigned nCores)
97 84
     }
98 85
 }
99 86
 
87
+unsigned GapsRunner::nRow() const
88
+{
89
+    return mASampler.dataRows();
90
+}
91
+
92
+unsigned GapsRunner::nCol() const
93
+{
94
+    return mASampler.dataCols();
95
+}
96
+
97
+ColMatrix GapsRunner::Amean() const
98
+{
99
+    return mStatistics.Amean();
100
+}
101
+
102
+RowMatrix GapsRunner::Pmean() const
103
+{
104
+    return mStatistics.Pmean();
105
+}
106
+
107
+ColMatrix GapsRunner::Asd() const
108
+{
109
+    return mStatistics.Asd();
110
+}
111
+
112
+RowMatrix GapsRunner::Psd() const
113
+{
114
+    return mStatistics.Psd();
115
+}
116
+
117
+float GapsRunner::meanChiSq() const
118
+{
119
+    return mStatistics.meanChiSq(mASampler);
120
+}
121
+
122
+Archive& operator<<(Archive &ar, GapsRunner &runner)
123
+{
124
+    ar << runner.mASampler << runner.mPSampler << runner.mStatistics <<
125
+        runner.mPrintMessages << runner.mOutputFrequency <<
126
+        runner.mFixedMatrix << runner.mSamplePhase << runner.mNumUpdatesA <<
127
+        runner.mNumUpdatesP;
128
+    return ar;
129
+}
130
+
131
+Archive& operator>>(Archive &ar, GapsRunner &runner)
132
+{
133
+    ar >> runner.mASampler >> runner.mPSampler >> runner.mStatistics >>
134
+        runner.mPrintMessages >> runner.mOutputFrequency >>
135
+        runner.mFixedMatrix >> runner.mSamplePhase >> runner.mNumUpdatesA >>
136
+        runner.mNumUpdatesP;
137
+    return ar;
138
+}
139
+
100 140
 void GapsRunner::updateSampler(unsigned nA, unsigned nP, unsigned nCores)
101 141
 {
102
-    mNumUpdatesA += nA;
103
-    mASampler.update(nA, nCores);
104
-    mPSampler.sync(mASampler);
142
+    if (mFixedMatrix != 'A')
143
+    {
144
+        mNumUpdatesA += nA;
145
+        mASampler.update(nA, nCores);
146
+        if (mFixedMatrix != 'P')
147
+        {
148
+            mPSampler.sync(mASampler);
149
+        }
150
+    }
105 151
 
106
-    mNumUpdatesP += nP;
107
-    mPSampler.update(nP, nCores);
108
-    mASampler.sync(mPSampler);
152
+    if (mFixedMatrix != 'P')
153
+    {
154
+        mNumUpdatesP += nP;
155
+        mPSampler.update(nP, nCores);
156
+        if (mFixedMatrix != 'A')
157
+        {
158
+            mASampler.sync(mPSampler);
159
+        }
160
+    }
109 161
 }
110 162
 
111
-void GapsRunner::setUncertainty(const RowMatrix &S)
163
+void GapsRunner::displayStatus(unsigned current, unsigned total)
112 164
 {
113
-    mASampler.setUncertainty(S);
114
-    mPSampler.setUncertainty(S);
165
+    if (mOutputFrequency > 0 && ((current + 1) % mOutputFrequency) == 0)
166
+    {
167
+        gaps_printf("%d of %d, Atoms:%lu(%lu) Chi2 = %.2f\n", current + 1,
168
+            total, mASampler.nAtoms(), mPSampler.nAtoms(), mASampler.chi2());
169
+    }
115 170
 }
... ...
@@ -9,50 +9,71 @@
9 9
 class GapsRunner
10 10
 {
11 11
 private:
12
-
13
-    unsigned mNumRows;
14
-    unsigned mNumCols;
15
-
12
+    
16 13
     AmplitudeGibbsSampler mASampler;
17 14
     PatternGibbsSampler mPSampler;
18 15
     GapsStatistics mStatistics;
19 16
 
17
+    bool mPrintMessages;
18
+    unsigned mOutputFrequency;
19
+    char mFixedMatrix;
20 20
     bool mSamplePhase;
21 21
 
22 22
     unsigned mNumUpdatesA;
23 23
     unsigned mNumUpdatesP;
24 24
 
25 25
     void updateSampler(unsigned nA, unsigned nP, unsigned nCores);
26
+    void displayStatus(unsigned current, unsigned total);
26 27
 
27 28
 public:
28 29
 
29
-    GapsRunner(const RowMatrix &data, unsigned nPatterns, float alphaA,
30
-        float alphaP, float maxGibbsMassA, float maxGibbsMassP,
31
-        bool singleCell);
30
+    template <class DataType>
31
+    GapsRunner(const DataType &data, unsigned nPatterns);
32 32
 
33
-    GapsRunner(const std::string &pathToData, unsigned nPatterns, float alphaA,
34
-        float alphaP, float maxGibbsMassA, float maxGibbsMassP,
35
-        bool singleCell);
36
-    
37
-    void run(unsigned nIter, unsigned outputFreq, bool printMessages,
38
-        unsigned nCores);
33
+    template <class DataType>
34
+    void setUncertainty(const DataType &unc);
39 35
 
40
-    unsigned nRow() const { return mNumRows; }
41
-    unsigned nCol() const { return mNumCols; }
36
+    void printMessages(bool print);
37
+    void setOutputFrequency(unsigned n);
38
+    void setSparsity(float alphaA, float alphaP, bool singleCell);
39
+    void setMaxGibbsMass(float maxA, float maxP);
42 40
 
43
-    void setUncertainty(const RowMatrix &S);
44
-    void setUncertainty(const std::string &pathToMatrix);
41
+    void setFixedMatrix(char which, const RowMatrix &mat);
45 42
 
46
-    void startSampling() { mSamplePhase = true; }
43
+    void startSampling();
47 44
 
48
-    void displayStatus(unsigned outFreq, unsigned current, unsigned total);
45
+    void run(unsigned nIter, unsigned nCores);
49 46
 
50
-    void setFixedMatrix(char which, const RowMatrix &mat);
47
+    unsigned nRow() const;
48
+    unsigned nCol() const;
51 49
 
52
-    ColMatrix AMean() const { return mStatistics.AMean(); }
53
-    RowMatrix PMean() const { return mStatistics.PMean(); }
54
-    ColMatrix AStd() const { return mStatistics.AStd(); }
55
-    RowMatrix PStd() const { return mStatistics.PStd(); }
50
+    ColMatrix Amean() const;
51
+    RowMatrix Pmean() const;
52
+    ColMatrix Asd() const;
53
+    RowMatrix Psd() const;
54
+    float meanChiSq() const;
55
+
56
+    // serialization
57
+    friend Archive& operator<<(Archive &ar, GapsRunner &runner);
58
+    friend Archive& operator>>(Archive &ar, GapsRunner &runner);
56 59
 };
57 60
 
61
+template <class DataType>
62
+GapsRunner::GapsRunner(const DataType &data, unsigned nPatterns)
63
+    :
64
+mASampler(data, nPatterns), mPSampler(data, nPatterns),
65
+mStatistics(mASampler.dataRows(), mPSampler.dataCols(), nPatterns),
66
+mSamplePhase(false), mNumUpdatesA(0), mNumUpdatesP(0), mFixedMatrix('N')
67
+{
68
+    mASampler.sync(mPSampler);
69
+    mPSampler.sync(mASampler);
70
+}
71
+
72
+template <class DataType>
73
+void GapsRunner::setUncertainty(const DataType &unc)
74
+{
75
+    mASampler.setUncertainty(unc);
76
+    mPSampler.setUncertainty(unc);
77
+}
78
+
58 79
 #endif // __COGAPS_GAPS_RUNNER_H__
59 80
\ No newline at end of file
... ...
@@ -1,11 +1,13 @@
1 1
 #include "GapsStatistics.h"
2 2
 #include "math/Algorithms.h"
3 3
 
4
-GapsStatistics::GapsStatistics(unsigned nRow, unsigned nCol, unsigned nFactor, PumpThreshold t)
5
-    : mAMeanMatrix(nRow, nFactor), mAStdMatrix(nRow, nFactor),
6
-        mPMeanMatrix(nFactor, nCol), mPStdMatrix(nFactor, nCol),
7
-        mStatUpdates(0), mNumPatterns(nFactor), mPumpMatrix(nRow, nCol),
8
-        mPumpThreshold(t), mPumpStatUpdates(0)
4
+GapsStatistics::GapsStatistics(unsigned nRow, unsigned nCol, unsigned nPatterns,
5
+PumpThreshold t)
6
+    :
7
+mAMeanMatrix(nRow, nPatterns), mAStdMatrix(nRow, nPatterns),
8
+mPMeanMatrix(nPatterns, nCol), mPStdMatrix(nPatterns, nCol),
9
+mStatUpdates(0), mNumPatterns(nPatterns), mPumpMatrix(nRow, nCol),
10
+mPumpThreshold(t), mPumpStatUpdates(0)
9 11
 {}
10 12
 
11 13
 void GapsStatistics::update(const AmplitudeGibbsSampler &ASampler,
... ...
@@ -29,6 +31,37 @@ const PatternGibbsSampler &PSampler)
29 31
     }
30 32
 }
31 33
 
34
+ColMatrix GapsStatistics::Amean() const
35
+{
36
+    return mAMeanMatrix / mStatUpdates;
37
+}
38
+
39
+ColMatrix GapsStatistics::Asd() const
40
+{
41
+    return gaps::algo::computeStdDev(mAStdMatrix, mAMeanMatrix,
42
+        mStatUpdates);
43
+}
44
+
45
+RowMatrix GapsStatistics::Pmean() const
46
+{
47
+    return mPMeanMatrix / mStatUpdates;
48
+}
49
+
50
+RowMatrix GapsStatistics::Psd() const
51
+{
52
+    return gaps::algo::computeStdDev(mPStdMatrix, mPMeanMatrix,
53
+        mStatUpdates);
54
+}
55
+
56
+float GapsStatistics::meanChiSq(const AmplitudeGibbsSampler &ASampler) const
57
+{
58
+    ColMatrix A = mAMeanMatrix / mStatUpdates;
59
+    RowMatrix P = mPMeanMatrix / mStatUpdates;
60
+    RowMatrix M(gaps::algo::matrixMultiplication(A, P));
61
+    return 2.f * gaps::algo::loglikelihood(ASampler.mDMatrix, ASampler.mSMatrix,
62
+        M);
63
+}
64
+
32 65
 static unsigned geneThreshold(const ColMatrix &rankMatrix, unsigned pat)
33 66
 {
34 67
     float cutRank = rankMatrix.nRow();
... ...
@@ -127,37 +160,6 @@ const PatternGibbsSampler &PSampler)
127 160
     patternMarkers(ASampler.mMatrix, PSampler.mMatrix, mPumpMatrix);
128 161
 }
129 162
 
130
-ColMatrix GapsStatistics::AMean() const
131
-{
132
-    return mAMeanMatrix / mStatUpdates;
133
-}
134
-
135
-ColMatrix GapsStatistics::AStd() const
136
-{
137
-    return gaps::algo::computeStdDev(mAStdMatrix, mAMeanMatrix,
138
-        mStatUpdates);
139
-}
140
-
141
-RowMatrix GapsStatistics::PMean() const
142
-{
143
-    return mPMeanMatrix / mStatUpdates;
144
-}
145
-
146
-RowMatrix GapsStatistics::PStd() const
147
-{
148
-    return gaps::algo::computeStdDev(mPStdMatrix, mPMeanMatrix,
149
-        mStatUpdates);
150
-}
151
-
152
-float GapsStatistics::meanChiSq(const AmplitudeGibbsSampler &ASampler) const
153
-{
154
-    ColMatrix A = mAMeanMatrix / mStatUpdates;
155
-    RowMatrix P = mPMeanMatrix / mStatUpdates;
156
-    RowMatrix M(gaps::algo::matrixMultiplication(A, P));
157
-    return 2.f * gaps::algo::loglikelihood(ASampler.mDMatrix, ASampler.mSMatrix,
158
-        M);
159
-}
160
-
161 163
 RowMatrix GapsStatistics::pumpMatrix() const
162 164
 {
163 165
     unsigned denom = mPumpStatUpdates != 0 ? mPumpStatUpdates : 1.f;
... ...
@@ -28,24 +28,25 @@ private:
28 28
 
29 29
 public:
30 30
 
31
-    GapsStatistics(unsigned nRow, unsigned nCol, unsigned nFactor, PumpThreshold t=PUMP_CUT);
32
-
33
-    ColMatrix AMean() const;
34
-    RowMatrix PMean() const;
35
-    ColMatrix AStd() const;
36
-    RowMatrix PStd() const;
37
-
38
-    float meanChiSq(const AmplitudeGibbsSampler &ASampler) const;
31
+    GapsStatistics(unsigned nRow, unsigned nCol, unsigned nPatterns,
32
+        PumpThreshold t=PUMP_CUT);
39 33
 
40 34
     void update(const AmplitudeGibbsSampler &ASampler,
41 35
         const PatternGibbsSampler &PSampler);
42 36
 
37
+    ColMatrix Amean() const;
38
+    RowMatrix Pmean() const;
39
+    ColMatrix Asd() const;
40
+    RowMatrix Psd() const;
41
+
42
+    float meanChiSq(const AmplitudeGibbsSampler &ASampler) const;
43
+
44
+    // PUMP statistics
43 45
     void updatePump(const AmplitudeGibbsSampler &ASampler,
44 46
         const PatternGibbsSampler &PSampler);
45 47
 
46 48
     RowMatrix pumpMatrix() const;
47 49
     RowMatrix meanPattern();
48
-
49 50
     void patternMarkers(ColMatrix normedA, RowMatrix normedP, ColMatrix &statMatrix);
50 51
 
51 52
     // serialization
... ...
@@ -3,21 +3,15 @@
3 3
 
4 4
 /******************** AmplitudeGibbsSampler Implementation ********************/
5 5
 
6
-AmplitudeGibbsSampler::AmplitudeGibbsSampler(const RowMatrix &D, 
7
-unsigned nFactor, float alpha, float maxGibbsMass, bool singleCell)
8
-    :
9
-GibbsSampler(D, D.nRow(), nFactor, nFactor, alpha, maxGibbsMass, singleCell)
6
+void AmplitudeGibbsSampler::sync(PatternGibbsSampler &sampler)
10 7
 {
11
-    mQueue.setDimensionSize(mBinSize, mNumCols);
8
+    mOtherMatrix = &(sampler.mMatrix);
9
+    mAPMatrix = sampler.mAPMatrix;
12 10
 }
13 11
 
14
-AmplitudeGibbsSampler::AmplitudeGibbsSampler(const std::string &pathToData,
15
-unsigned nFactor, float alpha, float maxGibbsMass, bool singleCell)
16
-    :
17
-GibbsSampler(pathToData, FileParser(pathToData).nRow(), nFactor, nFactor, alpha, 
18
-maxGibbsMass, singleCell)
12
+void AmplitudeGibbsSampler::recalculateAPMatrix()
19 13
 {
20
-    mQueue.setDimensionSize(mBinSize, mNumCols);
14
+    mAPMatrix = gaps::algo::matrixMultiplication(mMatrix, *mOtherMatrix);
21 15
 }
22 16
 
23 17
 unsigned AmplitudeGibbsSampler::getRow(uint64_t pos) const
... ...
@@ -41,12 +35,6 @@ bool AmplitudeGibbsSampler::canUseGibbs(unsigned r1, unsigned c1, unsigned r2, u
41 35
     return canUseGibbs(r1, c1) || canUseGibbs(r2, c2);
42 36
 }
43 37
 
44
-void AmplitudeGibbsSampler::sync(PatternGibbsSampler &sampler)
45
-{
46
-    mOtherMatrix = &(sampler.mMatrix);
47
-    mAPMatrix = sampler.mAPMatrix;
48
-}
49
-
50 38
 void AmplitudeGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
51 39
 {
52 40
     const float *other = mOtherMatrix->rowPtr(col);
... ...
@@ -79,8 +67,9 @@ AlphaParameters AmplitudeGibbsSampler::alphaParameters(unsigned row, unsigned co
79 67
 AlphaParameters AmplitudeGibbsSampler::alphaParameters(unsigned r1, unsigned c1,
80 68
 unsigned r2, unsigned c2)
81 69
 {
82
-    if (r1 == r2)
70
+    if (r1 == r2) // TODO should this ever happen
83 71
     {
72
+        GAPS_ASSERT(false);
84 73
         return gaps::algo::alphaParameters(mDMatrix.nCol(), mDMatrix.rowPtr(r1),
85 74
             mSMatrix.rowPtr(r1), mAPMatrix.rowPtr(r1), mOtherMatrix->rowPtr(c1),
86 75
             mOtherMatrix->rowPtr(c2));
... ...
@@ -98,8 +87,9 @@ float AmplitudeGibbsSampler::computeDeltaLL(unsigned row, unsigned col, float ma
98 87
 float AmplitudeGibbsSampler::computeDeltaLL(unsigned r1, unsigned c1, float m1,
99 88
 unsigned r2, unsigned c2, float m2)
100 89
 {
101
-    if (r1 == r2)
90
+    if (r1 == r2) // TODO should this ever happen
102 91
     {
92
+        GAPS_ASSERT(false);
103 93
         return gaps::algo::deltaLL(mDMatrix.nCol(), mDMatrix.rowPtr(r1),
104 94
             mSMatrix.rowPtr(r1), mAPMatrix.rowPtr(r1), mOtherMatrix->rowPtr(c1),
105 95
             m1, mOtherMatrix->rowPtr(c2), m2);
... ...
@@ -109,21 +99,15 @@ unsigned r2, unsigned c2, float m2)
109 99
 
110 100
 /********************* PatternGibbsSampler Implementation *********************/
111 101
 
112
-PatternGibbsSampler::PatternGibbsSampler(const RowMatrix &D,
113
-unsigned nFactor, float alpha, float maxGibbsMass, bool singleCell)
114
-    :
115
-GibbsSampler(D, nFactor, D.nCol(), nFactor, alpha, maxGibbsMass, singleCell)
102
+void PatternGibbsSampler::recalculateAPMatrix()
116 103
 {
117
-    mQueue.setDimensionSize(mBinSize, mNumRows);
104
+    mAPMatrix = gaps::algo::matrixMultiplication(*mOtherMatrix, mMatrix);
118 105
 }
119 106
 
120
-PatternGibbsSampler::PatternGibbsSampler(const std::string &pathToData,
121
-unsigned nFactor, float alpha, float maxGibbsMass, bool singleCell)
122
-    :
123
-GibbsSampler(pathToData, nFactor, FileParser(pathToData).nCol(), nFactor, alpha,
124
-maxGibbsMass, singleCell)
107
+void PatternGibbsSampler::sync(AmplitudeGibbsSampler &sampler)
125 108
 {
126
-    mQueue.setDimensionSize(mBinSize , mNumRows);
109
+    mOtherMatrix = &(sampler.mMatrix);
110
+    mAPMatrix = sampler.mAPMatrix;
127 111
 }
128 112
 
129 113
 unsigned PatternGibbsSampler::getRow(uint64_t pos) const
... ...
@@ -147,12 +131,6 @@ bool PatternGibbsSampler::canUseGibbs(unsigned r1, unsigned c1, unsigned r2, uns
147 131
     return canUseGibbs(r1, c1) || canUseGibbs(r2, c2);
148 132
 }
149 133
 
150
-void PatternGibbsSampler::sync(AmplitudeGibbsSampler &sampler)
151
-{
152
-    mOtherMatrix = &(sampler.mMatrix);
153
-    mAPMatrix = sampler.mAPMatrix;
154
-}
155
-
156 134
 void PatternGibbsSampler::updateAPMatrix(unsigned row, unsigned col, float delta)
157 135
 {
158 136
     const float *other = mOtherMatrix->colPtr(row);
... ...
@@ -185,8 +163,9 @@ AlphaParameters PatternGibbsSampler::alphaParameters(unsigned row, unsigned col)
185 163
 AlphaParameters PatternGibbsSampler::alphaParameters(unsigned r1, unsigned c1,
186 164
 unsigned r2, unsigned c2)
187 165
 {
188
-    if (c1 == c2)
166
+    if (c1 == c2) // TODO should this ever happen
189 167
     {
168
+        GAPS_ASSERT(false);
190 169
         return gaps::algo::alphaParameters(mDMatrix.nRow(), mDMatrix.colPtr(c1),
191 170
             mSMatrix.colPtr(c1), mAPMatrix.colPtr(c1), mOtherMatrix->colPtr(r1),
192 171
             mOtherMatrix->colPtr(r2));
... ...
@@ -204,8 +183,9 @@ float PatternGibbsSampler::computeDeltaLL(unsigned row, unsigned col, float mass
204 183
 float PatternGibbsSampler::computeDeltaLL(unsigned r1, unsigned c1, float m1,
205 184
 unsigned r2, unsigned c2, float m2)
206 185
 {
207
-    if (c1 == c2)
186
+    if (c1 == c2) // TODO should this ever happen
208 187
     {
188
+        GAPS_ASSERT(false);
209 189
         return gaps::algo::deltaLL(mDMatrix.nRow(), mDMatrix.colPtr(c1),
210 190
             mSMatrix.colPtr(c1), mAPMatrix.colPtr(c1), mOtherMatrix->colPtr(r1),
211 191
             m1, mOtherMatrix->colPtr(r2), m2);
... ...
@@ -11,13 +11,12 @@
11 11
 
12 12
 #include <algorithm>
13 13
 
14
-// forward declarations needed for friend classes
14
+// forward declarations needed for friend classes/functions
15
+
15 16
 class AmplitudeGibbsSampler;
16 17
 class PatternGibbsSampler;
17 18
 class GapsStatistics;
18 19
 
19
-/************************** GIBBS SAMPLER INTERFACE **************************/
20
-
21 20
 template <class T, class MatA, class MatB>
22 21
 class GibbsSampler;
23 22
 
... ...
@@ -27,6 +26,8 @@ Archive& operator<<(Archive &ar, GibbsSampler<T, MatA, MatB> &samp);
27 26
 template <class T, class MatA, class MatB>
28 27
 Archive& operator>>(Archive &ar, GibbsSampler<T, MatA, MatB> &samp);
29 28
 
29
+/*************************** GIBBS SAMPLER INTERFACE **************************/
30
+
30 31
 template <class T, class MatA, class MatB>
31 32
 class GibbsSampler
32 33
 {
... ...
@@ -36,12 +37,13 @@ private:
36 37
 
37 38
 protected:
38 39
 
39
-    MatA mMatrix;
40
-    MatB* mOtherMatrix;
41 40
     MatB mDMatrix;
42 41
     MatB mSMatrix;
43 42
     MatB mAPMatrix;
44 43
 
44
+    MatA mMatrix;
45
+    MatB* mOtherMatrix;
46
+
45 47
     ProposalQueue mQueue;
46 48
     AtomicDomain mDomain;
47 49
 
... ...
@@ -60,6 +62,9 @@ protected:
60 62
 
61 63
     void processProposal(const AtomicProposal &prop);
62 64
 
65
+    void addMass(uint64_t pos, float mass, unsigned row, unsigned col);
66
+    void removeMass(uint64_t pos, float mass, unsigned row, unsigned col);
67
+
63 68
     void birth(uint64_t pos, unsigned row, unsigned col);
64 69
     void death(uint64_t pos, float mass, unsigned row, unsigned col);
65 70
     void move(uint64_t src, float mass, uint64_t dest, unsigned r1, unsigned c1,
... ...
@@ -67,40 +72,43 @@ protected:
67 72
     void exchange(uint64_t p1, float m1, uint64_t p2, float m2, unsigned r1,
68 73
         unsigned c1, unsigned r2, unsigned c2);
69 74
 
70
-    float gibbsMass(unsigned row, unsigned col, float mass);
71
-    float gibbsMass(unsigned r1, unsigned c1, float m1, unsigned r2,
72
-        unsigned c2, float m2);
73
-
74
-    void addMass(uint64_t pos, float mass, unsigned row, unsigned col);
75
-    void removeMass(uint64_t pos, float mass, unsigned row, unsigned col);
76 75
     bool updateAtomMass(uint64_t pos, float mass, float delta);
77
-
78 76
     void acceptExchange(uint64_t p1, float m1, float d1, uint64_t p2, float m2,
79 77
         float d2, unsigned r1, unsigned c1, unsigned r2, unsigned c2);
80 78
 
79
+    float gibbsMass(unsigned row, unsigned col, float mass);
80
+    float gibbsMass(unsigned r1, unsigned c1, float m1, unsigned r2,
81
+        unsigned c2, float m2);
82
+
81 83
 public:
82 84
 
83
-    GibbsSampler(const RowMatrix &D, unsigned nrow, unsigned ncol,
84
-        unsigned nPatterns, float alpha, float maxGibbsMass, bool singleCell);
85
+    template <class DataType>
86
+    GibbsSampler(const DataType &data, bool amp, unsigned nPatterns);
85 87
 
86
-    GibbsSampler(const std::string &pathToData, unsigned nrow, unsigned ncol,
87
-        unsigned nPatterns, float alpha, float maxGibbsMass, bool singleCell);
88
+    template <class DataType>
89
+    void setUncertainty(const DataType &unc);
90
+    
91
+    void setSparsity(float alpha, bool singleCell);
92
+    void setMaxGibbsMass(float max);
93
+    void setAnnealingTemp(float temp);
88 94
 
89
-    void setUncertainty(const RowMatrix &S);
90
-    void setUncertainty(const std::string &path);
95
+    void setMatrix(const MatA &mat);
91 96
 
92 97
     void update(unsigned nSteps, unsigned nCores);
93
-    void setAnnealingTemp(float temp);
94
-    float getAvgQueue() const;
95
-    
98
+
99
+    unsigned dataRows() const;
100
+    unsigned dataCols() const;
101
+
96 102
     float chi2() const;
97 103
     uint64_t nAtoms() const;
98 104
 
99
-    void setMatrix(const MatA &mat);
105
+    #ifdef GAPS_DEBUG
106
+    float getAvgQueue() const;
107
+    #endif
100 108
 
101 109
     // serialization
102
-    friend Archive& operator<< <T, MatA, MatB> (Archive &ar, GibbsSampler &samp);
103
-    friend Archive& operator>> <T, MatA, MatB> (Archive &ar, GibbsSampler &samp);
110
+    friend Archive& operator<< <T, MatA, MatB> (Archive &ar, GibbsSampler &sampler);
111
+    friend Archive& operator>> <T, MatA, MatB> (Archive &ar, GibbsSampler &sampler);
104 112
 };
105 113
 
106 114
 class AmplitudeGibbsSampler : public GibbsSampler<AmplitudeGibbsSampler, ColMatrix, RowMatrix>
... ...
@@ -116,16 +124,6 @@ private:
116 124
     bool canUseGibbs(unsigned r1, unsigned c1, unsigned r2, unsigned c2) const;
117 125
     void updateAPMatrix(unsigned row, unsigned col, float delta);
118 126
 
119
-public:
120
-
121
-    AmplitudeGibbsSampler(const RowMatrix &D, unsigned nFactor,
122
-        float alpha=0.f, float maxGibbsMass=0.f, bool singleCell=false);
123
-
124
-    AmplitudeGibbsSampler(const std::string &pathToData, unsigned nFactor,
125
-    float alpha=0.f, float maxGibbsMass=0.f, bool singleCell=false);
126
-
127
-    void sync(PatternGibbsSampler &sampler);
128
-
129 127
     AlphaParameters alphaParameters(unsigned row, unsigned col);
130 128
     AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2,
131 129
         unsigned c2);
... ...
@@ -133,6 +131,14 @@ public:
133 131
     float computeDeltaLL(unsigned row, unsigned col, float mass);
134 132
     float computeDeltaLL(unsigned r1, unsigned c1, float m1, unsigned r2,
135 133
         unsigned c2, float m2);
134
+
135
+public:
136
+
137
+    template <class DataType>
138
+    AmplitudeGibbsSampler(const DataType &data, unsigned nPatterns);
139
+
140
+    void sync(PatternGibbsSampler &sampler);
141
+    void recalculateAPMatrix();
136 142
 };
137 143
 
138 144
 class PatternGibbsSampler : public GibbsSampler<PatternGibbsSampler, RowMatrix, ColMatrix>
... ...
@@ -148,16 +154,6 @@ private:
148 154
     bool canUseGibbs(unsigned r1, unsigned c1, unsigned r2, unsigned c2) const;
149 155
     void updateAPMatrix(unsigned row, unsigned col, float delta);
150 156
 
151
-public:
152
-
153
-    PatternGibbsSampler(const RowMatrix &D, unsigned nFactor,
154
-        float alpha=0.f, float maxGibbsMass=0.f, bool singleCell=false);
155
-
156
-    PatternGibbsSampler(const std::string &pathToData, unsigned nFactor,
157
-        float alpha=0.f, float maxGibbsMass=0.f, bool singleCell=false);
158
-
159
-    void sync(AmplitudeGibbsSampler &sampler);
160
-
161 157
     AlphaParameters alphaParameters(unsigned row, unsigned col);
162 158
     AlphaParameters alphaParameters(unsigned r1, unsigned c1, unsigned r2,
163 159
         unsigned c2);
... ...
@@ -165,62 +161,93 @@ public:
165 161
     float computeDeltaLL(unsigned row, unsigned col, float mass);
166 162
     float computeDeltaLL(unsigned r1, unsigned c1, float m1, unsigned r2,
167 163
         unsigned c2, float m2);
164
+
165
+public:
166
+
167
+    template <class DataType>
168
+    PatternGibbsSampler(const DataType &data, unsigned nPatterns);
169
+
170
+    void sync(AmplitudeGibbsSampler &sampler);
171
+    void recalculateAPMatrix();
168 172
 };
169 173
 
170
-/******************* IMPLEMENTATION OF TEMPLATED FUNCTIONS *******************/
174
+/******************** IMPLEMENTATION OF TEMPLATED FUNCTIONS *******************/
171 175
 
172
-template <class T, class MatA, class MatB>
173
-GibbsSampler<T, MatA, MatB>::GibbsSampler(const RowMatrix &D,
174
-unsigned nrow, unsigned ncol, unsigned nPatterns, float alpha,
175
-float maxGibbsMass, bool singleCell)
176
+template <class DataType>
177
+AmplitudeGibbsSampler::AmplitudeGibbsSampler(const DataType &data, unsigned nPatterns)
176 178
     :
177
-mMatrix(nrow, ncol), mOtherMatrix(NULL), mDMatrix(D),
178
-mSMatrix(mDMatrix.pmax(0.1f)), mAPMatrix(D.nRow(), D.nCol()),
179
-mQueue(nrow * ncol, alpha), mLambda(0.f), mMaxGibbsMass(maxGibbsMass),
180
-mAnnealingTemp(0.f), mNumRows(nrow), mNumCols(ncol), mAvgQueue(0.f),
181
-mNumQueues(0.f)
179
+GibbsSampler(data, true, nPatterns)
182 180
 {
183
-    mBinSize = std::numeric_limits<uint64_t>::max()
184
-        / static_cast<uint64_t>(mNumRows * mNumCols);
185
-    uint64_t remain = std::numeric_limits<uint64_t>::max()
186
-        % static_cast<uint64_t>(mNumRows * mNumCols);
187
-    mQueue.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
188
-    mDomain.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
181
+    mQueue.setDimensionSize(mBinSize, mNumCols);
182
+}
189 183
 
190
-    float meanD = singleCell ? gaps::algo::nonZeroMean(mDMatrix) :
191
-        gaps::algo::mean(mDMatrix);
192
-    mLambda = alpha * std::sqrt(nPatterns / meanD);
193
-    mMaxGibbsMass = maxGibbsMass / mLambda;
184
+template <class DataType>
185
+PatternGibbsSampler::PatternGibbsSampler(const DataType &data, unsigned nPatterns)
186
+    :
187
+GibbsSampler(data, false, nPatterns)
188
+{
189
+    mQueue.setDimensionSize(mBinSize, mNumRows);
194 190
 }
195 191
 
196 192
 template <class T, class MatA, class MatB>
197
-GibbsSampler<T, MatA, MatB>::GibbsSampler(const std::string &pathToData,
198
-unsigned nrow, unsigned ncol, unsigned nPatterns, float alpha,
199
-float maxGibbsMass, bool singleCell)
193
+template <class DataType>
194
+GibbsSampler<T, MatA, MatB>::GibbsSampler(const DataType &data,
195
+bool amp, unsigned nPatterns)
200 196
     :
201
-mMatrix(nrow, ncol), mOtherMatrix(NULL), mDMatrix(pathToData),
202
-mSMatrix(mDMatrix.pmax(0.1f)), mAPMatrix(mDMatrix.nRow(), mDMatrix.nCol()),
203
-mQueue(nrow * ncol, alpha), mLambda(0.f), mMaxGibbsMass(maxGibbsMass),
204
-mAnnealingTemp(0.f), mNumRows(nrow), mNumCols(ncol), mAvgQueue(0.f),
205
-mNumQueues(0.f)
197
+mDMatrix(data), mSMatrix(mDMatrix.pmax(0.1f, 0.1f)), 
198
+mAPMatrix(mDMatrix.nRow(), mDMatrix.nCol()),
199
+mMatrix(amp ? mDMatrix.nRow() : nPatterns, amp ? nPatterns : mDMatrix.nCol()),
200
+mOtherMatrix(NULL), mQueue(mMatrix.nRow() * mMatrix.nCol()), mLambda(0.f),
201
+mMaxGibbsMass(100.f), mAnnealingTemp(1.f), mNumRows(mMatrix.nRow()),
202
+mNumCols(mMatrix.nCol()), mAvgQueue(0.f), mNumQueues(0.f)
206 203
 {
204
+    // calculate atomic domain size
207 205
     mBinSize = std::numeric_limits<uint64_t>::max()
208 206
         / static_cast<uint64_t>(mNumRows * mNumCols);
209
-    uint64_t remain = std::numeric_limits<uint64_t>::max()
210
-        % static_cast<uint64_t>(mNumRows * mNumCols);
211
-    mQueue.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
212
-    mDomain.setDomainSize(std::numeric_limits<uint64_t>::max() - remain);
207
+    mQueue.setDomainSize(mBinSize * mNumRows * mNumCols);
208
+    mDomain.setDomainSize(mBinSize * mNumRows * mNumCols);
209
+
210
+    // default sparsity parameters
211
+    setSparsity(0.01, false);
212
+}
213
+
214
+template <class T, class MatA, class MatB>
215
+template <class DataType>
216
+void GibbsSampler<T, MatA, MatB>::setUncertainty(const DataType &unc)
217
+{
218
+    mSMatrix = MatB(unc);
219
+}
220
+
221
+template <class T, class MatA, class MatB>
222
+void GibbsSampler<T, MatA, MatB>::setSparsity(float alpha, bool singleCell)
223
+{
224
+    mQueue.setAlpha(alpha);
213 225
 
214 226
     float meanD = singleCell ? gaps::algo::nonZeroMean(mDMatrix) :
215 227
         gaps::algo::mean(mDMatrix);
228
+
229
+    unsigned nPatterns = mDMatrix.nRow() == mMatrix.nRow() ? mMatrix.nCol() :
230
+        mMatrix.nRow();
231
+
216 232
     mLambda = alpha * std::sqrt(nPatterns / meanD);
217
-    mMaxGibbsMass = maxGibbsMass / mLambda;
218 233
 }
219 234
 
220 235
 template <class T, class MatA, class MatB>
221
-T* GibbsSampler<T, MatA, MatB>::impl()
236
+void GibbsSampler<T, MatA, MatB>::setMaxGibbsMass(float max)
222 237
 {
223
-    return static_cast<T*>(this);
238
+    mMaxGibbsMass = max;
239
+}
240
+
241
+template <class T, class MatA, class MatB>
242
+void GibbsSampler<T, MatA, MatB>::setAnnealingTemp(float temp)
243
+{
244
+    mAnnealingTemp = temp;
245
+}
246
+
247
+template <class T, class MatA, class MatB>
248
+void GibbsSampler<T, MatA, MatB>::setMatrix(const MatA &mat)
249
+{   
250
+    mMatrix = mat;
224 251
 }
225 252
 
226 253
 template <class T, class MatA, class MatB>
... ...
@@ -229,13 +256,19 @@ void GibbsSampler<T, MatA, MatB>::update(unsigned nSteps, unsigned nCores)
229 256
     unsigned n = 0;
230 257
     while (n < nSteps)
231 258
     {
259
+        // populate queue, prepare domain for this queue
232 260
         mQueue.populate(mDomain, nSteps - n);
233