R/SpatialCPie.R
4f5c32ee
 #' @importFrom data.table as.data.table
e61be5d9
 #' @importFrom digest sha1
148ebd95
 #' @importFrom dplyr
7aacac3f
 #' arrange filter first group_by inner_join mutate n rename select summarize
 #' ungroup
148ebd95
 #' @importFrom ggiraph geom_point_interactive
 #' @importFrom ggplot2
358ab48f
 #' aes_ aes_string coord_fixed element_blank geom_segment ggplot ggtitle guides
 #' guide_legend
148ebd95
 #' labs
4293f8a1
 #' theme theme_bw theme_minimal
148ebd95
 #' scale_color_manual scale_fill_manual scale_size
 #' scale_x_continuous scale_y_continuous
 #' @importFrom grid unit
375d964b
 #' @importFrom methods is
148ebd95
 #' @importFrom purrr
b43b0a4d
 #' %>% %||% accumulate array_branch lift invoke keep map map_dbl map_int partial
 #' reduce transpose
148ebd95
 #' @importFrom readr read_file write_file
358ab48f
 #' @importFrom rlang !! := .data sym
148ebd95
 #' @importFrom shiny debounce observeEvent reactive
4293f8a1
 #' @importFrom shinyjs hideElement
d8d552aa
 #' @importFrom shinyWidgets radioGroupButtons materialSwitch
87eb5c45
 #' @importFrom stats dist kmeans setNames sd
d438c8dc
 #' @importFrom SummarizedExperiment assay
37dd1b03
 #' @importFrom tibble column_to_rownames rownames_to_column
 #' @importFrom tidyr gather separate spread unite
148ebd95
 #' @importFrom tidyselect everything quo
7aacac3f
 #' @importFrom tools toTitleCase
148ebd95
 #' @importFrom utils head tail
 #' @importFrom utils str
 #' @importFrom zeallot %<-%
 "_PACKAGE"
 
 
 ## Pre-declare all NSE variables as global in order to appease R CMD check
 ## (ref: https://stackoverflow.com/a/12429344)
 globalVariables(c(
     ".",
4f5c32ee
     "cluster",
     "count",
     "name",
358ab48f
     "otherMargin",
4f5c32ee
     "resolution",
     "spot",
148ebd95
     "xcoord",
     "ycoord",
     NULL
 ))
 
 
63973e57
 #' Logsumexp
 #'
 #' Adapted from https://stat.ethz.ch/pipermail/r-help/2011-February/269205.html
d8d552aa
 #' @param xs input vector
 #' @return log of summed exponentials
63973e57
 #' @keywords internal
 .logsumexp <- function(xs) {
     idx <- which.max(xs)
     log1p(sum(exp(xs[-idx] - xs[idx]))) + xs[idx]
 }
 
 
6b8dc425
 #' Likeness score
 #'
8059fda2
 #' @param d distance vector.
 #' @param c log multiplier.
 #' @return vector of scores.
6b8dc425
 #' @keywords internal
148ebd95
 .likeness <- function(
     d,
     c = 1.0
 ) {
63973e57
     exp(-c * d - .logsumexp(-c * d))
6b8dc425
 }
 
87eb5c45
 #' Z-score
 #'
 #' @param xs vector of observations
 #' @return `xs`, z-normalized. if all elements of `xs` are equal, a vector of
 #'     zeros will be returned instead.
 #' @keywords internal
 .zscore <- function(xs) {
     std <- sd(xs)
     std <- if (std == 0.0) 1 else std
     (xs - mean(xs)) / std
 }
 
148ebd95
 
a72b7c87
 #' Maximize overlap
 #'
8059fda2
 #' @param xss list of lists of labels.
b43b0a4d
 #' @return `xss`, relabeled so as to maximize the overlap between labels in
8059fda2
 #' consecutive label lists.
a72b7c87
 #' @keywords internal
