Browse code

corrected update, still problems with convergence

chakalakka authored on 01/10/2014 07:49:16
Showing 6 changed files

... ...
@@ -1,9 +1,7 @@
1 1
 findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, max.iter=-1, num.trials=1, eps.try=NULL, num.threads=1, output.if.not.converged=FALSE, filter.reads=TRUE) {
2 2
 
3 3
 	## Intercept user input
4
-	use.states <- 0:6	# set fixed
5 4
 	IDcheck <- ID  #trigger error if not defined
6
-	if (check.nonnegative.integer.vector(use.states)!=0) stop("argument 'use.states' expects a vector of non-negative integers")
7 5
 	if (check.positive(eps)!=0) stop("argument 'eps' expects a positive numeric")
8 6
 	if (check.integer(max.time)!=0) stop("argument 'max.time' expects an integer")
9 7
 	if (check.integer(max.iter)!=0) stop("argument 'max.iter' expects an integer")
... ...
@@ -14,21 +12,19 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
14 12
 	if (check.positive.integer(num.threads)!=0) stop("argument 'num.threads' expects a positive integer")
15 13
 	if (check.logical(output.if.not.converged)!=0) stop("argument 'output.if.not.converged' expects a logical (TRUE or FALSE)")
16 14
 
17
-
18 15
 	war <- NULL
19 16
 	if (is.null(eps.try)) eps.try <- eps
20 17
 
21 18
 	## Assign variables
22
-# 	state.labels # assigned globally outside this function
23
-	use.state.labels <- state.labels[use.states+1]
24
-	numstates <- length(use.states)
19
+# 	state.labels # assigned in global.R
20
+	numstates <- length(state.labels)
25 21
 	numbins <- length(binned.data)
26 22
 	reads <- mcols(binned.data)$reads
27 23
 	iniproc <- which(init==c("standard","random","empiric")) # transform to int
28 24
 
29 25
 	# Check if there are reads in the data, otherwise HMM will blow up
30 26
 	if (!any(reads!=0)) {
31
-		stop("All reads in data are zero. No univariate HMM done.")
27
+		stop("All reads in data are zero. No HMM done.")
32 28
 	}
33 29
 
34 30
 	# Filter high reads out, makes HMM faster
... ...
@@ -37,9 +33,10 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
37 33
 		mask <- reads > read.cutoff
38 34
 		reads[mask] <- read.cutoff
39 35
 		numfiltered <- length(which(mask))
40
-		if (numfiltered > 0) {
41
-			warning(paste("There are very high read counts in your data (probably artificial). Replaced read counts > ",read.cutoff," (99.99% quantile) by ",read.cutoff," in ",numfiltered," bins. Set option 'filter.reads=FALSE' to disable this filtering.", sep=""))
42
-		}
36
+		cat(paste0("Replaced read counts > ",read.cutoff," (99.99% quantile) by ",read.cutoff," in ",numfiltered," bins. Set option 'filter.reads=FALSE' to disable this filtering.\n"))
37
+# 		if (numfiltered > 0) {
38
+# 			warning(paste("There are very high read counts in your data (probably artificial). Replaced read counts > ",read.cutoff," (99.99% quantile) by ",read.cutoff," in ",numfiltered," bins. Set option 'filter.reads=FALSE' to disable this filtering.", sep=""))
39
+# 		}
43 40
 	}
44 41
 	
45 42
 	
... ...
@@ -51,7 +48,6 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
51 48
 			reads = as.integer(reads), # int* O
52 49
 			num.bins = as.integer(numbins), # int* T
53 50
 			num.states = as.integer(numstates), # int* N
54
-			use.states = as.integer(use.states), # int* statelabels
55 51
 			size = double(length=numstates), # double* size
56 52
 			prob = double(length=numstates), # double* prob
57 53
 			num.iterations = as.integer(max.iter), #  int* maxiter
... ...
@@ -62,6 +58,7 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
62 58
 			proba = double(length=numstates), # double* proba
63 59
 			loglik = double(length=1), # double* loglik
64 60
 			weights = double(length=numstates), # double* weights
61
+			distr.type = as.integer(state.distributions), # int* distr_type
65 62
 			ini.proc = as.integer(iniproc), # int* iniproc
66 63
 			size.initial = double(length=numstates), # double* initial_size
