Browse code

fixed several bugs

jokergoo authored on 18/10/2018 11:39:39
Showing11 changed files

... ...
@@ -2823,3 +2823,63 @@ anno_summary = function(which = c("column", "row"), border = TRUE, bar_width = 0
2823 2823
 	return(anno)
2824 2824
 }
2825 2825
 
2826
+# == title
2827
+# Block annotation
2828
+#
2829
+# == param
2830
+# -gp Graphic parameters
2831
+# -labels
2832
+# -labels_gp
2833
+# -labels_rot
2834
+# -which
2835
+# -width
2836
+# -height
2837
+#
2838
+anno_block = function(gp = gpar(), labels = NULL, labels_gp = gpar(), labels_rot = ifelse(which == "row", 90, 0),
2839
+	which = c("column", "row"), width = NULL, height = NULL) {
2840
+
2841
+	if(is.null(.ENV$current_annotation_which)) {
2842
+		which = match.arg(which)[1]
2843
+	} else {
2844
+		which = .ENV$current_annotation_which
2845
+	}
2846
+	if(length(labels)) {
2847
+		if(which == "column") {
2848
+			height = grobHeight(textGrob(labels, rot = labels_rot, gp = labels_gp))
2849
+			height = convertHeight(height, "mm") + unit(5, "mm")
2850
+		} else {
2851
+			width = grobWidth(textGrob(labels, rot = labels_rot, gp = labels_gp))
2852
+			width = convertWidth(width, "mm") + unit(5, "mm")
2853
+		}
2854
+	}
2855
+
2856
+	anno_size = anno_width_and_height(which, width, height, unit(5, "mm"))
2857
+	
2858
+	fun = function(index, k, n) {
2859
+		gp = subset_gp(recycle_gp(gp, n), k)
2860
+		
2861
+		grid.rect(gp = gp)
2862
+		if(length(labels)) {
2863
+			if(length(labels) != n) {
2864
+				stop_wrap("Length of `labels` should be as same as number of slices.")
2865
+			}
2866
+			label = labels[k]
2867
+			labels_gp = subset_gp(recycle_gp(labels_gp, n), k)
2868
+			grid.text(label, gp = labels_gp, rot = labels_rot)
2869
+		}
2870
+	}
2871
+
2872
+	anno = AnnotationFunction(
2873
+		fun = fun,
2874
+		n = NA,
2875
+		fun_name = "anno_block",
2876
+		which = which,
2877
+		var_import = list(gp, labels, labels_gp, labels_rot),
2878
+		subset_rule = list(),
2879
+		subsetable = TRUE,
2880
+		height = anno_size$height,
2881
+		width = anno_size$width,
2882
+		show_name = FALSE
2883
+	)
2884
+	return(anno) 
2885
+}
... ...
@@ -395,8 +395,8 @@ setMethod(f = "make_layout",
395 395
             layout_size$row_anno_left_width,
396 396
             layout_size$row_anno_right_width
397 397
         ) + object@matrix_param$width