148ebd95
 .maximizeOverlap <- function(
b43b0a4d
     xss
148ebd95
 ) {
b43b0a4d
     maximumOverlap <- function(xs, ys) {
         setNames(nm = sort(unique(xs))) %>%
             map(function(x)
                 setNames(nm = sort(unique(ys))) %>%
                     map_dbl(function(y) sum(`*`(xs == x, ys == y)))
             ) %>%
             invoke(rbind, .) %>%
             (function(overlaps) {
                 all <- union(rownames(overlaps), colnames(overlaps))
                 n <- length(all)
 
                 ## Zero-pad overlap matrix so that all labels are represented in
                 ## both the to and from dimensions
4c653096
                 paddedOverlaps <-
                     overlaps %>%
584808bf
                     rbind(do.call(
                         rbind,
                         rep(list(rep(0, n)), n - nrow(overlaps))
                     )) %>%
                     cbind(do.call(
                         cbind,
                         rep(list(rep(0, n)), n - ncol(overlaps))
                     ))
b43b0a4d
                 rownames(paddedOverlaps)[rownames(paddedOverlaps) == ""] <-
                     setdiff(all, rownames(paddedOverlaps))
                 colnames(paddedOverlaps)[colnames(paddedOverlaps) == ""] <-
                     setdiff(all, colnames(paddedOverlaps))
 
                 ## Solve the assignment problem to maximize the overlap
                 lpSolve::lp.assign(-paddedOverlaps)$solution %>%
                     array_branch(2) %>%
                     map(~colnames(paddedOverlaps)[which.max(.)]) %>%
                     setNames(nm = rownames(paddedOverlaps))
             })
     }
 
83ffb18b
     ## Convert cluster labels to natural numbers
     xss <- map(
         xss,
         function(x) setNames(as.character(as.integer(as.factor(x))), names(x))
     )
 
b43b0a4d
     ## Compute reassignment map between each label pair
     reassignments <-
         list(unname(head(xss, -1)), unname(tail(xss, -1))) %>%
a72b7c87
         transpose %>%
b43b0a4d
         map(lift(maximumOverlap))
 
     ## Sync reassignments by propagating them forward
1e933e9b
     reassignments <- accumulate(
         reassignments,
         function(prev, cur) {
             list(lapply(cur, function(x) {
                 if (x %in% names(prev[[1]])) prev[[1]][[x]]
                 else x
             }))
         },
         .init = list(setNames(nm = unique(xss[[1]])))
     )
a72b7c87
 
b43b0a4d
     ## Apply reassignments
1e933e9b
     list(xss, reassignments) %>%
         transpose() %>%
         map(lift(function(xs, reassignment) {
             vapply(xs, function(x) reassignment[[x]], character(1))
         })) %>%
b43b0a4d
         setNames(names(xss))
6b8dc425
 }
 
148ebd95
 
9e90795e
 #' Tidy assignments
 #'
 #' @param assignments list of assignment vectors.
8059fda2
 #' @return a \code{\link[base]{data.frame}} containing the `assignments`, with
e61be5d9
 #' the data relabeled so that the overlap between consecutive assignment
 #' vectors is maximized. Additionally, a "root" resolution is added.
9e90795e
 #' @keywords internal
 .tidyAssignments <- function(
     assignments
 ) {
87dddbb5
     if (length(assignments) == 0) {
         stop("Need at least one resolution")
     }
 
e61be5d9
     ## Add "root" resolution
     units <- names(assignments[[1]])
     assignments <- c(
         list("root" = setNames(rep(1, length(assignments[[1]])), nm = units)),
         assignments
9e90795e
     )
 
     ## Relabel the data to maximize overlap between labels in consecutive
     ## resolutions
e5393b23
     message("Maximizing label overlap in consecutive resolutions")
9e90795e
     assignments <- .maximizeOverlap(assignments)
 
37dd1b03
     ## Concatenate assignments to `data.frame`
     assignments <-
         list(names(assignments), assignments) %>%
         transpose() %>%
         map(lift(function(res, xs)
             data.frame(
d8d0ce3b
                 name = sprintf("resolution %s, cluster %s", res, xs),
37dd1b03
                 resolution = res,
9fa84bca
                 cluster = xs,
                 stringsAsFactors = TRUE
37dd1b03
             ) %>%
             tibble::rownames_to_column("unit")
         )) %>%
         invoke(rbind, .)
 
9e90795e
     assignments
 }
 
 
 #' Compute cluster colors
 #'
 #' Computes colors so that dissimilar clusters are far away in color space.
8059fda2
 #' @param clusterMeans matrix of size `(n, K)` representing the `n` feature
 #' means for each of the `K` clusters.
9e90795e
 #' @return vector of cluster colors.
 #' @keywords internal
 .computeClusterColors <- function(
ad889850
     clusterMeans
9e90795e
 ) {
b43b0a4d
     clusterLoadings <- stats::prcomp(
ad889850
         t(clusterMeans),
b43b0a4d
         rank = 2,
         center = TRUE
     )$x
     minLoading <- apply(clusterLoadings, 2, min)
     maxLoading <- apply(clusterLoadings, 2, max)
 
     clusterColors <- cbind(
         50,
d8d552aa
         200 * t(
             (t(clusterLoadings) - minLoading)
             / (maxLoading - minLoading + 1e-10))
584808bf
         - 100
b43b0a4d
     )
 
     colorspace::LAB(clusterColors) %>%
         colorspace::hex(fixup = TRUE)
9e90795e
 }
 
 
 #' Preprocess data
 #'
 #' Preprocesses input data for \code{\link{.makeServer}}.
8059fda2
 #' @param counts count matrix. `rownames` should correspond to genes and
 #' `colnames` should correspond to spot coordinates.
 #' @param margin which margin of the count matrix to cluster. Valid values are
 #' `c("spot", "sample", "gene", "feature")`.
8af7a64b
 #' @param resolutions vector of resolutions to cluster.
8059fda2
 #' @param assignmentFunction function to compute cluster assignments. The
 #' function should have the following signature: integer (number of clusters) ->
 #' (m, n) feature matrix -> m-length vector (cluster assignment of each data
 #' point).
 #' @param coordinates optional \code{\link[base]{data.frame}} with pixel
 #' coordinates for each spot. `rownames` should correspond to the `colnames` of
 #' `counts` and the columns `x` and `y` should specify the pixel coordinates of
 #' the spots.
