Browse code

update code for adjusting oncoprint

Zuguang Gu authored on 28/10/2018 18:31:16
Showing5 changed files

... ...
@@ -20,6 +20,7 @@
20 20
 # ht[, 1:5]
21 21
 # ht[1:5, 1:5]
22 22
 "[.Heatmap" = function(x, i, j) {
23
+    
23 24
     if(nargs() == 2) {
24 25
         subset_heatmap_by_row(x, i)
25 26
     } else {
... ...
@@ -46,6 +47,9 @@ subset_heatmap_by_row = function(ht, ind) {
46 47
         ht@row_names_param$anno = ht@row_names_param$anno[ind]
47 48
     }
48 49
     ht@row_names_param$gp = subset_gp(ht@row_names_param$gp, ind)
50
+    if(!is.null(ht@matrix_param$row_split)) {
51
+        ht@matrix_param$row_split = ht@matrix_param$row_split[ind, , drop = FALSE]
52
+    }
49 53
     if(length(ht@left_annotation)) {
50 54
         ht@left_annotation = ht@left_annotation[ind]
51 55
     }
... ...
@@ -66,6 +70,9 @@ subset_heatmap_by_column = function(ht, ind) {
66 70
         ht@column_names_param$anno = ht@column_names_param$anno[ind]
67 71
     }
68 72
     ht@column_names_param$gp = subset_gp(ht@column_names_param$gp, ind)
73
+    if(!is.null(ht@matrix_param$column_split)) {
74
+        ht@matrix_param$column_split = ht@matrix_param$column_split[, ind, drop = FALSE]
75
+    }
69 76
     if(length(ht@top_annotation)) {
70 77
         ht@top_annotation = ht@top_annotation[ind]
71 78
     }
... ...
@@ -818,7 +818,17 @@ setMethod(f = "add_heatmap",
818 818
 # ha
819 819
 c.HeatmapAnnotation = function(..., gap = unit(0, "mm")) {
820 820
 	anno_list = list(...)
821
+	if(length(anno_list) == 1) {
822
+		return(anno_list[[1]])
823
+	}
824
+	# remove NULL
825
+	anno_list = anno_list[ !sapply(anno_list, is.null) ]
826
+	if(length(anno_list) == 1) {
827
+		return(anno_list[[1]])
828
+	}
829
+
821 830
 	n = length(anno_list)
831
+
822 832
 	if(length(unique(sapply(anno_list, function(x) x@which))) != 1) {
823 833
 		stop_wrap("All annotations should be all row annotation or all column annotation.")
824 834
 	}
... ...
@@ -15,14 +15,20 @@
15 15
 # -col A vector of color for which names correspond to alteration types.
16 16
 # -top_annotation Annotation put on top of the oncoPrint. By default it is barplot which shows the number of genes with a certain alteration in each sample.
17 17
 # -right_annotation Annotation put on the right of the oncoPrint. By default it is barplot which shows the number of samples with a certain alteration in each gene.
18
+# -left_annotation
18 19
 # -bottom_annotation Annotation put at the bottom of the oncoPrint.
19 20
 # -show_pct whether show percent values on the left of the oncoprint?
20 21
 # -pct_gp Graphic paramters for percent values
21 22
 # -pct_digits Digits for the percent values.
22 23
 # -pct_side Side of the percent values to the oncoPrint. This argument is currently disabled.
24
+# -row_labels
23 25
 # -show_row_names Whether show row names?
24 26
 # -row_names_side Side of the row names to the oncoPrint. This argument is currently disabled.
25 27
 # -row_names_gp Graphic parameters for the row names.
28
+# -row_split
29
+# -column_labels
30
+# -column_names_gp
31
+# -column_split
26 32
 # -row_order Order of rows. By default rows are sorted by the number of occurence of the alterations.
27 33
 # -column_order Order of columns. By default the columns are sorted to show the mutual exclusivity of alterations.
28 34
 # -remove_empty_columns If there is no alteration in some samples, whether remove them on the oncoPrint?
... ...
@@ -50,18 +56,25 @@ oncoPrint = function(mat,
50 56
 	alter_fun_is_vectorized = NULL,
51 57
 	col, 
52 58
 
53
-	top_annotation = HeatmapAnnotation(column_barplot = anno_oncoprint_barplot()),
54
-	right_annotation = rowAnnotation(row_barplot = anno_oncoprint_barplot(
55
-			axis_param = list(side = "top", labels_rot = 0))),
59
+	top_annotation = HeatmapAnnotation(cbar = anno_oncoprint_barplot()),
60
+	right_annotation = rowAnnotation(rbar = anno_oncoprint_barplot()),
61
+	left_annotation = NULL,
56 62
 	bottom_annotation = NULL,
57 63
 
58 64
 	show_pct = TRUE, 
59 65
 	pct_gp = gpar(fontsize = 10), 
60 66
 	pct_digits = 0,
61 67
 	pct_side = "left",
68
+
69
+	row_labels = NULL,
62 70
 	show_row_names = TRUE,
63 71
 	row_names_side = "right",
64 72
 	row_names_gp = pct_gp,
73
+	row_split = NULL,
74
+
75
+	column_labels = NULL,
76
+	column_names_gp = gpar(fontsize = 10),
77
+	column_split = NULL,
65 78
 
66 79
 	row_order = NULL,
67 80
 	column_order = NULL,
... ...
@@ -72,13 +85,9 @@ oncoPrint = function(mat,
72 85
 	heatmap_legend_param = list(title = "Alterations"),
73 86
 	...) {
74 87
 
75
-	arg_list = list(...)
88
+	arg_list = as.list(match.call())[-1]
76 89
 	arg_names = names(arg_list)
77 90
 
78
-	oe = environment(anno_oncoprint_barplot)
79
-	environment(anno_oncoprint_barplot) = environment()
80
-	on.exit(environment(anno_oncoprint_barplot) <- oe)
81
-
82 91
 	# convert mat to mat_list
83 92
 	if(inherits(mat, "data.frame")) {
84 93
 		mat = as.matrix(mat)
... ...
@@ -269,15 +278,31 @@ oncoPrint = function(mat,
269 278
 		column_order = structure(seq_len(dim(arr)[2]), names = dimnames(arr)[[2]])[column_order]
270 279
 	}
271 280
 	names(column_order) = as.character(column_order)
281
+
282
+	l_non_empty_column = rowSums(apply(arr, c(2, 3), sum)) > 0
283
+	l_non_empty_row = rowSums(apply(arr, c(1, 3), sum)) > 0
284
+
285
+	if(is.null(row_labels)) row_labels = dimnames(arr)[[1]]
272 286
 	if(remove_empty_columns) {
273
-		l = rowSums(apply(arr, c(2, 3), sum)) > 0
274
-		arr = arr[, l, , drop = FALSE]
275
-		column_order = structure(seq_len(sum(l)), names = which(l))[as.character(intersect(column_order, which(l)))]
287
+		arr = arr[, l_non_empty_column, , drop = FALSE]
288
+		column_order = structure(seq_len(sum(l_non_empty_column)), names = which(l_non_empty_column))[as.character(intersect(column_order, which(l_non_empty_column)))]
289
+		if(!is.null(column_labels)) column_labels = column_labels[l_non_empty_column]
290
+		if(!is.null(column_split)) {
291
+			if(is.atomic(column_split)) column_split = data.frame(column_split)
292
+			column_split = column_split[l_non_empty_column, , drop = FALSE]
293
+		}
294
+		column_names_gp = subset_gp(column_names_gp, l_non_empty_column)
276 295
 	}
296
+	if(is.null(column_labels)) column_labels = dimnames(arr)[[2]]
277 297
 	if(remove_empty_rows) {
278
-		l = rowSums(apply(arr, c(1, 3), sum)) > 0
279
-		arr = arr[l, , , drop = FALSE]
280
-		row_order = structure(seq_len(sum(l)), names = which(l))[as.character(intersect(row_order, which(l)))]
298
+		arr = arr[l_non_empty_row, , , drop = FALSE]
299
+		row_order = structure(seq_len(sum(l_non_empty_row)), names = which(l_non_empty_row))[as.character(intersect(row_order, which(l_non_empty_row)))]
300
+		if(!is.null(row_labels)) row_labels = row_labels[l_non_empty_row]
301
+		if(!is.null(row_split)) {
302
+			if(is.atomic(row_split)) row_split = data.frame(row_split)
303
+			row_split = row_split[l_non_empty_row, , drop = FALSE]
304
+		}
305
+		row_names_gp = subset_gp(row_names_gp, l_non_empty_row)
281 306
 	}
282 307
 
283 308
 	# validate col
... ...
@@ -291,26 +316,68 @@ oncoPrint = function(mat,
291 316
 	pct = paste0(round(pct_num * 100, digits = pct_digits), "%")
292 317
 
293 318
 	### now the annotations
294
-	err = try(top_annotation <- eval(substitute(top_annotation)), silent = TRUE)
295
-	if(inherits(err, "try-error")) {
296
-		stop_wrap("find an error when executing top_annotation. ")
297
-	}
298
-	right_annotation = eval(substitute(right_annotation))
319
+	top_annotation = top_annotation
320
+	right_annotation = right_annotation
299 321
 
300
-	if("left_annotation" %in% arg_names) {
301
-		stop_wrap("'left_annotation' are not allowed to specify, you can add...")
322
+	if(show_pct && show_row_names) {
323
+		if(pct_side == row_names_side) {
324
+			stop_wrap("Percent values and row names should be at different side of the oncoPrint.")
325
+		}
302 326
 	}
303
-	left_annotation = NULL
327
+
304 328
 	if(show_pct) {
305
-		left_annotation = rowAnnotation(pct = anno_text(pct, just = "right", location = unit(1, "npc"), gp = pct_gp),
306
-			show_annotation_name = FALSE)
329
+		pct_ha = rowAnnotation(pct = anno_text(pct, just = "right", location = unit(1, "npc"), gp = pct_gp),
330
+				show_annotation_name = FALSE)
331
+		names(pct_ha) = paste0("pct_", random_str())
332
+	} else {
333
+		pct_ha = NULL
307 334
 	}
308 335
 	if(show_row_names) {
309
-		ha_row_names = rowAnnotation(rownames = anno_text(dimnames(arr)[[1]], gp = pct_gp, just = "left", location = unit(0, "npc")),
336
+		rn_ha = rowAnnotation(rownames = anno_text(row_labels, gp = pct_gp, just = "left", location = unit(0, "npc")),
310 337
 			show_annotation_name = FALSE)
311
-		right_annotation = c(ha_row_names, right_annotation, gap = unit(2, "mm"))
338
+		names(rn_ha) = paste0("rownames_", random_str())
339
+	} else {
340
+		rn_ha = NULL
341
+	}
342
+	
343
+	if(is.null(left_annotation)) {
344
+		if(pct_side == "left") {
345
+			left_annotation = pct_ha
346
+		}
347
+		if(row_names_side == "left") {
348
+			left_annotation = rn_ha
349
+		}
350
+	} else {
351
+		if(remove_empty_rows) {
352
+			left_annotation = left_annotation[l_non_empty_row, ]
353
+		}
354
+		if(pct_side == "left") {
355
+			left_annotation = c(left_annotation, pct_ha)
356
+		}
357
+		if(row_names_side == "left") {
358
+			left_annotation = c(left_annotation, rn_ha)
359
+		}
312 360
 	}
313 361
 
362
+	if(is.null(right_annotation)) {
363
+		if(pct_side == "right") {
364
+			right_annotation = pct_ha
365
+		}
366
+		if(row_names_side == "right") {
367
+			right_annotation = rn_ha
368
+		}
369
+	} else {
370
+		if(remove_empty_rows) {
371
+			right_annotation = right_annotation[l_non_empty_row, ]
372
+		}
373
+		if(pct_side == "right") {
374
+			right_annotation = c(pct_ha, right_annotation)
375
+		}
376
+		if(row_names_side == "right") {
377
+			right_annotation = c(rn_ha, right_annotation)
378
+		}
379
+	}
380
+	
314 381
 	#####################################################################
315 382
 	# the main matrix
316 383
 	pheudo = c(all_type, rep(NA, nrow(arr)*ncol(arr) - length(all_type)))
... ...
@@ -336,10 +403,14 @@ oncoPrint = function(mat,
336 403
 		heatmap_legend_param = heatmap_legend_param,
337 404
 		...
338 405
 	)
406
+	ht@heatmap_param$oncoprint_env = environment()
339 407
 
340 408
 	return(ht)
341 409
 }
342 410
 
411
+ONCOPRINT_ENV = new.env()
412
+ONCOPRINT_ENV$fun_env = NULL
413
+
343 414
 # == title
344 415
 # Unify a List of Matrix 
345 416
 #
... ...
@@ -389,7 +460,9 @@ unify_mat_list = function(mat_list, default = 0) {
389 460
 # == author
390 461
 # Zuguang Gu <z.gu@dkfz.de>
391 462
 #
392
-anno_oncoprint_barplot = function(type = all_type, which = c("column", "row"),
463
+anno_oncoprint_barplot = function(type = NULL, which = c("column", "row"),
464
+	bar_width = 0.6, axis = TRUE, 
465
+	axis_param = if(which == "column") default_axis_param("column") else list(side = "top", labels_rot = 0),
393 466
 	width = NULL, height = NULL, border = FALSE, ...) {
394 467
 
395 468
 	if(is.null(.ENV$current_annotation_which)) {
... ...
@@ -399,34 +472,74 @@ anno_oncoprint_barplot = function(type = all_type, which = c("column", "row"),
399 472
 	}
400 473
 
401 474
 	anno_size = anno_width_and_height(which, width, height, unit(2, "cm"))
402
-	# get variables fron oncoPrint() function
403
-	pf = parent.env(environment())
404
-	arr = pf$arr
405
-	all_type = pf$all_type
406
-	col = pf$col
407
-
408
-	type = type
409
-	all_type = intersect(all_type, type)
410
-	if(length(all_type) == 0) {
411
-		stop_wrap("find no overlap, check your `type` argument.")
412
-	}
413
-	arr = arr[, , all_type, drop = FALSE]
414
-	col = col[all_type]
415 475
 
416
-	if(which == "column") {
476
+	column_fun = function(index, k, n) {
477
+		pf = get("object", envir = parent.frame(7))@heatmap_param$oncoprint_env
478
+		arr = pf$arr
479
+		all_type = pf$all_type
480
+		col = pf$col
481
+
482
+		if(is.null(type)) type = all_type
483
+
484
+		all_type = intersect(all_type, type)
485
+		if(length(all_type) == 0) {
486
+			stop_wrap("find no overlap, check your `type` argument.")
487
+		}
488
+		arr = arr[, , all_type, drop = FALSE]
489
+		col = col[all_type]
490
+
417 491
 		count = apply(arr, c(2, 3), sum)
418 492
 		fun = anno_barplot(count, gp = gpar(fill = col, col = NA), which = "column",
419
-			baseline = 0, height = anno_size$height, border = border, ...)
420
-	} else {
493
+			baseline = 0, height = anno_size$height, border = border, bar_width = bar_width,
494
+			axis = axis, axis_param = axis_param)@fun
495
+		fun(index, k, n)
496
+	}
497
+	row_fun = function(index, k, n) {
498
+		pf = get("object", envir = parent.frame(7))@heatmap_param$oncoprint_env
499
+		arr = pf$arr
500
+		all_type = pf$all_type
501
+		col = pf$col
502
+
503
+		if(is.null(type)) type = all_type
504
+
505
+		all_type = intersect(all_type, type)
506
+		if(length(all_type) == 0) {
507
+			stop_wrap("find no overlap, check your `type` argument.")
508
+		}
509
+		arr = arr[, , all_type, drop = FALSE]
510
+		col = col[all_type]
511
+
421 512
 		count = apply(arr, c(1, 3), sum)
422 513
 		fun = anno_barplot(count, gp = gpar(fill = col, col = NA), which = "row",
423
-			baseline = 0, width = anno_size$width, border = border, ...)
514
+			baseline = 0, width = anno_size$width, border = border, bar_width = bar_width,
515
+			axis = axis, axis_param = axis_param)@fun
516
+		fun(index, k, n)
424 517
 	}
425 518
 	
426
-	fun@show_name = FALSE
427
-	return(fun)
428
-}
519
+	if(which == "row") {
520
+		fun = row_fun
521
+	} else if(which == "column") {
522
+		fun = column_fun
523
+	}
524
+
525
+	anno = AnnotationFunction(
526
+		fun = fun,
527
+		fun_name = "anno_oncoprint_barplot",
528
+		which = which,
529
+		width = anno_size$width,
530
+		height = anno_size$height,
531
+		var_import = list(border, type, bar_width, axis, axis_param, anno_size)
532
+	)
533
+		
534
+	anno@subsetable = TRUE
535
+	anno@show_name = FALSE
536
+
537
+	axis_param = validate_axis_param(axis_param, which)
538
+	axis_grob = if(axis) construct_axis_grob(axis_param, which, c(0, 100)) else NULL
539
+	anno@extended = update_anno_extend(anno, axis_grob, axis_param)
429 540
 
541
+	return(anno) 
542
+}
430 543
 
431 544
 guess_alter_fun_is_vectorized = function(alter_fun) {
432 545
 	n = 50
... ...
@@ -773,4 +773,7 @@ grid.boxplot = function(value, pos, outline = TRUE, box_width = 0.6,
773 773
     }
774 774
 }
775 775
 
776
+random_str = function() {
777
+    paste(sample(c(letters, LETTERS, 0:9), 8), collapse = "")
778
+}
776 779
 
777 780
new file mode 100644
... ...
@@ -0,0 +1,55 @@
1
+mat = read.table(textConnection(
2
+"s1,s2,s3
3
+g1,snv;indel,snv,indel
4
+g2,,snv;indel,snv
5
+g3,snv,,indel;snv"), row.names = 1, header = TRUE, sep = ",", stringsAsFactors = FALSE)
6
+mat = as.matrix(mat)
7
+
8
+get_type_fun = function(x) strsplit(x, ";")[[1]]
9
+
10
+alter_fun = list(
11
+    snv = function(x, y, w, h) grid.rect(x, y, w*0.9, h*0.9, 
12
+        gp = gpar(fill = col["snv"], col = NA)),
13
+    indel = function(x, y, w, h) grid.rect(x, y, w*0.9, h*0.4, 
14
+        gp = gpar(fill = col["indel"], col = NA))
15
+)
16
+
17
+col = c(snv = "red", indel = "blue")
18
+oncoPrint(mat, get_type = get_type_fun,
19
+    alter_fun = alter_fun, col = col)
20
+
21
+## turn off row names while turn on column names
22
+oncoPrint(mat, get_type = get_type_fun,
23
+    alter_fun = alter_fun, col = col, 
24
+    show_column_names = TRUE, show_row_names = FALSE)
25
+
26
+oncoPrint(mat, get_type = get_type_fun,
27
+    alter_fun = alter_fun, col = col, pct_side = "right", 
28
+    row_names_side = "left")
29
+
30
+oncoPrint(mat, get_type = get_type_fun,
31
+    alter_fun = alter_fun, col = col,
32
+    top_annotation = HeatmapAnnotation(column_barplot = anno_oncoprint_barplot())
33
+)
34
+
35
+oncoPrint(mat, get_type = get_type_fun,
36
+    alter_fun = alter_fun, col = col,
37
+    top_annotation = HeatmapAnnotation(
38
+    	column_barplot = anno_oncoprint_barplot(),
39
+    	foo = 1:3,
40
+    	annotation_name_side = "left")
41
+)
42
+
43
+oncoPrint(mat, get_type = get_type_fun,
44
+    alter_fun = alter_fun, col = col,
45
+    top_annotation = HeatmapAnnotation(
46
+    	cbar = anno_oncoprint_barplot(),
47
+    	foo1 = 1:3,
48
+    	annotation_name_side = "left"),
49
+    left_annotation = rowAnnotation(foo2 = 1:3),
50
+    right_annotation = rowAnnotation(cbar = anno_oncoprint_barplot(), foo3 = 1:3),
51
+)
52
+
53
+
54
+
55
+