398
-        if(nr_slice > 1) {
399
-            object@heatmap_param$width = object@heatmap_param$width + sum(row_gap[seq_len(nr_slice-1)])
398
+        if(nc_slice > 1) {
399
+            object@heatmap_param$width = object@heatmap_param$width + sum(column_gap[seq_len(nc_slice-1)])
400 400
         }
401 401
     } else {
402 402
         object@heatmap_param$width = unit(1, "npc")
... ...
@@ -423,8 +423,8 @@ setMethod(f = "make_layout",
423 423
             layout_size$column_anno_bottom_height,
424 424
             layout_size$column_names_bottom_height
425 425
         ) + object@matrix_param$height
426
-        if(nc_slice > 1) {
427
-            object@heatmap_param$height = object@heatmap_param$height + sum(column_gap[seq_len(nc_slice-1)])
426
+        if(nr_slice > 1) {
427
+            object@heatmap_param$height = object@heatmap_param$height + sum(row_gap[seq_len(nr_slice-1)])
428 428
         }
429 429
     } else {
430 430
         object@heatmap_param$height = unit(1, "npc")
... ...
@@ -150,7 +150,7 @@ HeatmapAnnotation = function(...,
150 150
     anno_args = setdiff(called_args, fun_args)
151 151
     if(any(anno_args == "")) stop("annotations should have names.")
152 152
     if(is.null(called_args)) {
153
-    	stop_wrap("It seems you are putting only one argument to the function. If it is a simple vector annotation, specify it as HeatmapAnnotation(name = value). If it is a data frame annotation, specify it as HeatmapAnnotation(df = value)")
153
+    	stop_wrap("It seems you are putting only one argument to the function. If it is a simple vector annotation or a function annotation (e.g. anno_*()), specify it as HeatmapAnnotation(name = value). If it is a data frame annotation, specify it as HeatmapAnnotation(df = value)")
154 154
     }
155 155
 
156 156
     ##### pull all annotation to `anno_value_list`####
... ...
@@ -479,28 +479,30 @@ setMethod(f = "draw",
479 479
         heatmap_width = heatmap_width
480 480
     )
481 481
 
482
-    layout = grid.layout(nrow = length(HEATMAP_LIST_LAYOUT_COLUMN_COMPONENT), 
483
-        ncol = length(HEATMAP_LAYOUT_ROW_COMPONENT), 
484
-        widths = component_width(object), 
485
-        heights = component_height(object))
486 482
     ht_list_width = sum(component_width(object)) + padding[2] + padding[4]
487 483
     ht_list_height = sum(component_height(object)) + padding[1] + padding[3]
488 484
 
489 485
     if(is_abs_unit(ht_list_width)) {
490
-        ht_list_width = unit(round(convertWidth(ht_list_width, "mm", valueOnly = TRUE)), "mm")
486
+        ht_list_width = unit(ceiling(convertWidth(ht_list_width, "mm", valueOnly = TRUE)), "mm")
491 487
         qqcat("Since all heatmaps/annotations have absolute units, the total width of the plot is @{ht_list_width}\n")
492 488
         w = ht_list_width
493 489
     } else {
494 490
         w = unit(1, "npc")
495 491
     }
496 492
     if(is_abs_unit(ht_list_height)) {
497
-        ht_list_height = unit(round(convertHeight(ht_list_height, "mm", valueOnly = TRUE)), "mm")
493
+        ht_list_height = unit(ceiling(convertHeight(ht_list_height, "mm", valueOnly = TRUE)), "mm")
498 494
         qqcat("Since all heatmaps/annotations have absolute units, the total height of the plot is @{ht_list_height}\n")
499 495
         h = ht_list_height
500 496
     } else {
501 497
         h = unit(1, "npc")
502 498
     }
503 499
 
500
+    layout = grid.layout(nrow = length(HEATMAP_LIST_LAYOUT_COLUMN_COMPONENT), 
501
+        ncol = length(HEATMAP_LAYOUT_ROW_COMPONENT), 
502
+        widths = component_width(object), 
503
+        heights = component_height(object))
504
+    
505
+
504 506
     pushViewport(viewport(name = "global", width = w, height = h))
505 507
     pushViewport(viewport(layout = layout, name = "global_layout", x = padding[2], y = padding[1], width = unit(1, "npc") - padding[2] - padding[4],
506 508
         height = unit(1, "npc") - padding[1] - padding[3], just = c("left", "bottom")))
... ...
@@ -435,6 +435,28 @@ setMethod(f = "adjust_heatmap_list",
435 435
     }
436 436
 
437 437
 
438
+    adjust_annotation_extension = object@ht_list_param$adjust_annotation_extension
439
+
440
+    # the padding of the heatmap list should be recorded because if the total wdith of e.g. heatmap body
441
+    # is a fixed value, the width should added by the padding
442
+    padding = unit(c(0, 0, 0, 0), "mm")
443
+    if(adjust_annotation_extension) {
444
+        if(object@layout$row_anno_max_bottom_extended[[1]] > object@layout$max_bottom_component_height[[1]]) {
445
+            padding[1] = object@layout$row_anno_max_bottom_extended - object@layout$max_bottom_component_height
446
+        }
447
+        if(object@layout$column_anno_max_left_extended[[1]] > object@layout$max_left_component_width[[1]]) {
448
+            padding[2] = object@layout$column_anno_max_left_extended - object@layout$max_left_component_width + GLOBAL_PADDING[2]
449
+        }
450
+            
451
+        if(object@layout$row_anno_max_top_extended[[1]] > object@layout$max_top_component_height[[1]]) {
452
+            padding[3] = object@layout$row_anno_max_top_extended - object@layout$max_top_component_height + GLOBAL_PADDING[3]
453
+        }
454
+        if(object@layout$column_anno_max_right_extended[[1]] > object@layout$max_right_component_width[[1]]) {
455
+            padding[4] = object@layout$column_anno_max_right_extended - object@layout$max_right_component_width + GLOBAL_PADDING[4]
456
+        }
457
+    }
458
+    object@layout$heatmap_list_padding = padding
459
+
438 460
     return(object)
439 461
 })
440 462
 
... ...
@@ -463,22 +485,7 @@ setMethod(f = "draw_heatmap_list",
463 485
     ht_gap = object@ht_list_param$ht_gap
464 486
     adjust_annotation_extension = object@ht_list_param$adjust_annotation_extension
465 487
 
466
-    padding = unit(c(0, 0, 0, 0), "mm")
467
-    if(adjust_annotation_extension) {
468
-        if(object@layout$row_anno_max_bottom_extended[[1]] > object@layout$max_bottom_component_height[[1]]) {
469
-            padding[1] = object@layout$row_anno_max_bottom_extended - object@layout$max_bottom_component_height
470
-        }
471
-        if(object@layout$column_anno_max_left_extended[[1]] > object@layout$max_left_component_width[[1]]) {
472
-            padding[2] = object@layout$column_anno_max_left_extended - object@layout$max_left_component_width + GLOBAL_PADDING[2]
473
-        }
474
-            
475
-        if(object@layout$row_anno_max_top_extended[[1]] > object@layout$max_top_component_height[[1]]) {
476
-            padding[3] = object@layout$row_anno_max_top_extended - object@layout$max_top_component_height + GLOBAL_PADDING[3]
477
-        }
478
-        if(object@layout$column_anno_max_right_extended[[1]] > object@layout$max_right_component_width[[1]]) {
479
-            padding[4] = object@layout$column_anno_max_right_extended - object@layout$max_right_component_width + GLOBAL_PADDING[4]
480
-        }
481
-    }
488
+    padding = object@layout$heatmap_list_padding
482 489
 
483 490
     pushViewport(viewport(x = padding[2], y = padding[1], width = unit(1, "npc") - padding[2] - padding[4],
484 491
         height = unit(1, "npc") - padding[1] - padding[3], just = c("left", "bottom")))
... ...
@@ -523,7 +530,7 @@ setMethod(f = "draw_heatmap_list",
523 530
             } else {
524 531
                 x = unit(0, "npc")
525 532
             }
526
-            
533
+
527 534
             pushViewport(viewport(x = x, y = unit(0, "npc"), width = heatmap_width[i], just = c("left", "bottom"), name = paste0("heatmap_", object@ht_list[[i]]@name)))
528 535
             if(inherits(ht, "Heatmap")) {
529 536
                 draw(ht, internal = TRUE)
... ...
@@ -1068,7 +1068,7 @@ setMethod(f = "component_width",
1068 1068
                     }
1069 1069
                 })))