9e90795e
 #' @return list with the following elements:
 #' - `$assignments`: tidy assignments
7aacac3f
 #' - `$means`: cluster means
8af7a64b
 #' - `$scores`: cluster scores for each spot in each resolution
9e90795e
 #' - `$colors`: cluster colors
 #' - `$coordinates`: spot coordinates, either from `coordinates` or parsed from
 #' `assignments`
7aacac3f
 #' - `$featureName`: name of the clustered feature (the "opposite" of `margin`)
9e90795e
 #' @keywords internal
 .preprocessData <- function(
     counts,
ad889850
     margin,
     resolutions,
8af7a64b
     assignmentFunction,
     coordinates = NULL
9e90795e
 ) {
ad889850
     spotNames <- c("spot", "sample")
     geneNames <- c("gene", "feature")
     c(margin, otherMargin) %<-% {
         if (margin %in% spotNames) list("spot", "gene")
         else if (margin %in% geneNames) list("gene", "spot")
584808bf
         else stop(sprintf(
             "invalid margin '%s' (must be one of: %s)",
             margin,
             paste(c(spotNames, geneNames), collapse = ", ")
         ))
ad889850
     }
9e90795e
 
ad889850
     spots <- colnames(counts)
9e90795e
     if (!is.null(coordinates)) {
         spots <- intersect(spots, rownames(coordinates))
ad889850
         counts <- counts[, spots]
9e90795e
     } else {
         c(xcoord, ycoord) %<-% {
b43b0a4d
             strsplit(spots, "x") %>%
                 transpose %>%
                 map(as.numeric)
         }
9e90795e
         coordinates <- as.data.frame(cbind(x = xcoord, y = ycoord))
         rownames(coordinates) <- spots
     }
 
ad889850
     assignments <-
         resolutions %>%
e5393b23
         map(function(r) {
e61be5d9
             message(sprintf("Clustering resolution %s", deparse(r)))
e5393b23
             assignmentFunction(
                 r,
                 if (margin == "spot") t(counts)
                 else {
                     log(as.matrix(counts) + 1) %>%
                         prop.table(margin = 2) %>%
                         apply(1, .zscore) %>%
                         t()
                 }
             )
         }) %>%
ad889850
         setNames(resolutions) %>%
         .tidyAssignments() %>%
358ab48f
         rename(!! sym(margin) := .data$unit)
ad889850
 
     longCounts <-
         counts %>%
         as.data.frame() %>%
         rownames_to_column("gene") %>%
358ab48f
         gather("spot", "count", -.data$gene)
ad889850
 
     clusterMeans <-
b43b0a4d
         assignments %>%
ad889850
         inner_join(longCounts, by = margin) %>%
358ab48f
         group_by(
             .data$name,
             .data$resolution,
             .data$cluster,
             !! sym(otherMargin)
         ) %>%
         summarize(mean = mean(.data$count)) %>%
ad889850
         ungroup()
 
     colors <-
358ab48f
         clusterMeans %>%
         select(
             .data$name,
             .data$mean,
             !! sym(otherMargin)
         ) %>%
         spread(.data$name, .data$mean) %>%
ad889850
         as.data.frame() %>%
         column_to_rownames(otherMargin) %>%
         .computeClusterColors()
9e90795e
 
e5393b23
     message("Scoring spot-cluster affinity")
37dd1b03
     scores <-
ad889850
         if (margin == "spot") {
4f5c32ee
             countsAndMeans <-
                 longCounts %>%
                 inner_join(clusterMeans, by = "gene")
             distances <- as.data.table(countsAndMeans)[,
                 .(distance = sqrt(mean((count - mean) ^ 2))),
                 by = .(resolution, cluster, spot, name)
             ]
             distances %>%
358ab48f
                 group_by(.data$resolution, .data$spot) %>%
3e14a354
                 mutate(
                     score = .likeness(.data$distance / sum(.data$distance),
                     c = 40.
                 )) %>%
358ab48f
                 ungroup() %>%
                 select(-.data$distance)
ad889850
         } else {
f9ff7fa1
             normalizedCounts <-
                 longCounts %>%
3e5f51d0
                 mutate(count = log(.data$count + 1)) %>%
f9ff7fa1
                 group_by(.data$spot) %>%
                 mutate(count = .data$count / sum(.data$count)) %>%
                 group_by(.data$gene) %>%
                 mutate(count = .data$count / sum(.data$count)) %>%
                 ungroup()
ad889850
             assignments %>%
f9ff7fa1
                 inner_join(normalizedCounts, by = "gene") %>%
358ab48f
                 group_by(
                     .data$resolution,
                     .data$spot,
                     .data$cluster,
                     .data$name
                 ) %>%
f9ff7fa1
                 summarize(score = mean(.data$count)) %>%
ad889850
                 ungroup()
         }
37dd1b03
 
83d3a324
     normalizedScores <-
         scores %>%
         group_by(.data$resolution, .data$spot) %>%
63973e57
         mutate(score = .data$score / max(.data$score)) %>%
83d3a324
         ungroup()
 
9e90795e
     list(
584808bf
         assignments = assignments %>% rename(unit = !! sym(margin)),
d1b5ab40
         counts = longCounts,
7aacac3f
         means = clusterMeans,
83d3a324
         scores = normalizedScores,
584808bf
         colors = colors,
7aacac3f
         coordinates = coordinates,
         featureName = otherMargin
9e90795e
     )
 }
 
 
