Browse code

small bug fix

jokergoo authored on 02/10/2018 14:36:00
Showing11 changed files

... ...
@@ -25,7 +25,8 @@ AnnotationFunction = setClass("AnnotationFunction",
25 25
 		subset_rule = "list",
26 26
 		subsetable = "logical",
27 27
 		data_scale = "numeric",
28
-		extended = "ANY"
28
+		extended = "ANY",
29
+		show_name = "logical"
29 30
 	),
30 31
 	prototype = list(
31 32
 		fun_name = "",
... ...
@@ -35,7 +36,8 @@ AnnotationFunction = setClass("AnnotationFunction",
35 36
 		subsetable = FALSE,
36 37
 		data_scale = c(0, 1),
37 38
 		n = 0,
38
-		extended = unit(c(0, 0, 0, 0), "mm")
39
+		extended = unit(c(0, 0, 0, 0), "mm"),
40
+		show_name = TRUE
39 41
 	)
40 42
 )
41 43
 
... ...
@@ -173,7 +175,7 @@ anno_width_and_height = function(which, width = NULL, height = NULL,
173 175
 # are all subsettable.
174 176
 AnnotationFunction = function(fun, fun_name = "", which = c("column", "row"), 
175 177
 	var_import = list(), n = 0, data_scale = c(0, 1), subset_rule = list(), 
176
-	subsetable = FALSE, width = NULL, height = NULL) {
178
+	subsetable = FALSE, show_name = TRUE, width = NULL, height = NULL) {
177 179
 
178 180
 	which = match.arg(which)[1]
179 181
 
... ...
@@ -190,6 +192,8 @@ AnnotationFunction = function(fun, fun_name = "", which = c("column", "row"),
190 192
 	anno@width = anno_size$width
191 193
 	anno@height = anno_size$height
192 194
 
195
+	anno@show_name = show_name
196
+
193 197
 	anno@n = n
194 198
 	anno@data_scale = data_scale
195 199
 
... ...
@@ -61,7 +61,8 @@ anno_empty = function(which = c("column", "row"), border = TRUE, width = NULL, h
61 61
 		subset_rule = list(),
62 62
 		subsetable = TRUE,
63 63
 		height = anno_size$height,
64
-		width = anno_size$width
64
+		width = anno_size$width,
65
+		show_name = FALSE
65 66
 	)
66 67
 	return(anno) 
67 68
 }
... ...
@@ -762,7 +763,7 @@ update_anno_extend = function(anno, axis_grob, axis_param) {
762 763
 # 	add_points = TRUE, pt_gp = gpar(col = 5:6), pch = c(1, 16))
763 764
 # draw(anno, test = "matrix")
764 765
 anno_lines = function(x, which = c("column", "row"), border = TRUE, gp = gpar(), 
765
-	add_points = FALSE, pch = 16, size = unit(2, "mm"), pt_gp = gpar(), ylim = NULL, 
766
+	add_points = FALSE, smooth = FALSE, pch = 16, size = unit(2, "mm"), pt_gp = gpar(), ylim = NULL, 
766 767
 	extend = 0.05, axis = TRUE, axis_param = default_axis_param(which),
767 768
 	width = NULL, height = NULL) {
768 769
 
... ...
@@ -822,18 +823,34 @@ anno_lines = function(x, which = c("column", "row"), border = TRUE, gp = gpar(),
822 823
 		pushViewport(viewport(xscale = data_scale, yscale = c(0.5, n+0.5)))
823 824
 		if(is.matrix(value)) {
824 825
 			for(i in seq_len(ncol(value))) {
825
-				grid.lines(value[index, i], n - seq_along(index) + 1, gp = subset_gp(gp, i), 
826
-					default.units = "native")
826
+				x = n - seq_along(index) + 1
827
+				y = value[index, i]
828
+				if(smooth) {
829
+					fit = loess(y ~ x)
830
+					x2 = seq(x[1], x[length(x)], length = 100)
831
+					y2 = predict(fit, x2)
832
+					grid.lines(y2, x2, gp = subset_gp(gp, i), default.units = "native")
833
+				} else {
834
+					grid.lines(y, x, gp = subset_gp(gp, i), default.units = "native")
835
+				}
827 836
 				if(add_points) {
828
-					grid.points(value[index, i], n - seq_along(index) + 1, gp = subset_gp(pt_gp, i), 
837
+					grid.points(y, x, gp = subset_gp(pt_gp, i), 
829 838
 						default.units = "native", pch = pch[i], size = size[i])
830 839
 				}
831 840
 			}
832 841
 		} else {
833
-			grid.lines(value[index, i], n - seq_along(index) + 1, gp = gp, 
834
-				default.units = "native")
842
+			x = n - seq_along(index) + 1
843
+			y = value[index]
844
+			if(smooth) {
845
+				fit = loess(y ~ x)
846
+				x2 = seq(x[1], x[length(x)], length = 100)
847
+				y2 = predict(fit, x2)
848
+				grid.lines(y2, x2, gp = gp, default.units = "native")
849
+			} else {
850
+				grid.lines(y, x, gp = gp, default.units = "native")
851
+			}
835 852
 			if(add_points) {
836
-				grid.points(value[index], n - seq_along(index) + 1, gp = gp, default.units = "native", 
853
+				grid.points(y, x, gp = gp, default.units = "native", 
837 854
 					pch = pch[index], size = size[index])
838 855
 			}
839 856
 		}
... ...
@@ -853,15 +870,32 @@ anno_lines = function(x, which = c("column", "row"), border = TRUE, gp = gpar(),
853 870
 		pushViewport(viewport(yscale = data_scale, xscale = c(0.5, n+0.5)))
854 871
 		if(is.matrix(value)) {
855 872
 			for(i in seq_len(ncol(value))) {
856
-				grid.lines(seq_along(index), value[index, i], gp = subset_gp(gp, i), 
857
-					default.units = "native")
873
+				x = seq_along(index)
874
+				y = value[index, i]
875
+				if(smooth) {
876
+					fit = loess(y ~ x)
877
+					x2 = seq(x[1], x[length(x)], length = 100)
878
+					y2 = predict(fit, x2)
879
+					grid.lines(x2, y2, gp = subset_gp(gp, i), default.units = "native")
880
+				} else {
881
+					grid.lines(x, y, gp = subset_gp(gp, i), default.units = "native")
882
+				}
858 883
 				if(add_points) {
859
-					grid.points(seq_along(index), value[index, i], gp = subset_gp(pt_gp, i), 
884
+					grid.points(x, y, gp = subset_gp(pt_gp, i), 
860 885
 						default.units = "native", pch = pch[i], size = size[i])
861 886
 				}
862 887
 			}
863 888
 		} else {
864
-			grid.lines(seq_along(index), value[index], gp = gp, default.units = "native")
889
+			x = seq_along(index)
890
+			y = value[index]
891
+			if(smooth) {
892
+				fit = loess(y ~ x)
893
+				x2 = seq(x[1], x[length(x)], length = 100)
894
+				y2 = predict(fit, x2)
895
+				grid.lines(x2, y2, gp = gp, default.units = "native")
896
+			} else {
897
+				grid.lines(x, y, gp = gp, default.units = "native")
898
+			}
865 899
 			if(add_points) {
866 900
 				grid.points(seq_along(index), value[index], gp = pt_gp, default.units = "native", 
867 901
 					pch = pch[index], size = size[index])
... ...
@@ -891,7 +925,8 @@ anno_lines = function(x, which = c("column", "row"), border = TRUE, gp = gpar(),
891 925
 		height = anno_size$height,
892 926
 		n = n,
893 927
 		data_scale = data_scale,
894
-		var_import = list(value, gp, border, pch, size, pt_gp, axis, axis_param, axis_grob, data_scale, add_points)
928
+		var_import = list(value, gp, border, pch, size, pt_gp, axis, axis_param, 
929
+			axis_grob, data_scale, add_points, smooth)
895 930
 	)
896 931
 
897 932
 	anno@subset_rule$gp = subset_vector
... ...
@@ -968,15 +1003,17 @@ anno_barplot = function(x, baseline = 0, which = c("column", "row"), border = TR
968 1003
 	if(!is.null(ylim)) data_scale = ylim
969 1004
 	if(baseline == "min") {
970 1005
 		data_scale = data_scale + c(0, extend)*(data_scale[2] - data_scale[1])
1006
+		baseline = min(x)
971 1007
 	} else if(baseline == "max") {
972 1008
 		data_scale = data_scale + c(-extend, 0)*(data_scale[2] - data_scale[1])
1009
+		baseline = max(x)
973 1010
 	} else {
974 1011
 		if(is.numeric(baseline)) {
975 1012
 			if(baseline == 0 && all(rowSums(x) == 1)) {
976 1013
 				data_scale = c(0, 1)
977 1014
 			} else if(baseline <= min(x)) {
978 1015
 				data_scale = c(baseline, extend*(data_scale[2] - baseline) + data_scale[2])
979
-			} else if(baseline >= rowSums(x)) {
1016
+			} else if(baseline >= max(x)) {
980 1017
 				data_scale = c(-extend*(baseline - data_scale[1]) + data_scale[1], baseline)
981 1018
 			} else {
982 1019
 				data_scale = data_scale + c(-extend, extend)*(data_scale[2] - data_scale[1])
... ...
@@ -1824,7 +1861,8 @@ anno_text = function(x, which = c("column", "row"), gp = gpar(),
1824 1861
 		width = width,
1825 1862
 		height = height,
1826 1863
 		n = n,
1827
-		var_import = list(value, gp, just, rot, location)
1864
+		var_import = list(value, gp, just, rot, location),
1865
+		show_name = FALSE
1828 1866
 	)
1829 1867
 
1830 1868
 	anno@subset_rule$value = subset_vector
... ...
@@ -2465,6 +2503,8 @@ anno_mark = function(at, labels, which = c("column", "row"),
2465 2503
 		width = unit(1, "npc")
2466 2504
 	}
2467 2505
 
2506
+	# a map between row index and positions
2507
+	# pos_map = 
2468 2508
 	row_fun = function(index) {
2469 2509
 		n = length(index)
2470 2510
 		# adjust at and labels
... ...
@@ -2558,7 +2598,9 @@ anno_mark = function(at, labels, which = c("column", "row"),
2558 2598
 		width = width,
2559 2599
 		height = height,
2560 2600
 		n = -1,
2561
-		var_import = list(at, labels2index, at2labels, link_gp, labels_gp, padding, side, link_width, extend)
2601
+		var_import = list(at, labels2index, at2labels, link_gp, labels_gp, padding, 
2602
+			side, link_width, extend),
2603
+		show_name = FALSE
2562 2604
 	)
2563 2605
 
2564 2606
 	anno@subset_rule$at = subset_by_intersect
... ...
@@ -767,7 +767,8 @@ Heatmap = function(matrix, col, name,
767 767
         ),
768 768
 
769 769
         layout_index = NULL,
770
-        graphic_fun_list = list()
770
+        graphic_fun_list = list(),
771
+        initialized = FALSE
771 772
     )
772 773
 
773 774
     .Object@heatmap_param$width = width
... ...
@@ -1397,6 +1398,10 @@ setMethod(f = "prepare",
1397 1398
     signature = "Heatmap",
1398 1399
     definition = function(object, process_rows = TRUE, process_columns = TRUE) {
1399 1400
 
1401
+    if(object@layout$initialized) {
1402
+        return(object)
1403
+    }
1404
+    
1400 1405
     if(process_rows) {
1401 1406
         object = make_row_cluster(object)
1402 1407
     }
... ...
@@ -24,6 +24,10 @@ setMethod(f = "make_layout",
24 24
     signature = "Heatmap",
25 25
     definition = function(object) {
26 26
 
27
+    if(object@layout$initialized) {
28
+        return(object)
29
+    }
30
+
27 31
     # position of each row-slice
28 32
     row_gap = object@matrix_param$row_gap
29 33
     column_gap = object@matrix_param$column_gap
... ...
@@ -432,6 +436,8 @@ setMethod(f = "make_layout",
432 436
     object@heatmap_param$width_is_absolute_unit = is_abs_unit(object@heatmap_param$width) 
433 437
     object@heatmap_param$height_is_absolute_unit = is_abs_unit(object@heatmap_param$height) 
434 438
     
439
+    object@layout$initialized = TRUE
440
+
435 441
     return(object)
436 442
 })
437 443
 
... ...
@@ -133,6 +133,9 @@ HeatmapAnnotation = function(...,
133 133
     called_args = names(arg_list)
134 134
     anno_args = setdiff(called_args, fun_args)
135 135
     if(any(anno_args == "")) stop("annotations should have names.")
136
+    if(is.null(called_args)) {
137
+    	stop_wrap("It seems you are putting only one argument to the function. If it is a simple vector annotation, specify it as HeatmapAnnotation(name = value). If it is a data frame annotation, specify it as HeatmapAnnotation(df = value)")
138
+    }
136 139
 
137 140
     ##### pull all annotation to `anno_value_list`####
138 141
     if("df" %in% called_args) {
... ...
@@ -563,8 +566,8 @@ setMethod(f = "draw",
563 566
 
564 567
     if(test2) {
565 568
     	grid.newpage()
566
-    	if(which == "column") pushViewport(viewport(width = unit(1, "npc") - unit(4, "cm"), height = object@height))
567
-    	if(which == "row") pushViewport(viewport(height = unit(1, "npc") - unit(4, "cm"), width = object@width))
569
+    	if(which == "column") pushViewport(viewport(width = unit(1, "npc") - unit(2, "cm"), height = object@height))
570
+    	if(which == "row") pushViewport(viewport(height = unit(1, "npc") - unit(2, "cm"), width = object@width))
568 571
     } else {
569 572
 		pushViewport(viewport(...))
570 573
 	}
... ...
@@ -579,7 +582,12 @@ setMethod(f = "draw",
579 582
 			}
580 583
 		})
581 584
 		len = len[!is.na(len)]
582
-		if(length(len)) index = seq_len(len[1])
585
+		if(length(len)) {
586
+			index = seq_len(len[1])
587
+		} 
588
+		if(!length(index)) {
589
+			stop("Cannot infer the number of observations of the annotation.")
590
+		}
583 591
     }
584 592
 
585 593
 	if(which == "column") {
... ...
@@ -607,10 +615,10 @@ setMethod(f = "draw",
607 615
 	}
608 616
 	if(test2) {
609 617
         grid.text(test, y = unit(1, "npc") + unit(2, "mm"), just = "bottom")
610
-        grid.rect(unit(0, "npc") - object@extended[2], unit(0, "npc") - object@extended[1], 
611
-            width = unit(1, "npc") + object@extended[2] + object@extended[4],
612
-            height = unit(1, "npc") + object@extended[1] + object@extended[3],
613
-            just = c("left", "bottom"), gp = gpar(fill = "transparent", col = "red", lty = 2))
618
+        # grid.rect(unit(0, "npc") - object@extended[2], unit(0, "npc") - object@extended[1], 
619
+        #     width = unit(1, "npc") + object@extended[2] + object@extended[4],
620
+        #     height = unit(1, "npc") + object@extended[1] + object@extended[3],
621
+        #     just = c("left", "bottom"), gp = gpar(fill = "transparent", col = "red", lty = 2))
614 622
     }
615 623
 	upViewport()
616 624
 })
... ...
@@ -523,6 +523,16 @@ setMethod(f = "draw_heatmap_list",
523 523
             } else if(inherits(ht, "HeatmapAnnotation")) {
524 524
                 # calcualte the position of the heatmap body
525 525
                 pushViewport(viewport(y = max_bottom_component_height, height = unit(1, "npc") - max_top_component_height - max_bottom_component_height, just = c("bottom")))
526
+                # if(length(ht) == 1 & n_slice > 1) {
527
+                #     if(inherits(ht@anno_list[[1]], "AnnotationFunction")) {
528
+                #         if(identical(ht@anno_list[[1]]@fun@fun_name, "anno_mark")) {
529
+                #             # adjust pos_map var
530
+                #             fun = ht@anno_list[[1]]@fun
531
+                #             draw(fun, index = ht_main@row_order)
532
+                #             next
533
+                #         }
534
+                #     }
535
+                # }
526 536
                 for(j in seq_len(n_slice)) {
527 537
                     draw(ht, index = ht_main@row_order_list[[j]], y = slice_y[j], height = slice_height[j], just = slice_just[2], k = j, n = n_slice)
528 538
                 }
... ...
@@ -229,6 +229,8 @@ SingleAnnotation = function(name, value, col, fun,
229 229
         if(inherits(fun, "AnnotationFunction")) {
230 230
             anno_fun_extend = fun@extended
231 231
             if(verbose) qqcat("@{name}: annotation is a AnnotationFunction object\n")
232
+
233
+            show_name = fun@show_name
232 234
         } else {
233 235
             fun = AnnotationFunction(fun = fun)
234 236
             anno_fun_extend = fun@extended
... ...
@@ -752,7 +752,7 @@ horizontal_continuous_legend_body = function(at, labels = at, col_fun,
752 752
 # draw(pd, test = "two legends")
753 753
 # pd = packLegend(lgd1, lgd2, direction = "horizontal")
754 754
 # draw(pd, test = "two legends packed horizontally")
755
-packLegend = function(..., row_gap = unit(2, "mm"), column_gap = unit(2, "mm"),
755
+packLegend = function(...,gap = unit(2, "mm"), row_gap = unit(2, "mm"), column_gap = unit(2, "mm"),
756 756
 	direction = c("vertical", "horizontal"),
757 757
 	max_width = NULL, max_height = NULL, list = NULL) {
758 758
 
... ...
@@ -771,6 +771,16 @@ packLegend = function(..., row_gap = unit(2, "mm"), column_gap = unit(2, "mm"),
771 771
 		lgd
772 772
 	})
773 773
 	direction = match.arg(direction)
774
+	if(direction == "vertical") {
775
+		if(missing(row_gap)) {
776
+			row_gap = gap
777
+		}
778
+	}
779
+	if(direction == "horizontal") {
780
+		if(missing(column_gap)) {
781
+			column_gap = gap
782
+		}
783
+	}
774 784
 	if(length(row_gap) != 1) {
775 785
 		stop("Length of `row_gap` must be one.")
776 786
 	}
... ...
@@ -279,7 +279,12 @@ rep.list = function(x, n) {
279 279
 
280 280
 
281 281
 list_component = function() {
282
-    vp_name = grid.ls(viewports = TRUE, grobs = FALSE, print = FALSE)$name
282
+    vp = grid.ls(viewports = TRUE, grob = FALSE, flatten = FALSE, print = FALSE)
283
+    vp = unlist(vp)
284
+    attributes(vp) = NULL
285
+    vp = vp[!grepl("^\\d+$", vp)]
286
+    vp = vp[!grepl("GRID.VP", vp)]
287
+    unique(vp)
283 288
 }
284 289
 
285 290
 # == title
... ...
@@ -361,13 +366,13 @@ dev.null = function(...) {
361 366
 stop_wrap = function (...) {
362 367
     x = paste0(...)
363 368
     x = paste(strwrap(x), collapse = "\n")
364
-    stop(x)
369
+    stop(x, call. = FALSE)
365 370
 }
366 371
 
367 372
 warning_wrap = function (...) {
368 373
     x = paste0(...)
369 374
     x = paste(strwrap(x), collapse = "\n")
370
-    warning(x)
375
+    warning(x, call. = FALSE)
371 376
 }
372 377
 
373 378
 message_wrap = function (...) {
... ...
@@ -123,6 +123,18 @@ draw(anno, test = "matrix")
123 123
 anno = anno_lines(cbind(c(1:5, 1:5), c(5:1, 5:1)), gp = gpar(col = 2:3),
124 124
 	add_points = TRUE, pt_gp = gpar(col = 5:6), pch = c(1, 16))
125 125
 draw(anno, test = "matrix")
126
+anno = anno_lines(sort(rnorm(10)), height = unit(2, "cm"), smooth = TRUE, add_points = TRUE)
127
+draw(anno, test = "anno_lines, smooth")
128
+anno = anno_lines(cbind(sort(rnorm(10)), sort(rnorm(10), decreasing = TRUE)), 
129
+	height = unit(2, "cm"), smooth = TRUE, add_points = TRUE, gp = gpar(col = 2:3))
130
+draw(anno, test = "anno_lines, smooth, matrix")
131
+
132
+anno = anno_lines(sort(rnorm(10)), width = unit(2, "cm"), smooth = TRUE, add_points = TRUE, which = "row")
133
+draw(anno, test = "anno_lines, smooth, by row")
134
+anno = anno_lines(cbind(sort(rnorm(10)), sort(rnorm(10), decreasing = TRUE)), 
135
+	width = unit(2, "cm"), smooth = TRUE, add_points = TRUE, gp = gpar(col = 2:3), which = "row")
136
+draw(anno, test = "anno_lines, smooth, matrix, by row")
137
+
126 138
 
127 139
 
128 140
 ###### test anno_text #######
... ...
@@ -246,4 +258,5 @@ anno = anno_mark(at = c(1:4, 20, 60, 97:100), labels = month.name[1:10], which =
246 258
 Heatmap(m, cluster_rows = F, cluster_columns = F) + rowAnnotation(mark = anno)
247 259
 Heatmap(m) + rowAnnotation(mark = anno)
248 260
 
249
-
261
+ht_list = Heatmap(m, cluster_rows = F, cluster_columns = F) + rowAnnotation(mark = anno)
262
+draw(ht_list, row_split = c(rep("a", 95), rep("b", 5)))
... ...
@@ -64,3 +64,8 @@ draw(ha, test = "complex annotations on row")
64 64
 
65 65
 ## test row annotation with no heatmap
66 66
 rowAnnotation(foo = 1:10, bar = anno_points(10:1))
67
+
68
+
69
+HeatmapAnnotation(1:10)
70
+
71
+HeatmapAnnotation(data.frame(1:10))