1070 1070
             if(is_abs_unit(width)) {
1071
-                width + sum(object@ht_list_param$ht_gap) - object@ht_list_param$ht_gap[length(object@ht_list_param$ht_gap)]
1071
+                width + sum(object@ht_list_param$ht_gap) - object@ht_list_param$ht_gap[length(object@ht_list_param$ht_gap)] + object@layout$heatmap_list_padding[2] + object@layout$heatmap_list_padding[4]
1072 1072
             } else {
1073 1073
                 unit(1, "null") 
1074 1074
             }
... ...
@@ -1095,7 +1095,7 @@ setMethod(f = "component_width",
1095 1095
                 if(convertWidth(width, "mm", valueOnly = TRUE) == 0) {
1096 1096
                     unit(1, "null")
1097 1097
                 } else {
1098
-                    width
1098
+                    width + object@layout$heatmap_list_padding[2] + object@layout$heatmap_list_padding[4]
1099 1099
                 }
1100 1100
             } else {
1101 1101
                 unit(1, "null") 
... ...
@@ -1146,7 +1146,7 @@ setMethod(f = "component_height",
1146 1146
                     }
1147 1147
                 })))
1148 1148
             if(is_abs_unit(height)) {
1149
-                height + sum(object@ht_list_param$ht_gap) - object@ht_list_param$ht_gap[length(object@ht_list_param$ht_gap)]
1149
+                height + sum(object@ht_list_param$ht_gap) - object@ht_list_param$ht_gap[length(object@ht_list_param$ht_gap)] + object@layout$heatmap_list_padding[1] + object@layout$heatmap_list_padding[3]
1150 1150
             } else {
1151 1151
                 unit(1, "null") 
1152 1152
             }