d1b5ab40
 #' SVG barplot
 #'
 #' @param xs named vector with observations
 #' @return \code{\link{character}} SVG barplot
 #' @keywords internal
 .SVGBarplot <- function(xs) {
     invoke(
         paste,
         sprintf(
             paste0(
                 "<svg width=\"20em\" height=\"1.5em\">",
                 paste0(
                     "<rect width=\"%f%%\" height=\"1.5em\" ",
                     "style=\"fill:rgb(125,125,125)\"></rect>"
                 ),
                 paste0(
                     "<text x=\"2%%\" y=\"50%%\" fill=\"black\"",
                     "dominant-baseline=\"central\">%s</text>"
                 ),
                 paste0(
                     "<text x=\"%f%%\" y=\"50%%\" fill=\"white\"",
                     "dominant-baseline=\"central\" >%.2f</text>"
                 ),
                 "</svg>"
             ),
             70 * xs / max(xs),
             names(xs),
             70 * xs / max(xs) + 2,
             xs
         ),
         sep="\n"
     )
 }
 
 
6b8dc425
 #' Array pie plot
 #'
8059fda2
 #' @param scores \code{\link[base]{data.frame}} with cluster scores for each
 #' spot containing the columns `"spot"`, `"name"`, and `"score"`.
 #' @param coordinates \code{\link[base]{data.frame}} with `rownames` matching
 #' those in `scores` and columns `"x"` and `"y"` specifying the plotting
 #' position of each observation.
90dc203c
 #' @param image a \code{\link[grid]{grid.grob}} to use as background to the
8059fda2
 #' plots.
41968210
 #' @param scoreMultiplier log multiplication factor applied to the score vector.
8059fda2
 #' @param spotScale pie chart size.
 #' @param spotOpacity pie chart opacity.
 #' @return \code{\link[ggplot2]{ggplot}} object of the pie plot.
6b8dc425
 #' @keywords internal
a72b7c87
 .arrayPlot <- function(
     scores,
     coordinates,
d1b5ab40
     counts = NULL,
a72b7c87
     image = NULL,
7c1e9436
     scoreMultiplier = 1.0,
a72b7c87
     spotScale = 1,
d1b5ab40
     spotOpacity = 1,
     numTopGenes = 5
a72b7c87
 ) {
37dd1b03
     spots <- intersect(scores$spot, rownames(coordinates))
a72b7c87
 
     r <- spotScale * min(dist(coordinates[spots, ])) / 2
 
     c(ymin, ymax) %<-% range(coordinates$y)
     c(xmin, xmax) %<-% range(coordinates$x)
795be519
     c(ymin, xmin) %<-% { c(ymin, xmin) %>% map(~. - 3 * r) }
     c(ymax, xmax) %<-% { c(ymax, xmax) %>% map(~. + 3 * r) }
a72b7c87
 
     if (!is.null(image)) {
62dcd279
         ymin <- max(ymin, 1)
a72b7c87
         ymax <- min(ymax, nrow(image$raster))
62dcd279
         xmin <- max(xmin, 1)
a72b7c87
         xmax <- min(xmax, ncol(image$raster))
 
         image$raster <- image$raster[ymin:ymax, xmin:xmax]
148ebd95
         annotation <- ggplot2::annotation_custom(image, -Inf, Inf, -Inf, Inf)
a72b7c87
     } else {
         annotation <- NULL
     }
 
     coordinates$y <- ymax - coordinates$y + ymin
 
74fbd018
     df <-
         coordinates %>%
         rownames_to_column("spot") %>%
         inner_join(scores, by="spot") %>%
7c1e9436
         mutate(score = .data$score ^ scoreMultiplier) %>%
74fbd018
         mutate(tooltip = .data$spot)
37dd1b03
 
d1b5ab40
     if (!is.null(counts)) {
         topGenes <-
             counts %>%
             group_by(.data$spot) %>%
             mutate(rank = rank(-.data$count, ties.method = "first")) %>%
             filter(.data$rank <= numTopGenes) %>%
             arrange(-.data$count) %>%
             summarize(topGenes = paste(
                 .SVGBarplot(setNames(.data$count, .data$gene))
             ))
         df <-
             df %>%
             inner_join(topGenes, by = "spot") %>%
             mutate(tooltip = paste(sep = "<br />",
                 .data$tooltip,
                 .data$topGenes
             )) %>%
             select(-.data$topGenes)
     }
 
a72b7c87
     ggplot() +
         annotation +
74fbd018
         geom_scatterpie_interactive(
             mapping = ggplot2::aes_string(
                 x0 = "x", y0 = "y", r = "r", amount = "score", fill = "name",
                 tooltip = "tooltip"
             ),
             data = df,
             alpha = spotOpacity,
             n = 64
a72b7c87
         ) +
         coord_fixed() +
         scale_x_continuous(expand = c(0, 0), limits = c(xmin, xmax)) +
         scale_y_continuous(expand = c(0, 0), limits = c(ymin, ymax)) +
4293f8a1
         theme_minimal() +
a72b7c87
         theme(
             axis.text = element_blank(),
             axis.title = element_blank(),
4293f8a1
             axis.ticks = element_blank(),
             panel.grid = element_blank()
a72b7c87
         )
6b8dc425
 }
 