67 64
 			prob.initial = double(length=numstates), # double* initial_prob
... ...
@@ -73,22 +70,7 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
73 70
 			read.cutoff = as.integer(read.cutoff) # int* read_cutoff
74 71
 		)
75 72
 
76
-		names(hmm$weights) <- use.state.labels
77 73
 		hmm$eps <- eps.try
78
-		hmm$A <- matrix(hmm$A, ncol=hmm$num.states, byrow=TRUE)
79
-		rownames(hmm$A) <- use.state.labels
80
-		colnames(hmm$A) <- use.state.labels
81
-		hmm$distributions <- cbind(size=hmm$size, prob=hmm$prob, mu=fmean(hmm$size,hmm$prob), variance=fvariance(hmm$size,hmm$prob))
82
-		rownames(hmm$distributions) <- use.state.labels
83
-		# Treat 'null-mixed' separately
84
-			hmm$distributions[2,'mu'] <- (1-hmm$distributions[2,'prob'])/hmm$distributions[2,'prob']
85
-			hmm$distributions[2,'variance'] <- hmm$distributions[2,'mu']/hmm$distributions[2,'prob']
86
-			hmm$distributions[2,'size'] <- NA
87
-		hmm$A.initial <- matrix(hmm$A.initial, ncol=hmm$num.states, byrow=TRUE)
88
-		rownames(hmm$A.initial) <- use.state.labels
89
-		colnames(hmm$A.initial) <- use.state.labels
90
-		hmm$distributions.initial <- cbind(size=hmm$size.initial, prob=hmm$prob.initial, mu=fmean(hmm$size.initial,hmm$prob.initial), variance=fvariance(hmm$size.initial,hmm$prob.initial))
91
-		rownames(hmm$distributions.initial) <- use.state.labels
92 74
 		if (num.trials > 1) {
93 75
 			if (hmm$loglik.delta > hmm$eps) {
94 76
 				warning("HMM did not converge in trial run ",i_try,"!\n")
... ...
@@ -111,7 +93,6 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
111 93
 			reads = as.integer(reads), # int* O
112 94
 			num.bins = as.integer(numbins), # int* T
113 95
 			num.states = as.integer(numstates), # int* N
114
-			use.states = as.integer(use.states), # int* statelabels
115 96
 			size = double(length=numstates), # double* size
116 97
 			prob = double(length=numstates), # double* prob
117 98
 			num.iterations = as.integer(max.iter), #  int* maxiter
... ...
@@ -122,9 +103,10 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
122 103
 			proba = double(length=numstates), # double* proba
123 104
 			loglik = double(length=1), # double* loglik
124 105
 			weights = double(length=numstates), # double* weights
106
+			distr.type = as.integer(state.distributions), # int* distr_type
125 107
 			ini.proc = as.integer(iniproc), # int* iniproc
126
-			size.initial = as.vector(hmm$distributions[,'size']), # double* initial_size
127
-			prob.initial = as.vector(hmm$distributions[,'prob']), # double* initial_prob
108
+			size.initial = as.vector(hmm$size), # double* initial_size
109
+			prob.initial = as.vector(hmm$prob), # double* initial_prob
128 110
 			A.initial = as.vector(hmm$A), # double* initial_A
129 111
 			proba.initial = as.vector(hmm$proba), # double* initial_proba
130 112
 			use.initial.params = as.logical(1), # bool* use_initial_params
... ...
@@ -136,27 +118,30 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
136 118
 
137 119
 	# Add useful entries
138 120
 	hmm$ID <- ID
139
-	names(hmm$weights) <- use.state.labels
121
+	names(hmm$weights) <- state.labels
140 122
 	hmm$coordinates <- data.frame(as.character(seqnames(binned.data)), start(ranges(binned.data)), end(ranges(binned.data)))
141 123
 	names(hmm$coordinates) <- coordinate.names
142 124
 	hmm$seqlengths <- seqlengths(binned.data)
143 125
 	class(hmm) <- class.aneufinder.hmm
144
-	hmm$states <- factor(use.state.labels, levels=use.state.labels)[hmm$states+1]
126
+	hmm$states <- factor(state.labels, levels=state.labels)[hmm$states+1]
145 127
 	hmm$eps <- eps
146 128
 	hmm$A <- matrix(hmm$A, ncol=hmm$num.states, byrow=TRUE)
147
-	rownames(hmm$A) <- use.state.labels
148
-	colnames(hmm$A) <- use.state.labels
149
-	hmm$distributions <- cbind(size=hmm$size, prob=hmm$prob, mu=fmean(hmm$size,hmm$prob), variance=fvariance(hmm$size,hmm$prob))
150
-	rownames(hmm$distributions) <- use.state.labels
129
+	rownames(hmm$A) <- state.labels
130
+	colnames(hmm$A) <- state.labels
131
+	hmm$distributions <- data.frame(type=state.distributions, size=hmm$size, prob=hmm$prob, mu=fmean(hmm$size,hmm$prob), variance=fvariance(hmm$size,hmm$prob))
132
+	rownames(hmm$distributions) <- state.labels
151 133
 	# Treat 'null-mixed' separately
152
-		hmm$distributions[2,'mu'] <- (1-hmm$distributions[2,'prob'])/hmm$distributions[2,'prob']
153
-		hmm$distributions[2,'variance'] <- hmm$distributions[2,'mu']/hmm$distributions[2,'prob']
154
-		hmm$distributions[2,'size'] <- NA
134
+	if ('null-mixed' %in% state.labels) {
135
+		hmm$distributions['null-mixed','mu'] <- (1-hmm$distributions['null-mixed','prob'])/hmm$distributions['null-mixed','prob']
136
+		hmm$distributions['null-mixed','variance'] <- hmm$distributions['null-mixed','mu']/hmm$distributions['null-mixed','prob']
137
+		hmm$distributions['null-mixed','size'] <- NA
138
+	}
155 139
 	hmm$A.initial <- matrix(hmm$A.initial, ncol=hmm$num.states, byrow=TRUE)
156
-	rownames(hmm$A.initial) <- use.state.labels
157
-	colnames(hmm$A.initial) <- use.state.labels
158
-	hmm$distributions.initial <- cbind(size=hmm$size.initial, prob=hmm$prob.initial, mu=fmean(hmm$size.initial,hmm$prob.initial), variance=fvariance(hmm$size.initial,hmm$prob.initial))
159
-	rownames(hmm$distributions.initial) <- use.state.labels
140
+	rownames(hmm$A.initial) <- state.labels
141
+	colnames(hmm$A.initial) <- state.labels
142
+	hmm$distributions.initial <- data.frame(type=state.distributions, size=hmm$size.initial, prob=hmm$prob.initial, mu=fmean(hmm$size.initial,hmm$prob.initial), variance=fvariance(hmm$size.initial,hmm$prob.initial))
143
+	rownames(hmm$distributions.initial) <- state.labels
144
+	hmm$distributions.initial['nullsomy',2:5] <- c(0,1,0,0)
160 145
 	hmm$filter.reads <- filter.reads
161 146
 
162 147
 	# Delete redundant entries
... ...
@@ -166,7 +151,7 @@ findCNVs <- function(binned.data, ID, eps=0.001, init="standard", max.time=-1, m
166 151
 	hmm$prob.initial <- NULL
167 152
 	hmm$use.initial.params <- NULL
168 153
 	hmm$read.cutoff <- NULL
169
-	hmm$use.states <- NULL
154
+	hmm$distr.type <- NULL
170 155
 
171 156
 	# Issue warnings
172 157
 	if (num.trials == 1) {
... ...
@@ -2,6 +2,7 @@
2 2
 # Some global variables that can be used in all functions
3 3
 # =======================================================
4 4
 state.labels <- c("nullsomy","null-mixed","monosomy","disomy","trisomy","tetrasomy","multisomy")
5
+state.distributions <- factor(c('delta','dgeom','dnbinom','dnbinom','dnbinom','dnbinom','dnbinom'), levels=c('delta','dgeom','dnbinom'))
5 6
 coordinate.names <- c("chrom","start","end")
6 7
 binned.data.names <- c(coordinate.names,"reads")
7 8
 class.aneufinder.hmm <- "aneufinder.hmm"
... ...
@@ -70,13 +70,17 @@ plot.distribution <- function(model, state=NULL, chrom=NULL, start=NULL, end=NUL
70 70
 	x <- 0:rightxlim
71 71
 	distributions <- list(x)
72 72
 
73
-	# zero-inflation
74
-	distributions[[length(distributions)+1]] <- c(weights[1],rep(0,length(x)-1))
75
-	# geometric
76
-	distributions[[length(distributions)+1]] <- weights[2] * dgeom(x, model$distributions[2,'prob'])
77
-	# negative binomials
78
-	for (istate in 3:numstates) {
79
-		distributions[[length(distributions)+1]] <- weights[istate] * dnbinom(x, model$distributions[istate,'size'], model$distributions[istate,'prob'])
73
+	for (istate in 1:nrow(model$distributions)) {
74
+		if (model$distributions[istate,'type']=='delta') {
75
+			# zero-inflation
76
+			distributions[[length(distributions)+1]] <- c(weights[istate],rep(0,length(x)-1))
77
+		} else if (model$distributions[istate,'type']=='dgeom') {
78
+			# geometric
79
+			distributions[[length(distributions)+1]] <- weights[istate] * dgeom(x, model$distributions[istate,'prob'])
80
+		} else if (model$distributions[istate,'type']=='dnbinom') {
81
+			# negative binomials
82
+			distributions[[length(distributions)+1]] <- weights[istate] * dnbinom(x, model$distributions[istate,'size'], model$distributions[istate,'prob'])
83
+		}
80 84
 	}
81 85
 	distributions <- as.data.frame(distributions)
82 86
 	names(distributions) <- c("x",state.labels)
... ...
@@ -7,14 +7,14 @@
7 7
 // This function takes parameters from R, creates a univariate HMM object, creates the distributions, runs the Baum-Welch and returns the result to R.
8 8
 // ===================================================================================================================================================
9 9
 extern "C" {
10
-void R_univariate_hmm(int* O, int* T, int* N, int* statelabels, double* size, double* prob, int* maxiter, int* maxtime, double* eps, int* states, double* A, double* proba, double* loglik, double* weights, int* iniproc, double* initial_size, double* initial_prob, double* initial_A, double* initial_proba, bool* use_initial_params, int* num_threads, int* error, int* read_cutoff)
10
+void R_univariate_hmm(int* O, int* T, int* N, double* size, double* prob, int* maxiter, int* maxtime, double* eps, int* states, double* A, double* proba, double* loglik, double* weights, int* distr_type, int* iniproc, double* initial_size, double* initial_prob, double* initial_A, double* initial_proba, bool* use_initial_params, int* num_threads, int* error, int* read_cutoff)
11 11
 {
12 12
 
13 13
 	// Define logging level
14 14
 // 	FILE* pFile = fopen("chromStar.log", "w");
15 15
 // 	Output2FILE::Stream() = pFile;
16 16
  	FILELog::ReportingLevel() = FILELog::FromString("ERROR");
17
-//  	FILELog::ReportingLevel() = FILELog::FromString("DEBUG2");
17
+//  	FILELog::ReportingLevel() = FILELog::FromString("DEBUG1");
18 18
 
19 19
 	// Parallelization settings
20 20
 	omp_set_num_threads(*num_threads);
... ...
@@ -76,6 +76,7 @@ void R_univariate_hmm(int* O, int* T, int* N, int* statelabels, double* size, do
76 76
 	Rprintf("data mean = %g, data variance = %g\n", mean, variance);		
77 77
 	
78 78
 	// Go through all states of the hmm and assign the density functions
79
+	// This loop assumes that the negative binomial states come last and are consecutive
79 80
 	double imean, ivariance;
80 81
 	for (int i_state=0; i_state<*N; i_state++)
81 82
 	{
... ...
@@ -91,29 +92,41 @@ void R_univariate_hmm(int* O, int* T, int* N, int* statelabels, double* size, do
91 92
 			if (*iniproc == 1)
92 93
 			{
93 94
 				// Simple initialization based on data mean, assumed to be the disomic mean
94
-				imean = mean/2 * statelabels[i_state];
95
-				ivariance = variance/2 * statelabels[i_state];
95
+				if (distr_type[i_state] == 1) { }
96
+				else if (distr_type[i_state] == 2) { }
97
+				else if (distr_type[i_state] == 3)
98
+				{
99
+					for (int ii_state=i_state; ii_state<*N; ii_state++)
100
+					{
101
+						imean = mean/2 * (ii_state-i_state+1);
102
+						ivariance = imean * 5;
103
+// 						ivariance = variance/2 * (i_state-1);
104
+						// Calculate r and p from mean and variance
105
+						initial_size[ii_state] = pow(imean,2)/(ivariance-imean);
106
+						initial_prob[ii_state] = imean/ivariance;
107
+					}
108
+					break;
109
+				}
96 110
 			}
97 111
 
98
-			// Calculate r and p from mean and variance
99
-			initial_size[i_state] = pow(imean,2)/(ivariance-imean);
100
-			initial_prob[i_state] = imean/ivariance;
101
-
102 112
 		}
113
+	}
103 114
 
104
-		if (i_state == 0)
115
+	for (int i_state=0; i_state<*N; i_state++)
116
+	{
117
+		if (distr_type[i_state] == 1)
105 118
 		{
106
-			FILE_LOG(logDEBUG1) << "Using only zeros for state " << i_state;
119
+			FILE_LOG(logDEBUG1) << "Using delta distribution for state " << i_state;
107 120
 			ZeroInflation *d = new ZeroInflation(O, *T); // delete is done inside ~ScaleHMM()
108 121
 			hmm->densityFunctions.push_back(d);
109 122
 		}
110
-		else if (i_state == 1)
123
+		else if (distr_type[i_state] == 2)
111 124
 		{
112 125
 			FILE_LOG(logDEBUG1) << "Using geometric distribution for state " << i_state;
113 126
 			Geometric *d = new Geometric(O, *T, 0.9); // delete is done inside ~ScaleHMM()
114 127
 			hmm->densityFunctions.push_back(d);
115 128
 		}
116
-		else if (i_state >= 2)
129
+		else if (distr_type[i_state] == 3)
117 130
 		{
118 131
 			FILE_LOG(logDEBUG1) << "Using negative binomial for state " << i_state;
119 132
 			NegativeBinomial *d = new NegativeBinomial(O, *T, initial_size[i_state], initial_prob[i_state]); // delete is done inside ~ScaleHMM()
... ...
@@ -166,7 +179,7 @@ void R_univariate_hmm(int* O, int* T, int* N, int* statelabels, double* size, do
166 179
 		{
167 180
 			posterior_per_t[iN] = hmm->get_posterior(iN, t);
168 181
 		}
169
-		states[t] = statelabels[argMax(posterior_per_t, *N)];
182
+		states[t] = argMax(posterior_per_t, *N);
170 183
 	}
171 184
 
172 185
 	FILE_LOG(logDEBUG1) << "Return parameters";
... ...
@@ -325,8 +325,10 @@ void NegativeBinomial::update(double* weights)
325 325
 void NegativeBinomial::update_constrained(double** weights, int fromState, int toState)
326 326
 {
327 327
 	FILE_LOG(logDEBUG2) << __PRETTY_FUNCTION__;
328
+	FILE_LOG(logDEBUG1) << "r = "<<this->size << ", p = "<<this->prob;
328 329
 	double eps = 1e-4, kmax;
329 330
 	double numerator, denominator, rhere, dr, Fr, dFrdr, DigammaR, DigammaRplusDR;
331
+	double logp = log(this->prob);
330 332
 	// Update prob (p)
331 333
 	numerator=denominator=0.0;
332 334
 // 	clock_t time, dtime;
... ...
@@ -339,8 +341,7 @@ void NegativeBinomial::update_constrained(double** weights, int fromState, int t
339 341
 			denominator+=weights[i+fromState][t]*(this->size*(i+1)+this->obs[t]);
340 342
 		}
341 343
 	}
342
-	this->prob = numerator/denominator; // Update of size (r) is now done with updated prob
343
-	double logp = log(this->prob);
344
+	this->prob = numerator/denominator; // Update of size (r) is now done with old prob
344 345
 // 	dtime = clock() - time;
345 346
 // 	FILE_LOG(logDEBUG1) << "updateP(): "<<dtime<< " clicks";
346 347
 	// Update of size (r) with Newton Method
... ...
@@ -370,13 +371,13 @@ void NegativeBinomial::update_constrained(double** weights, int fromState, int t
370 371
 				{
371 372
 					if(this->obs[t]==0)
372 373
 					{
373
-						Fr+=weights[i+fromState][t]*logp;
374
+						Fr+=weights[i+fromState][t]*(i+1)*logp;
374 375
 						//dFrdr+=0;
375 376
 					}
376 377
 					if(this->obs[t]!=0)
377 378
 					{
378
-						Fr+=weights[i+fromState][t]*(logp-DigammaR+DigammaRplusX[(int)obs[t]]);
379
-						dFrdr+=weights[i+fromState][t]/((i+1)*dr)*(DigammaR-DigammaRplusDR+DigammaRplusDRplusX[(int)obs[t]]-DigammaRplusX[(int)obs[t]]);
379
+						Fr+=weights[i+fromState][t]*(i+1)*(logp-DigammaR+DigammaRplusX[(int)obs[t]]);
380
+						dFrdr+=weights[i+fromState][t]/dr*(i+1)*(DigammaR-DigammaRplusDR+DigammaRplusDRplusX[(int)obs[t]]-DigammaRplusX[(int)obs[t]]);
380 381
 					}
381 382
 				}
382 383
 				if(fabs(Fr)<eps)
... ...
@@ -405,13 +406,13 @@ void NegativeBinomial::update_constrained(double** weights, int fromState, int t
405 406
 					DigammaRplusDRplusX = digamma((i+1)*(rhere+dr)+this->obs[t]); // boost::math::digamma<>(rhere+dr+this->obs[ti]);
406 407
 					if(this->obs[t]==0)
407 408
 					{
408
-						Fr+=weights[i+fromState][t]*logp;
409
+						Fr+=weights[i+fromState][t]*(i+1)*logp;
409 410
 						//dFrdr+=0;
410 411
 					}
411 412
 					if(this->obs[t]!=0)
412 413
 					{
413
-						Fr+=weights[i+fromState][t]*(logp-DigammaR+DigammaRplusX);
414
-						dFrdr+=weights[i+fromState][t]/((i+1)*dr)*(DigammaR-DigammaRplusDR+DigammaRplusDRplusX-DigammaRplusX);
414
+						Fr+=weights[i+fromState][t]*(i+1)*(logp-DigammaR+DigammaRplusX);
415
+						dFrdr+=weights[i+fromState][t]/dr*(i+1)*(DigammaR-DigammaRplusDR+DigammaRplusDRplusX-DigammaRplusX);
415 416
 					}
416 417
 				}
417 418
 			}
... ...
@@ -425,6 +426,8 @@ void NegativeBinomial::update_constrained(double** weights, int fromState, int t
425 426
 	}
426 427
 	this->size = rhere;
427 428
 	FILE_LOG(logDEBUG1) << "r = "<<this->size << ", p = "<<this->prob;
429
+	this->mean = this->fmean(this->size, this->prob);
430
+	this->variance = this->fvariance(this->size, this->prob);
428 431
 
429 432
 // 	dtime = clock() - time;
430 433
 // 	FILE_LOG(logDEBUG1) << "updateR(): "<<dtime<< " clicks";
... ...
@@ -699,7 +702,7 @@ void Geometric::calc_densities(double* dens)
699 702
 	}
700 703
 } 
701 704
 
702
-void Geometric::update(double* weight)
705
+void Geometric::update(double* weights)
703 706
 {
704 707
 	FILE_LOG(logDEBUG2) << __PRETTY_FUNCTION__;
705 708
 	double numerator, denominator;
... ...
@@ -707,8 +710,8 @@ void Geometric::update(double* weight)
707 710
 	numerator=denominator=0.0;
708 711
 	for (int t=0; t<this->T; t++)
709 712
 	{
710
-		numerator+=weight[t];
711
-		denominator+=weight[t]*(1+this->obs[t]);
713
+		numerator+=weights[t];
714
+		denominator+=weights[t]*(1+this->obs[t]);
712 715
 	}
713 716
 	this->prob = numerator/denominator;
714 717
 	FILE_LOG(logDEBUG1) << "p = "<<this->prob;
... ...
@@ -256,45 +256,39 @@ void ScaleHMM::baumWelch(int* maxiter, int* maxtime, double* eps)
256 256
 
257 257
 		clock_t clocktime = clock(), dtime;
258 258
 
259
-// 		// Update distribution of state 1 (null-mixed)
260
-// 		this->densityFunctions[1]->update(this->gamma[1]);
261
-// 		// Update distribution of state 2 (monosomic)
262
-// 		int iN=2;
263
-// 		this->densityFunctions[iN]->update_constrained(this->gamma, iN, this->N);
264
-// 		double mean1 = this->densityFunctions[iN]->get_mean();
265
-// 		double variance1 = this->densityFunctions[iN]->get_variance();
266
-// 		// Set others as multiples
259
+// 		// Update all distributions independantly
267 260
 // 		for (int iN=0; iN<this->N; iN++)
268 261
 // 		{
269
-// 			if (iN!=1 and iN!=2)
270
-// 			{
271
-// 				this->densityFunctions[iN]->set_mean(mean1 * (iN-1));
272
-// 				this->densityFunctions[iN]->set_variance(variance1 * (iN-1));
273
-// 			}
274
-// 			FILE_LOG(logDEBUG1) << "mean(state="<<iN<<") = " << this->densityFunctions[iN]->get_mean();
262
+// 			this->densityFunctions[iN]->update(this->gamma[iN]);
275 263
 // 		}
276
-		
277
-		// Update distribution of state 3 (disomic)
278
-		this->densityFunctions[3]->update(this->gamma[3]);
279
-		double mean2 = this->densityFunctions[3]->get_mean();
280
-		double variance2 = this->densityFunctions[3]->get_variance();
281
-		// Update distribution of state 1 (null-mixed)
282
-		this->densityFunctions[1]->update(this->gamma[1]);
283
-// 		// Set mean of state 1 (null-mixed) to half of monosomic
284
-// 		this->densityFunctions[1]->set_mean(mean2/2 / 2);
285
-// 		this->densityFunctions[1]->set_variance(variance2/2 / 2);
286
-		// Set the others as multiples of disomic
264
+
265
+		// Update distribution of state 'null-mixed' and 'monosomic', set others as multiples of 'monosomic'
266
+		// This loop assumes that the negative binomial states come last and are consecutive
287 267
 		for (int iN=0; iN<this->N; iN++)
288 268
 		{
289
-			if (iN!=1 and iN!=3)
269
+			if (this->densityFunctions[iN]->get_name() == ZERO_INFLATION) {}
270
+			if (this->densityFunctions[iN]->get_name() == GEOMETRIC)
290 271
 			{
291
-				this->densityFunctions[iN]->set_mean(mean2/2 * (iN-1));
292
-				this->densityFunctions[iN]->set_variance(variance2/2 * (iN-1));
272
+				this->densityFunctions[iN]->update(this->gamma[iN]);
273
+			}
274
+			if (this->densityFunctions[iN]->get_name() == NEGATIVE_BINOMIAL)
275
+			{
276
+				FILE_LOG(logDEBUG1) << "mean(state="<<iN<<") = " << this->densityFunctions[iN]->get_mean() << ", var(state="<<iN<<") = " << this->densityFunctions[iN]->get_variance();
277
+				this->densityFunctions[iN]->update_constrained(this->gamma, iN, this->N);
278
+				double mean1 = this->densityFunctions[iN]->get_mean();
279
+				double variance1 = this->densityFunctions[iN]->get_variance();
280
+				FILE_LOG(logDEBUG1) << "mean(state="<<iN<<") = " << this->densityFunctions[iN]->get_mean() << ", var(state="<<iN<<") = " << this->densityFunctions[iN]->get_variance();
281
+				// Set others as multiples
282
+				for (int jN=iN+1; jN<this->N; jN++)
283
+				{
284
+					this->densityFunctions[jN]->set_mean(mean1 * (jN-iN+1));
285
+					this->densityFunctions[jN]->set_variance(variance1 * (jN-iN+1));
286
+				FILE_LOG(logDEBUG1) << "mean(state="<<jN<<") = " << this->densityFunctions[jN]->get_mean() << ", var(state="<<jN<<") = " << this->densityFunctions[jN]->get_variance();
287
+				}
288
+				break;
293 289
 			}
294 290
 		}
295
-
296
-
297
-
291
+			
298 292
 		dtime = clock() - clocktime;
299 293
 	 	FILE_LOG(logDEBUG) << "updating distributions: " << dtime << " clicks";
300 294
 		R_CheckUserInterrupt();