... ...
@@ -1171,7 +1171,7 @@ setMethod(f = "component_height",
1171 1171
                 if(convertWidth(height, "mm", valueOnly = TRUE) == 0) {
1172 1172
                     unit(1, "null")
1173 1173
                 } else {
1174
-                    height
1174
+                    height + object@layout$heatmap_list_padding[1] + object@layout$heatmap_list_padding[3]
1175 1175
                 }
1176 1176
             } else {
1177 1177
                 unit(1, "null") 
... ...
@@ -232,7 +232,7 @@ SingleAnnotation = function(name, value, col, fun,
232 232
             anno_fun_extend = fun@extended
233 233
             if(verbose) qqcat("@{name}: annotation is a AnnotationFunction object\n")
234 234
 
235
-            show_name = fun@show_name
235
+            if(!fun@show_name) show_name = fun@show_name
236 236
         } else {
237 237
             fun = AnnotationFunction(fun = fun)
238 238
             anno_fun_extend = fun@extended
... ...
@@ -77,21 +77,37 @@ densityHeatmap = function(data,
77 77
 	column_names_rot = 90,
78 78
 
79 79
 	cluster_columns = FALSE,
80
+	clustering_distance_columns = "ks",
81
+	clustering_method_columns = "complete",
82
+
80 83
 	...) {
81 84
 
85
+	arg_list = list(...)
86
+	if(length(arg_list)) {
87
+		if(any(c("row_km", "row_split", "split", "km") %in% names(arg_list))) {
88
+			stop_wrap("density heatmaps do not allow row splitting.")
89
+		}
90
+		if(grepl("row", names(arg_list))) {
91
+			stop_wrap("density heatmaps do not allow to set rows.")
92
+		}
93
+	}
94
+
95
+	ylab = ylab
96
+	column_title = column_title
97
+
82 98
 	density_param$na.rm = TRUE
83 99
 
84
-	if(is.matrix(data)) {
85
-		density_list = apply(data, 2, function(x) do.call(density, c(list(x = x), density_param)))
86
-		quantile_list = apply(data, 2, quantile, na.rm = TRUE)
87
-		mean_value = apply(data, 2, mean, na.rm = TRUE)
88
-	} else if(is.data.frame(data) || is.list(data)) {
89
-		density_list = lapply(data, function(x) do.call(density, c(list(x = x), density_param)))
90
-		quantile_list = sapply(data, quantile, na.rm = TRUE)
91
-		mean_value = sapply(data, mean, na.rm = TRUE)
92
-	} else {
100
+	if(!is.matrix(data) && !is.data.frame(matrix) && !is.list(matrix)) {
93 101
 		stop("only matrix and list are allowed.")
94 102
 	}
103
+	if(is.matrix(data)) {
104
+		data2 = as.list(as.data.frame(data))
105
+		names(data2) = colnames(data)
106
+		data = data2
107
+	}
108
+	density_list = lapply(data, function(x) do.call(density, c(list(x = x), density_param)))
109
+	quantile_list = sapply(data, quantile, na.rm = TRUE)
110
+	mean_value = sapply(data, mean, na.rm = TRUE)
95 111
 
96 112
 	n = length(density_list)
97 113
 	nm = names(density_list)
... ...
@@ -113,6 +129,25 @@ densityHeatmap = function(data,
113 129
 	mat = as.matrix(as.data.frame(mat))
114 130
 	colnames(mat) = nm
115 131
 
132
+	if(cluster_columns) {
133
+		if(clustering_distance_columns == "ks") {
134
+			nc = length(data)
135
+		    d = matrix(NA, nrow = nc, ncol = nc)
136
+		    rownames(d) = colnames(d) = rownames(d)
137
+
138
+		    for(i in 2:nc) {
139
+		        for(j in 1:(nc-1)) {
140
+		            suppressWarnings(d[i, j] <- ks.test(data[[i]], data[[j]])$stat)
141
+		        }
142
+		    }
143
+
144
+		    d = as.dist(d)
145
+
146
+			hc = hclust(d, clustering_method_columns)
147
+			cluster_columns = hc
148
+		}
149
+	}
150
+
116 151
 	col = colorRamp2(seq(0, max(mat, na.rm = TRUE), length = length(col)), col, space = color_space)
117 152
 
118 153
 	bb = grid.pretty(c(min_x, max_x))
... ...
@@ -121,6 +156,8 @@ densityHeatmap = function(data,
121 156
 		column_title_gp = title_gp,
122 157
 		cluster_rows = FALSE, 
123 158
 		cluster_columns = cluster_columns,
159
+		clustering_distance_columns = clustering_distance_columns,
160
+		clustering_method_columns = clustering_method_columns,
124 161
 		column_names_side = column_names_side,
125 162
 		show_column_names = show_column_names,
126 163
 		column_names_max_height = column_names_max_height,
... ...
@@ -143,26 +180,40 @@ densityHeatmap = function(data,
143 180
 
144 181
 	post_fun = function(ht) {
145 182
 		column_order = column_order(ht)
183
+		if(!is.list(column_order)) {
184
+			column_order = list(column_order)
185
+		}
186
+		n_slice = length(column_order)
146 187
 
147 188
 		decorate_annotation(paste0("axis_", random_str), {
148 189
 			grid.text(ylab, x = grobHeight(textGrob(ylab, gp = ylab_gp)), rot = 90)
149
-		})
190
+		}, slice = 1)
191
+
192
+		for(i_slice in 1:n_slice) {
193
+			decorate_heatmap_body(paste0("density_", random_str), {
194
+				n = length(column_order[[i_slice]])
195
+				pushViewport(viewport(xscale = c(0.5, n + 0.5), yscale = c(min_x, max_x), clip = TRUE))
196
+				for(i in seq_len(5)) {
197
+					grid.lines(1:n, quantile_list[i, column_order[[i_slice]] ], default.units = "native", gp = gpar(lty = 2))
198
+				}
199
+				grid.lines(1:n, mean_value[ column_order[[i_slice]] ], default.units = "native", gp = gpar(lty = 2, col = "darkred"))
200
+				upViewport()
201
+			}, column_slice = i_slice)
202
+		}
150 203
 
151 204
 		decorate_heatmap_body(paste0("density_", random_str), {
152
-			pushViewport(viewport(xscale = c(0.5, n + 0.5), yscale = c(min_x, max_x), clip = TRUE))
153
-			for(i in seq_len(5)) {
154
-				grid.lines(1:n, quantile_list[i, column_order], default.units = "native", gp = gpar(lty = 2))
155
-			}
156
-			grid.lines(1:n, mean_value[column_order], default.units = "native", gp = gpar(lty = 2, col = "darkred"))
205
+			pushViewport(viewport(yscale = c(min_x, max_x), clip = FALSE))
206
+			grid.rect(gp = gpar(fill = NA))
207
+			grid.yaxis(gp = tick_label_gp)
157 208
 			upViewport()
158
-		})
209
+		}, column_slice = 1)
210
+
159 211
 		decorate_heatmap_body(paste0("density_", random_str), {
212
+			n = length(column_order[[n_slice]])
160 213
 			pushViewport(viewport(xscale = c(0.5, n + 0.5), yscale = c(min_x, max_x), clip = FALSE))
161
-			grid.rect(gp = gpar(fill = NA))
162
-			grid.yaxis(gp = tick_label_gp)
163 214
 
164 215
 			labels = c(rownames(quantile_list), "mean")
165
-			y = c(quantile_list[, column_order[n]], mean_value[column_order[n]])
216
+			y = c(quantile_list[, column_order[[n_slice]][n] ], mean_value[ column_order[[n_slice]][n] ])
166 217
 			od = order(y)
167 218
 			y = y[od]
168 219
 			labels = labels[od]
... ...
@@ -180,7 +231,7 @@ densityHeatmap = function(data,
180 231
 	        grid.segments(unit(1, "npc") + rep(link_width * (2/3), n2), h, unit(1, "npc") + rep(link_width, n2), h, default.units = "native")
181 232
 
182 233
 			upViewport()
183
-		})
234
+		}, column_slice = n_slice)
184 235
 	}
185 236
 
186 237
 	ht@heatmap_param$post_fun = post_fun
... ...
@@ -653,10 +653,10 @@ grid.boxplot = function(value, pos, outline = TRUE, box_width = 0.6,
653 653
                       default.units = "native", gp = gp)
654 654
         if(outline) {   
655 655
             l1 = value > boxplot_stats[5, 1]
656
-            if(sum(l1)) grid.points(x = pos, y = value[l1], 
656
+            if(sum(l1)) grid.points(x = rep(pos, sum(l1)), y = value[l1], 
657 657
                 default.units = "native", gp = gp, pch = pch, size = size)
658 658
             l2 = value < boxplot_stats[1, 1]
659
-            if(sum(l2)) grid.points(x = pos, y = value[l2], 
659
+            if(sum(l2)) grid.points(x = rep(pos, sum(l2)), y = value[l2], 
660 660
                 default.units = "native", gp = gp, pch = pch, size = size) 
661 661
         }
662 662
     } else {
... ...
@@ -681,10 +681,10 @@ grid.boxplot = function(value, pos, outline = TRUE, box_width = 0.6,
681 681
                       default.units = "native", gp = gp)
682 682
         if(outline) {   
683 683
             l1 = value > boxplot_stats[5, 1]
684
-            if(sum(l1)) grid.points(y = pos, x = value[l1], 
684
+            if(sum(l1)) grid.points(y = rep(pos, sum(l1)), x = value[l1], 
685 685
                 default.units = "native", gp = gp, pch = pch, size = size)
686 686
             l2 = value < boxplot_stats[1, 1]
687
-            if(sum(l2)) grid.points(y = pos, x = value[l2], 
687
+            if(sum(l2)) grid.points(y = rep(pos, sum(l2)), x = value[l2], 
688 688
                 default.units = "native", gp = gp, pch = pch, size = size) 
689 689
         }
690 690
     }
... ...
@@ -260,3 +260,18 @@ Heatmap(m) + rowAnnotation(mark = anno)
260 260
 
261 261
 ht_list = Heatmap(m, cluster_rows = F, cluster_columns = F) + rowAnnotation(mark = anno)
262 262
 draw(ht_list, row_split = c(rep("a", 95), rep("b", 5)))
263
+
264
+
265
+### anno_block
266
+
267
+anno = anno_block(gp = gpar(fill = 1:4))
268
+draw(anno, index = 1:10, k = 1, n = 4, test = "anno_block")
269
+draw(anno, index = 1:10, k = 2, n = 4, test = "anno_block")
270
+
271
+anno = anno_block(gp = gpar(fill = 1:4), labels = letters[1:4], labels_gp = gpar(col = "white"))
272
+draw(anno, index = 1:10, k = 2, n = 4, test = "anno_block")
273
+draw(anno, index = 1:10, k = 4, n = 4, test = "anno_block")
274
+draw(anno, index = 1:10, k = 2, n = 2, test = "anno_block")
275
+
276
+anno = anno_block(gp = gpar(fill = 1:4), labels = letters[1:4], labels_gp = gpar(col = "white"), which = "row")
277
+draw(anno, index = 1:10, k = 2, n = 4, test = "anno_block")
... ...
@@ -84,3 +84,10 @@ ha = HeatmapAnnotation(summary = anno_summary(gp = gpar(fill = 2:3), height = un
84 84
 v = rnorm(50)
85 85
 Heatmap(v, top_annotation = ha, width = unit(1, "cm"), split = split)
86 86
 
87
+
88
+
89
+### auto adjust
90
+m = matrix(rnorm(100), 10)
91
+Heatmap(m, top_annotation = HeatmapAnnotation(foo = 1:10), column_dend_height = unit(4, "cm")) +
92
+Heatmap(m, top_annotation = HeatmapAnnotation(bar = anno_points(1:10)),
93
+	cluster_columns = FALSE)