148ebd95
 
bf43ae5c
 #' Cluster graph
6b8dc425
 #'
8059fda2
 #' @param assignments \code{\link[base]{data.frame}} with columns `"name"`,
 #' `"resolution"`, and `"cluster"`.
7aacac3f
 #' @param clusterMeans \code{\link[base]{data.frame}} with columns `"name"`,
 #' `"resolution"`, `"cluster"`, `featureName`, and `"mean"`.
 #' @param featureName \code{\link[base]{character}} with the name of the
 #' clustered feature.
148ebd95
 #' @param transitionProportions how to compute the transition proportions.
 #' Possible values are:
 #' - `"From"`: based on the total number of assignments in the lower-resolution
 #' cluster
 #' - `"To"`: based on the total number of assignments in the higher-resolution
 #' cluster
8059fda2
 #' @param transitionLabels \code{\link[base]{logical}} specifying whether to
 #' show edge labels.
148ebd95
 #' @param transitionThreshold hide edges with transition proportions below this
8059fda2
 #' threshold.
7aacac3f
 #' @param numTopFeatures \code{\link[base]{integer}} specifying the number of
 #' features to show in the hover tooltips.
bf43ae5c
 #' @return \code{\link[ggplot2]{ggplot}} object of the cluster graph.
6b8dc425
 #' @keywords internal
