#' @title Plots contamination on UMAP coordinates
#' @description A scatter plot of the UMAP dimensions generated by DecontX with
#' cells colored by the estimated percentation of contamation.
#' @param x Either a \linkS4class{SingleCellExperiment} with \code{decontX}
#' results stored in \code{metadata(x)$decontX} or the result from running
#' decontX on a count matrix.
#' @param batch Character. Batch of cells to plot. If \code{NULL}, then
#' the first batch in the list will be selected. Default \code{NULL}.
#' @param colorScale Character vector. Contains the color spectrum to be passed
#' to \code{scale_colour_gradientn} from package 'ggplot2'. Default
#' c("blue","green","yellow","orange","red").
#' @param size Numeric. Size of points in the scatterplot. Default 1.
#' @return Returns a \code{ggplot} object.
#' @author Shiyi Yang, Joshua Campbell
#' @seealso See \code{\link{decontX}} for a full example of how to estimate
#' and plot contamination.
#' @export
plotDecontXContamination <- function(x,
                                     batch = NULL,
                                     colorScale = c(
                                       "blue",
                                       "green",
                                       "yellow",
                                       "orange",
                                       "red"
                                     ),
                                     size = 1) {
  if (inherits(x, "SingleCellExperiment")) {
    estimates <- S4Vectors::metadata(x)$decontX$estimates
  } else {
    estimates <- x$estimates
  }
  if (is.null(estimates)) {
    stop("decontX estimates not found. Estimates will be found in
          'metadata(x)$decontX$estimates' if 'x' is a
          SingleCellExperiment or 'x$estimates' if decontX was run
          on a count matrix. Are you sure 'x' is output from decontX?")
  }
  batches <- names(estimates)

  if (is.null(batch)) {
    i <- batches[1]
  } else {
    if (!(batch %in% batches)) {
      stop(
        "'", batch, "' is not one of the batches in 'x'. Batches available",
        " for plotting: '", paste(batches, collapse = ","), "'"
      )
    }
    i <- batch
  }

  contamin <- estimates[[i]]$contamination
  umap <- estimates[[i]]$UMAP

  ## Create data.frame
  df <- data.frame(umap, "Contamination" = contamin)
  naIx <- is.na(umap[, 1]) | is.na(umap[, 2])
  df <- df[!naIx, ]

  ## Generate ggplot scatterplot
  gg <- ggplot2::ggplot(
    df,
    ggplot2::aes_string(
      x = colnames(umap)[1],
      y = colnames(umap)[2]
    )
  ) +
    ggplot2::geom_point(
      stat = "identity",
      size = size,
      ggplot2::aes_string(color = "Contamination")
    ) +
    ggplot2::theme_bw() +
    ggplot2::scale_colour_gradientn(
      colors = colorScale,
      name = "Contamination",
      limits = c(0, 1)
    ) +
    ggplot2::theme(
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      axis.text = ggplot2::element_text(size = 15),
      axis.title = ggplot2::element_text(size = 15)
    )
  return(gg)
}


#' @title Plots percentage of cells cell types expressing markers
#' @description Generates a barplot that shows the percentage of
#' cells within clusters or cell types that have detectable levels
#' of given marker genes. Can be used to view the expression of
#' marker genes in different cell types before and after
#' decontamination with \code{\link{decontX}}.
#' @param x Either a \linkS4class{SingleCellExperiment} or
#' a matrix-like object of counts.
#' @param markers List. A named list indicating the marker genes
#' for each cell type of
#' interest. Multiple markers can be supplied for each cell type. For example,
#' \code{list(Tcell_Markers=c("CD3E", "CD3D"),
#' Bcell_Markers=c("CD79A", "CD79B", "MS4A1")}
#' would specify markers for human T-cells and B-cells.
#' A cell will be considered
#' "positive" for a cell type if it has a count greater than \code{threshold}
#' for at least one of the marker genes in the list.
#' @param groupClusters List. A named list that allows
#' cell clusters labels coded in
#' \code{z} to be regrouped and renamed on the fly. For example,
#' \code{list(Tcells=c(1, 2), Bcells=7)} would recode
#' clusters 1 and 2 to "Tcells"
#' and cluster 7 to "Bcells". Note that if this is
#' used, clusters in \code{z} not found
#' in \code{groupClusters} will be excluded from the barplot.
#' Default \code{NULL}.
#' @param assayName Character vector. Name(s) of the assay(s) to
#' plot if \code{x} is a
#' \linkS4class{SingleCellExperiment}. If more than one assay
#' is listed, then side-by-side barplots will be generated.
#' Default \code{c("counts", "decontXcounts")}.
#' @param z Character, Integer, or Vector. Indicates the cluster labels
#' for each cell.
#' If \code{x} is a \linkS4class{SingleCellExperiment} and \code{z = NULL},
#' then the cluster labels from \code{\link{decontX}} will be retived from the
#' \code{colData} of \code{x} (i.e. \code{colData(x)$decontX_clusters}).
#' If \code{z} is a single character or integer,
#' then that column will be retrived
#' from \code{colData} of \code{x}. (i.e. \code{colData(x)[,z]}). If \code{x}
#' is a counts matrix, then \code{z} will need
#' to be a vector the same length as
#' the number of columns in \code{x} that indicate
#' the cluster to which each cell
#' belongs. Default \code{NULL}.
#' @param threshold Numeric. Markers greater than or equal to this value will
#' be considered detected in a cell. Default 1.
#' @param exactMatch Boolean. Whether to only identify exact matches
#' for the markers or to identify partial matches using \code{\link{grep}}. See
#' \code{\link{retrieveFeatureIndex}} for more details. Default \code{TRUE}.
#' @param by Character. Where to search for the markers if \code{x} is a
#' \linkS4class{SingleCellExperiment}. See \code{\link{retrieveFeatureIndex}}
#' for more details. If \code{x} is a matrix,
#' then this must be set to \code{"rownames"}.Default \code{"rownames"}.
#' @param ncol Integer. Number of columns to make in the plot.
#' Default \code{round(sqrt(length(markers))}.
#' @param labelBars Boolean. Whether to display percentages above each bar
#' Default \code{TRUE}.
#' @param labelSize Numeric. Size of the percentage labels in the barplot.
#' Default 3.
#' @return Returns a \code{ggplot} object.
#' @author Shiyi Yang, Joshua Campbell
#' @seealso See \code{\link{decontX}} for a full example of how to estimate
#' and plot contamination.
#' @export
plotDecontXMarkerPercentage <- function(x, markers, groupClusters = NULL,
                                        assayName = c(
                                          "counts",
                                          "decontXcounts"
                                        ),
                                        z = NULL, threshold = 1,
                                        exactMatch = TRUE, by = "rownames",
                                        ncol = round(sqrt(length(markers))),
                                        labelBars = TRUE, labelSize = 3) {
  cellTypeLabels <- percent <- NULL # fix check note

  legend <- "none"
  # Check that list arguments are named
  if (!is(markers, "list") || is.null(names(markers))) {
    stop("'markers' needs to be a named list.")
  }

  temp <- .processPlotDecontXMarkerInupt(
    x = x,
    z = z,
    markers = markers,
    groupClusters = groupClusters,
    by = by,
    exactMatch = exactMatch
  )
  x <- temp$x
  z <- temp$z
  geneMarkerIndex <- temp$geneMarkerIndex
  geneMarkerCellTypeIndex <- temp$geneMarkerCellTypeIndex
  groupClusters <- temp$groupClusters
  xlab <- temp$xlab

  if (inherits(x, "SingleCellExperiment")) {
    # If 'x' is SingleCellExperiment, then get percentage
    # for each matrix in 'assayName'
    df.list <- list()
    for (i in seq_along(assayName)) {
      counts <- SummarizedExperiment::assay(
        x[geneMarkerIndex, ],
        assayName[i]
      )
      df <- .calculateDecontXBarplotPercent(
        counts,
        z,
        geneMarkerCellTypeIndex,
        threshold
      )
      df.list[[i]] <- cbind(df, assay = assayName[i])
    }
    df <- do.call(rbind, df.list)
    assay <- as.factor(df$assay)
    if (length(assayName) > 1) {
      legend <- "right"
    }
  } else {
    ## If 'x' is matrix, then calculate percentages directly
    counts <- x[geneMarkerIndex, ]
    df <- .calculateDecontXBarplotPercent(
      counts,
      z,
      geneMarkerCellTypeIndex,
      threshold
    )
    assay <- "red3"
    legend <- "none"
  }

  # Build data.frame for ggplots
  df <- cbind(df, cellTypeLabels = names(groupClusters)[df$cellType])
  df$cellTypeLabels <- factor(df$cellTypeLabels,
    levels = names(groupClusters)
  )
  df <- cbind(df, markerLabels = names(markers)[df$markers])
  df$markerLabels <- factor(df$markerLabels, levels = names(markers))

  plt <- ggplot2::ggplot(df, ggplot2::aes_string(
    x = "cellTypeLabels",
    y = "percent", fill = "assay"
  )) +
    ggplot2::geom_bar(
      stat = "identity",
      position = ggplot2::position_dodge2(width = 0.9, preserve = "single")
    ) +
    ggplot2::xlab(xlab) +
    ggplot2::ylab(paste0("Percentage of cells expressing markers")) +
    ggplot2::facet_wrap(. ~ df$markerLabels, ncol = ncol) +
    ggplot2::theme(
      panel.background = ggplot2::element_rect(
        fill = "white",
        color = "grey"
      ),
      panel.grid = ggplot2::element_line("grey"),
      legend.position = legend,
      legend.key = ggplot2::element_rect(
        fill = "white",
        color = "white"
      ),
      panel.grid.minor = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      text = ggplot2::element_text(size = 10),
      axis.text.x = ggplot2::element_text(
        size = 8, angle = 45,
        hjust = 1
      ),
      axis.text.y = ggplot2::element_text(size = 9),
      legend.key.size = grid::unit(8, "mm"),
      legend.text = ggplot2::element_text(size = 10),
      strip.text.x = ggplot2::element_text(size = 10)
    )

  if (isTRUE(labelBars)) {
    plt <- plt + ggplot2::geom_text(ggplot2::aes(
      x = cellTypeLabels,
      y = percent + 2.5,
      label = percent
    ),
    position = ggplot2::position_dodge2(width = 0.9, preserve = "single"),
    size = labelSize
    )
  }
  return(plt)
}


#' @title Plots expression of marker genes before and after decontamination
#' @description Generates a violin plot that shows the counts of marker
#' genes in cells across specific clusters or cell types. Can be used to view
#' the expression of marker genes in different cell types before and after
#' decontamination with \code{\link{decontX}}.
#' @param x Either a \linkS4class{SingleCellExperiment}
#' or a matrix-like object of counts.
#' @param markers Character Vector or List. A character vector
#' or list of character vectors
#' with the names of the marker genes of interest.
#' @param groupClusters List. A named list that allows
#' cell clusterslabels coded in
#' \code{z} to be regrouped and renamed on the fly. For example,
#' \code{list(Tcells=c(1, 2), Bcells=7)} would recode clusters
#' 1 and 2 to "Tcells"
#' and cluster 7 to "Bcells". Note that if this is used, clusters
#' in \code{z} not found
#' in \code{groupClusters} will be excluded. Default \code{NULL}.
#' @param assayName Character vector. Name(s) of the assay(s) to
#' plot if \code{x} is a
#' \linkS4class{SingleCellExperiment}. If more than one assay is listed, then
#' side-by-side violin plots will be generated.
#' Default \code{c("counts", "decontXcounts")}.
#' @param z Character, Integer, or Vector.
#' Indicates the cluster labels for each cell.
#' If \code{x} is a \linkS4class{SingleCellExperiment} and \code{z = NULL},
#' then the cluster labels from \code{\link{decontX}} will be retived from the
#' \code{colData} of \code{x} (i.e. \code{colData(x)$decontX_clusters}).
#' If \code{z} is a single character or integer, then that column will be
#' retrived from \code{colData} of \code{x}. (i.e. \code{colData(x)[,z]}).
#' If \code{x} is a counts matrix, then \code{z} will need to be a vector
#' the same length as the number of columns in \code{x} that indicate
#' the cluster to which each cell belongs. Default \code{NULL}.
#' @param exactMatch Boolean. Whether to only identify exact matches
#' for the markers or to identify partial matches using \code{\link{grep}}.
#' See \code{\link{retrieveFeatureIndex}} for more details.
#' Default \code{TRUE}.
#' @param by Character. Where to search for the markers if \code{x} is a
#' \linkS4class{SingleCellExperiment}. See \code{\link{retrieveFeatureIndex}}
#' for more details. If \code{x} is a matrix, then this must be set to
#' \code{"rownames"}. Default \code{"rownames"}.
#' @param log1p Boolean. Whether to apply the function \code{log1p} to the data
#' before plotting. This function will add a pseudocount of 1 and then log
#' transform the expression values. Default \code{FALSE}.
#' @param ncol Integer. Number of columns to make in the plot.
#' Default \code{NULL}.
#' @param plotDots Boolean. If \code{TRUE}, the
#'  expression of features will be plotted as points in addition to the violin
#'  curve. Default \code{FALSE}.
#' @param dotSize Numeric. Size of points if \code{plotDots = TRUE}.
#' Default \code{0.1}.
#' @return Returns a \code{ggplot} object.
#' @author Shiyi Yang, Joshua Campbell
#' @seealso See \code{\link{decontX}} for a full example of how to estimate
#' and plot contamination.
#' @export
plotDecontXMarkerExpression <- function(x, markers, groupClusters = NULL,
                                        assayName = c(
                                          "counts",
                                          "decontXcounts"
                                        ),
                                        z = NULL, exactMatch = TRUE,
                                        by = "rownames", log1p = FALSE,
                                        ncol = NULL,
                                        plotDots = FALSE, dotSize = 0.1) {
  legend <- "none"
  temp <- .processPlotDecontXMarkerInupt(
    x = x,
    z = z,
    markers = markers,
    groupClusters = groupClusters,
    by = by,
    exactMatch = exactMatch
  )
  x <- temp$x
  z <- temp$z
  geneMarkerIndex <- temp$geneMarkerIndex
  groupClusters <- temp$groupClusters
  xlab <- temp$xlab

  if (inherits(x, "SingleCellExperiment")) {
    # If 'x' is SingleCellExperiment, then get percentage
    # for each matrix in 'assayName'
    df.list <- list()
    for (i in seq_along(assayName)) {
      counts <- SummarizedExperiment::assay(
        x[geneMarkerIndex, ],
        assayName[i]
      )
      df <- reshape2::melt(as.matrix(counts),
        varnames = c("Marker", "Cell"),
        value.name = "Expression"
      )
      df.list[[i]] <- cbind(df, assay = assayName[i])
    }
    df <- do.call(rbind, df.list)
    assay <- as.factor(df$assay)
    if (length(assayName) > 1) {
      legend <- "right"
    }
  } else {
    ## If 'x' is matrix, then calculate percentages directly
    counts <- x[geneMarkerIndex, ]
    df <- reshape2::melt(counts,
      varnames = c("Marker", "Cell"),
      value.name = "Expression"
    )
    assay <- "red3"
    legend <- "none"
  }

  # Create data.frame and add cell type groups back in
  names(z) <- colnames(x)
  df <- cbind(df, Cluster = z[df$Cell])

  ylab <- "Expression"
  if (isTRUE(log1p)) {
    df$Expression <- log1p(df$Expression)
    ylab <- "Expression (log1p)"
  }
  Expression <- df$Expression
  Marker <- df$Marker
  Assay <- df$assay
  Cluster <- df$Cluster
  if (!is.null(groupClusters)) {
    df <- cbind(df, Cell_Type = names(groupClusters)[Cluster])
    Cell_Type <- factor(df$Cell_Type, levels = names(groupClusters))
    plt <- ggplot2::ggplot(df, ggplot2::aes(
      x = Cell_Type,
      y = Expression,
      fill = Assay
    )) +
      ggplot2::facet_wrap(~ Cell_Type + Marker,
        scales = "free",
        labeller = ggplot2::label_context,
        ncol = ncol
      )
  } else {
    plt <- ggplot2::ggplot(df, ggplot2::aes(
      x = Cluster,
      y = Expression,
      fill = Assay
    )) +
      ggplot2::facet_wrap(~ Cluster + Marker,
        scales = "free",
        labeller = ggplot2::label_context,
        ncol = ncol
      )
  }
  plt <- plt + ggplot2::geom_violin(
    trim = TRUE,
    scale = "width"
  ) +
    ggplot2::theme_bw() + ggplot2::theme(
      axis.text.x = ggplot2::element_blank(),
      axis.ticks.x = ggplot2::element_blank(),
      axis.title.x = ggplot2::element_blank(),
      strip.text = ggplot2::element_text(size = 8),
      panel.grid = ggplot2::element_blank(),
      legend.position = legend
    ) + ggplot2::ylab(ylab)


  if (isTRUE(plotDots)) {
    plt <- plt + ggplot2::geom_jitter(height = 0, size = dotSize)
  }

  return(plt)
}



.processPlotDecontXMarkerInupt <- function(x, z, markers, groupClusters,
                                           by, exactMatch) {

  # Process z and convert to a factor
  if (is.null(z) & inherits(x, "SingleCellExperiment")) {
    cn <- colnames(SummarizedExperiment::colData(x))
    if (!("decontX_clusters" %in% cn)) {
      stop("'decontX_clusters' not found in 'colData(x)'. Make sure you have
           run 'decontX' or supply 'z' directly.")
    }
    z <- SummarizedExperiment::colData(x)$decontX_clusters
  } else if (length(z) == 1 & inherits(x, "SingleCellExperiment")) {
    if (!(z %in% colnames(SummarizedExperiment::colData(x)))) {
      stop("'", z, "' not found in 'colData(x)'.")
    }
    z <- SummarizedExperiment::colData(x)[, z]
  } else if (length(z) != ncol(x)) {
    stop("If 'x' is a SingleCellExperiment, then 'z' needs to be",
          " a single character or integer specifying the column in",
          " 'colData(x)'. Alternatively to specify the cell cluster",
          " labels directly as a vector, the length of 'z' needs to",
          " be the same as the number of columns in 'x'. This is",
          " required if 'x' is a matrix.")
  }
  if (!is.factor(z)) {
    z <- as.factor(z)
  }

  if (!is.null(groupClusters)) {
    if (!is(groupClusters, "list") || is.null(names(groupClusters))) {
      stop("'groupClusters' needs to be a named list.")
    }

    # Check that groupClusters are found in 'z'
    cellMappings <- unlist(groupClusters)
    if (any(!(cellMappings %in% z))) {
      missing <- cellMappings[!(cellMappings %in% z)]
      stop(
        "'groupClusters' not found in 'z': ",
        paste(missing, collapse = ",")
      )
    }

    labels <- rep(NA, ncol(x))
    for (i in seq_along(groupClusters)) {
      labels[z %in% groupClusters[i]] <- names(groupClusters)[i]
    }
    na.ix <- is.na(labels)
    labels <- labels[!na.ix]
    x <- x[, !na.ix]
    z <- as.integer(factor(labels, levels = names(groupClusters)))
    xlab <- "Cell types"
  } else {
    labels <- as.factor(z)
    groupClusters <- levels(labels)
    names(groupClusters) <- levels(labels)
    xlab <- "Clusters"
  }

  # Find index of each feature in 'x'
  geneMarkerCellTypeIndex <- rep(
    seq(length(markers)),
    lapply(markers, length)
  )
  geneMarkerIndex <- retrieveFeatureIndex(unlist(markers),
    x,
    by = by,
    removeNA = FALSE,
    exactMatch = exactMatch
  )

  # Remove genes that did not match
  na.ix <- is.na(geneMarkerIndex)
  geneMarkerCellTypeIndex <- geneMarkerCellTypeIndex[!na.ix]
  geneMarkerIndex <- geneMarkerIndex[!na.ix]

  return(list(
    x = x,
    z = z,
    geneMarkerIndex = geneMarkerIndex,
    geneMarkerCellTypeIndex = geneMarkerCellTypeIndex,
    groupClusters = groupClusters,
    xlab = xlab
  ))
}


.calculateDecontXBarplotPercent <- function(counts,
                                            z,
                                            geneMarkerCellTypeIndex,
                                            threshold) {

  # Get counts matrix and convert to DelayedMatrix
  counts <- DelayedArray::DelayedArray(counts)

  # Convert to boolean matrix and sum markers in same cell type
  # The "+ 0" is to convert boolean to numeric
  counts <- counts >= threshold
  countsByMarker <- DelayedArray::rowsum(counts + 0, geneMarkerCellTypeIndex)
  countsByCellType <- DelayedArray::colsum((countsByMarker > 0) + 0, z)

  # Calculate percentages within each cell cluster
  zTotals <- tabulate(z)
  percentByCellType <- round(sweep(countsByCellType, 2, zTotals, "/") * 100)
  df <- reshape2::melt(percentByCellType,
    varnames = c("markers", "cellType"),
    value.name = "percent"
  )

  return(df)
}