bf43ae5c
 .clusterGraph <- function(
a72b7c87
     assignments,
7aacac3f
     clusterMeans,
     featureName,
a72b7c87
     transitionProportions = "To",
     transitionLabels = FALSE,
7aacac3f
     transitionThreshold = 0.0,
     numTopFeatures = 10
6b8dc425
 ) {
358ab48f
     transitionSym <-
         if (transitionProportions == "To") "toNode"
         else if (transitionProportions == "From") "node"
         else stop(sprintf(
a72b7c87
             "Invalid value `transitionProportions`: %s",
             str(transitionProportions)
         ))
 
b43b0a4d
     data <-
         assignments %>%
358ab48f
         mutate(resolution = as.numeric(.data$resolution)) %>%
         rename(node = .data$name)
a72b7c87
 
148ebd95
     graph <- igraph::graph_from_data_frame(
e6730a40
         d = data %>%
358ab48f
             mutate(toResolution = .data$resolution + 1) %>%
e6730a40
             (function(x) inner_join(
                 x,
358ab48f
                 x %>%
                     select(
                         everything(),
                         toCluster = .data$cluster,
                         toNode = .data$node
                     ),
37dd1b03
                 by = c("unit", "toResolution" = "resolution")
e6730a40
             )) %>%
358ab48f
             group_by(
                 .data$node,
                 .data$toNode,
                 .data$cluster,
                 .data$toCluster
             ) %>%
e6730a40
             summarize(transCount = n()) %>%
358ab48f
             group_by(!! sym(transitionSym)) %>%
             mutate(transProp = .data$transCount / sum(.data$transCount)) %>%
4ff6a5fa
             ungroup() %>%
 
             group_by(.data$toNode) %>%
             filter(
                 .data$transProp == max(.data$transProp)
                 | .data$transProp > transitionThreshold
             ) %>%
             ungroup() %>%
             # ^ filter edges with transition proportions (weights) below
             #   threshold but always keep the incident edge with the highest
             #   weight (since the graph would become disconnected if that edge
             #   also were removed)
 
358ab48f
             select(.data$node, .data$toNode, everything()),
e6730a40
         vertices = data %>%
358ab48f
             group_by(.data$node, .data$resolution, .data$cluster) %>%
e6730a40
             summarize(size = n())
a72b7c87
     )
 
     vertices <- cbind(
d8d0ce3b
         igraph::layout_as_tree(graph, flip.y = FALSE) %>%
             `colnames<-`(c("y", "x")),
a72b7c87
         igraph::get.vertex.attribute(graph) %>%
             as.data.frame(stringsAsFactors = FALSE)
3ec8edbf
     )
a72b7c87
 
     edges <- c(
b43b0a4d
         igraph::get.edgelist(graph) %>%
             array_branch(2) %>%
a72b7c87
             `names<-`(c("from", "to")),
         igraph::get.edge.attribute(graph)
     ) %>%
         as.data.frame(stringsAsFactors = FALSE) %>%
         inner_join(
358ab48f
             vertices %>%
                 select(.data$name, .data$x, .data$y),
a72b7c87
             by = c("from" = "name")
         ) %>%
         inner_join(
358ab48f
             vertices %>%
                 select(.data$name, xend = .data$x, yend = .data$y),
a72b7c87
             by = c("to" = "name")
3dfba42b
         )
a72b7c87
 
7aacac3f
     resolutionLabels <-
d8d0ce3b
         vertices %>%
358ab48f
         select(.data$resolution, .data$x, .data$y) %>%
         filter(.data$resolution != 1) %>%
         mutate(ymin = min(.data$y), ymax = max(.data$y)) %>%
         group_by(.data$resolution) %>%
d8d0ce3b
         summarize(
358ab48f
             x = mean(.data$x),
             y = first(.data$ymax) +
                 0.1 * (first(.data$ymax) - first(.data$ymin))
d8d0ce3b
         ) %>%
bf997737
         mutate(
45a46425
             label = as.character(
                 levels(assignments$resolution)[.data$resolution])
bf997737
         )
d8d0ce3b
 
7aacac3f
     tooltips <-
         clusterMeans %>%
         mutate(name = as.character(.data$name)) %>%
         group_by(.data$name) %>%
         mutate(rank = rank(-.data$mean, ties.method = "first")) %>%
         filter(.data$rank <= numTopFeatures) %>%
         arrange(-.data$mean) %>%
d1b5ab40
         summarize(tooltip = .SVGBarplot(setNames(
             mean,
             nm = !! sym(featureName)
         )))
7aacac3f
     vertices <-
         vertices %>%
         inner_join(tooltips, by = "name") %>%
         mutate(tooltip = paste(sep = "<br />",
             toTitleCase(.data$name),
             sprintf("Size: %d", .data$size),
             .data$tooltip
         ))
 
a72b7c87
     ggplot() +
         geom_segment(
358ab48f
             aes_string(
                 "x", "y",
                 xend = "xend", yend = "yend",
                 alpha = "transProp"
a72b7c87
             ),
             col = "black",
             data = edges
         ) +
         geom_point_interactive(
358ab48f
             aes_(
                 ~x, ~y,
                 color = ~name,
                 size = ~size,
7aacac3f
                 tooltip = ~tooltip
a72b7c87
             ),
358ab48f
             data = vertices %>% filter(.data$resolution != 1)
a72b7c87
         ) +
         {
             if (isTRUE(transitionLabels))
148ebd95
                 ggrepel::geom_label_repel(
358ab48f
                     aes_(
                         x = ~(x + xend) / 2,
                         y = ~(y + yend) / 2,
                         color =
                             if (transitionProportions == "To") ~as.factor(to)
                             else ~as.factor(from),
                         label = ~round(transProp, 2)
a72b7c87
                     ),
                     data = edges,
                     show.legend = FALSE
                 )
             else NULL
         } +
d8d0ce3b
         ggplot2::geom_text(
358ab48f
             aes_string("x", "y", label = "label"),
7aacac3f
             data = resolutionLabels
d8d0ce3b
         ) +
a72b7c87
         labs(alpha = "Proportion", color = "Cluster") +
500cf13b
         scale_size(guide = "none", range = c(2, 7)) +
d8d0ce3b
         scale_x_continuous(expand = c(0.1, 0.1)) +
         guides(alpha = FALSE, color = FALSE) +
500cf13b
         theme_bw() +
         theme(
016a623c
             axis.ticks.x = element_blank(),
             axis.ticks.y = element_blank(),
             axis.title.x = element_blank(),
             axis.title.y = element_blank(),
             axis.text.x = element_blank(),
d8d0ce3b
             axis.text.y = element_blank(),
             panel.grid = element_blank(),
             panel.border = element_blank()
500cf13b
         )
6b8dc425
 }
 
148ebd95
 
a72b7c87
 #' SpatialCPie server
 #'
8059fda2
 #' @param assignments \code{\link[base]{data.frame}} with cluster assignments
 #' containing the columns `"unit"` (name of the observational unit; either a
 #' gene name or a spot name), `"resolution"`, `"cluster"`, and `"name"` (a
 #' unique identifier of the (resolution, cluster) pair).
7aacac3f
 #' @param clusterMeans \code{\link[base]{data.frame}} with columns `"name"`,
 #' `"resolution"`, `"cluster"`, `featureName`, and `"mean"`.
8059fda2
 #' @param scores \code{\link[base]{data.frame}} with cluster scores for each
 #' spot in each resolution containing the columns `"spot"`, `"resolution"`,
 #' `"cluster"`, `"name"`, and `"score"`.
 #' @param colors vector of colors for each cluster. Names should match the
 #' `"name"` columns of the `assignments` and `scores`.
148ebd95
 #' @param image background image for the array plots, passed to
8059fda2
 #' \code{\link[grid]{grid.raster}}.
 #' @param coordinates \code{\link[base]{data.frame}} with `rownames` matching
 #' the \code{\link[base]{names}} in `scores` and columns `"x"` and `"y"`
 #' specifying the plotting position of each observation.
7aacac3f
 #' @param featureName \code{\link[base]{character}} with the name of the
 #' clustered feature.
8059fda2
 #' @return server function, to be passed to \code{\link[shiny]{shinyApp}}.
a72b7c87
 #' @keywords internal
 .makeServer <- function(
37dd1b03
     assignments,
7aacac3f
     clusterMeans,
d1b5ab40
     counts,
37dd1b03
     scores,
a72b7c87
     colors,
148ebd95
     image,
7aacac3f
     coordinates,
     featureName
a72b7c87
 ) {
37dd1b03
     resolutions <-
         levels(assignments$resolution) %>%
e61be5d9
         keep(~. != "root")
148ebd95
 
a72b7c87
     function(input, output, session) {
4293f8a1
         if (is.null(image)) {
             hideElement("showImage")
             hideElement("spotOpacity")
         }
 
a72b7c87
         ###
         ## INPUTS
         edgeProportions <- reactive({ input$edgeProportions })
c03a725a
         edgeThreshold   <- reactive({ input$edgeThreshold   }) %>% debounce(1000)
a72b7c87
         edgeLabels      <- reactive({ input$edgeLabels      })
c03a725a
         scoreMultiplier <- reactive({ input$scoreMultiplier }) %>% debounce(1000)
a72b7c87
         showImage       <- reactive({ input$showImage       })
c03a725a
         spotOpacity     <- reactive({ input$spotOpacity     }) %>% debounce(1000)
         spotSize        <- reactive({ input$spotSize        }) %>% debounce(1000)
a72b7c87
 
         ###
bf43ae5c
         ## CLUSTER GRAPH
         clusterGraph <- reactive({
             p <- .clusterGraph(
148ebd95
                 assignments,
7aacac3f
                 clusterMeans,
a72b7c87
                 transitionProportions = edgeProportions(),
4293f8a1
                 transitionLabels = edgeLabels(),
7aacac3f
                 transitionThreshold = edgeThreshold(),
                 featureName = featureName
a72b7c87
             ) +
                 scale_color_manual(values = colors)
         })
 
bf43ae5c
         output$clusterGraph <- ggiraph::renderGirafe({
b58b02c6
             plot <- ggiraph::girafe_options(
bf43ae5c
                 x = ggiraph::girafe(ggobj = clusterGraph()),
4293f8a1
                 ggiraph::opts_toolbar(saveaspng = FALSE)
a72b7c87
             )
             plot
         })
 
         ###
         ## ARRAY PLOT
e61be5d9
         arrayName <- function(r) sprintf("array%s", sha1(r))
4293f8a1
 
37dd1b03
         for (r in resolutions) {
4293f8a1
             shiny::insertUI("#array", "beforeEnd",
                 shiny::div(class = "array", "data-resolution" = r,
                     ggiraph::girafeOutput(arrayName(r)) %>%
                     shinycssloaders::withSpinner()
                 ),
                 immediate = TRUE
             )
a72b7c87
             ## We evaluate the below block in a new frame (with anonymous
37dd1b03
             ## function call) in order to protect the value of `r`, which
a72b7c87
             ## will have changed when the reactive expressions are
             ## evaluated.
             (function() {
37dd1b03
                 r_ <- r
d8d0ce3b
                 scores_ <-
                     scores %>%
358ab48f
                     filter(.data$resolution == r_)
4293f8a1
                 assign(envir = parent.frame(), arrayName(r_), reactive(
a72b7c87
                     .arrayPlot(
358ab48f
                         scores = scores_ %>%
                             select(.data$spot, .data$name, .data$score),
148ebd95
                         coordinates = coordinates,
d1b5ab40
                         counts = counts,
a72b7c87
                         image =
148ebd95
                             if (!is.null(image) && !is.null(coordinates) &&
4293f8a1
                                     showImage())
148ebd95
                                 grid::rasterGrob(
                                     image,
a72b7c87
                                     width = unit(1, "npc"),
                                     height = unit(1, "npc"),
                                     interpolate = TRUE
                                 )
                             else NULL,
a0404626
                         scoreMultiplier = 2 ** scoreMultiplier(),
a72b7c87
                         spotScale = spotSize() / 5,
4293f8a1
                         spotOpacity = spotOpacity()
a72b7c87
                     ) +
d8d0ce3b
                         guides(fill = guide_legend(title = "Cluster")) +
                         scale_fill_manual(
                             values = colors,
                             labels = unique(scores_$cluster)
4293f8a1
                         )
a72b7c87
                 ))
4293f8a1
                 output[[arrayName(r_)]] <- ggiraph::renderGirafe(
a72b7c87
                     {
4293f8a1
                         ggiraph::girafe_options(
                             x = ggiraph::girafe(
842ae933
                                 ggobj = eval(call(arrayName(r_))),
                                 xml_reader_options = list(options = "HUGE")),
4293f8a1
                             ggiraph::opts_toolbar(saveaspng = FALSE),
                             ggiraph::opts_zoom(max = 5)
                         )
                     }
a72b7c87
                 )
             })()
         }
 
         ###
         ## EXPORT
795be519
         outputs <- reactive({
             list(
358ab48f
                 clusters = assignments %>% select(-.data$name),
bf43ae5c
                 clusterGraph = clusterGraph(),
                 arrayPlot = lapply(
4293f8a1
                     setNames(nm = resolutions),
                     function(x) eval(call(arrayName(x)))
a72b7c87
                 )
795be519
             )
a72b7c87
         })
795be519
         observeEvent(input$done, shiny::stopApp(returnValue = outputs()))
a72b7c87
     }
 }
 
148ebd95
 
a72b7c87
 #' SpatialCPie UI
 #'
8059fda2
 #' @return SpatialCPie UI, to be passed to \code{\link[shiny]{shinyApp}}.
a72b7c87
 #' @keywords internal
4293f8a1
 .makeUI <- function() {
     shiny::htmlTemplate(system.file(
         "www", "default", "index.html", package = "SpatialCPie"))
148ebd95
 }
 
a72b7c87
 
4764ea87
 #' SpatialCPie App
6b8dc425
 #'
97d9a729
 #' @param image background image.
8059fda2
 #' @param ... arguments passed to \code{\link{.preprocessData}}.
 #' @return SpatialCPie \code{\link[shiny]{shinyApp}} object.
97d9a729
 #' @keywords internal
ad889850
 .makeApp <- function(image, ...) {
     data <- .preprocessData(...)
4764ea87
     shiny::shinyApp(
4293f8a1
         ui = .makeUI(),
a72b7c87
         server = .makeServer(
37dd1b03
             assignments = data$assignments,
7aacac3f
             clusterMeans = data$means,
d1b5ab40
             counts = data$counts,
37dd1b03
             scores = data$scores,
9e90795e
             colors = data$colors,
148ebd95
             image = image,
7aacac3f
             coordinates = data$coordinates,
             featureName = data$featureName
4764ea87
         )
     )
 }
 
 
 #' Run SpatialCPie
 #'
 #' Runs the SpatialCPie gadget.
c6fae5dd
 #' @param counts gene count matrix or a
4fb0f7e5
 #' \code{\link[SummarizedExperiment]{SummarizedExperiment-class}} object
 #' containing count values.
4764ea87
 #' @param image image to be used as background to the plot.
d3f26a7b
 #' @param spotCoordinates \code{\link[base]{data.frame}} with pixel coordinates.
 #' The rows should correspond to the columns (spatial areas) in the count file.
8af7a64b
 #' @param margin which margin to cluster.
8059fda2
 #' @param resolutions \code{\link[base]{numeric}} vector specifying the
 #' clustering resolutions.
8af7a64b
 #' @param assignmentFunction function to compute cluster assignments.
8059fda2
 #' @param view \code{\link[shiny]{viewer}} object.
4764ea87
 #' @return a list with the following items:
 #' - `"clusters"`: Cluster assignments (may differ from `assignments`)
bf43ae5c
 #' - `"clusterGraph"`: The cluster tree ggplot object
 #' - `"arrayPlot"`: The pie plot ggplot objects
4764ea87
 #' @export
 #' @examples
 #' if (interactive()) {
 #'     options(device.ask.default = FALSE)
 #'
 #'     ## Set up coordinate system
 #'     coordinates <- as.matrix(expand.grid(1:10, 1:10))
 #'
 #'     ## Generate data set with three distinct genes generated by three
 #'     ## distinct cell types
 #'     profiles <- diag(rep(1, 3)) + runif(9)
 #'     centers <- cbind(c(5, 2), c(2, 8), c(8, 2))
 #'     mixes <- apply(coordinates, 1, function(x) {
 #'         x <- exp(-colSums((centers - x) ^ 2) / 50)
 #'         x / sum(x)
 #'     })
 #'     means <- 100 * profiles %*% mixes
 #'     counts <- matrix(rpois(prod(dim(means)), means), nrow = nrow(profiles))
 #'     colnames(counts) <- apply(
 #'         coordinates,
 #'         1,
 #'         function(x) do.call(paste, c(as.list(x), list(sep = "x")))
 #'     )
 #'     rownames(counts) <- paste("gene", 1:nrow(counts))
 #'
 #'     ## Run SpatialCPie
8af7a64b
 #'     runCPie(counts)
4764ea87
 #' }
 runCPie <- function(
     counts,
     image = NULL,
     spotCoordinates = NULL,
ad889850
     margin = "spot",
     resolutions = 2:4,
9b310cf7
     assignmentFunction = function(k, x) kmeans(x, centers = k)$cluster,
4764ea87
     view = NULL
 ) {
375d964b
     if (is(counts, "SummarizedExperiment")) {
f7671caa
         counts <- as.data.frame(SummarizedExperiment::assay(counts))
d438c8dc
     }
4764ea87
     shiny::runGadget(
         app = .makeApp(
9e90795e
             image = image,
d1b5ab40
             counts = counts,
ad889850
             coordinates = spotCoordinates,
             margin = margin,
             resolutions = resolutions,
             assignmentFunction = assignmentFunction
6b8dc425
         ),
4293f8a1
         viewer = view %||% shiny::paneViewer()
3b459163
     )
6b8dc425
 }