#' @title Mapping the dimension reduction plot
#' @description Creates a scatterplot given two dimensions from a data
#'  dimension reduction tool (e.g tSNE) output.
#' @param x Numeric matrix or a \linkS4class{SingleCellExperiment} object
#'  with the matrix located in the assay slot under \code{useAssay}. Each
#'  row of the matrix will be plotted as a separate facet.
#' @param reducedDimName The name of the dimension reduction slot in
#'  \code{reducedDimNames(x)} if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Ignored if both \code{dim1} and
#'  \code{dim2} are set.
#' @param dim1 Numeric vector. Second dimension from data dimension
#'  reduction output.
#' @param dim2 Numeric vector. Second dimension from data dimension
#'  reduction output.
#' @param useAssay A string specifying which \link{assay}
#'  slot to use if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param size Numeric. Sets size of point on plot. Default 1.
#' @param xlab Character vector. Label for the x-axis. Default 'Dimension_1'.
#' @param ylab Character vector. Label for the y-axis. Default 'Dimension_2'.
#' @param limits Passed to \link{scale_colour_gradient2}. The range
#'  of color scale.
#' @param colorLow Character. A color available from `colors()`.
#'  The color will be used to signify the lowest values on the scale.
#'  Default "blue4".
#' @param colorMid Character. A color available from `colors()`.
#'  The color will be used to signify the midpoint on the scale. Default
#'  "grey90".
#' @param colorHigh Character. A color available from `colors()`.
#'  The color will be used to signify the highest values on the scale.
#'  Default "firebrick1".
#' @param midpoint Numeric. The value indicating the midpoint of the
#' diverging color scheme. If \code{NULL}, defaults to the mean
#' with 10 percent of values trimmed. Default \code{0}.
#' @param varLabel Character vector. Title for the color legend.
#' @param ncol Integer. Passed to \link[ggplot2]{facet_wrap}. Specify the
#'  number of columns for facet wrap.
#' @param headers Character vector. If `NULL`, the corresponding rownames are
#'  used as labels. Otherwise, these headers are used to label the genes.
#' @param decreasing logical. Specifies the order of plotting the points.
#'  If \code{FALSE}, the points will be plotted in increasing order where
#'  the points with largest values will be on top. \code{TRUE} otherwise.
#'  If \code{NULL}, no sorting is performed. Points will be plotted in their
#'  current order in \code{x}. Default \code{FALSE}.
#' @return The plot as a ggplot object
#' @export
setGeneric("plotDimReduceGrid",
           function(x,
                    reducedDimName,
                    dim1 = NULL,
                    dim2 = NULL,
                    useAssay = "counts",
                    altExpName = "featureSubset",
                    size = 1,
                    xlab = "Dimension_1",
                    ylab = "Dimension_2",
                    limits = c(-2, 2),
                    colorLow = "blue4",
                    colorMid = "grey90",
                    colorHigh = "firebrick1",
                    midpoint = 0,
                    varLabel = NULL,
                    ncol = NULL,
                    headers = NULL,
                    decreasing = FALSE) {
             standardGeneric("plotDimReduceGrid")
           })


#' @rdname plotDimReduceGrid
#' @examples
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceGrid(x = sce,
#'   reducedDimName = "celda_tSNE",
#'   xlab = "Dimension1",
#'   ylab = "Dimension2",
#'   varLabel = "tSNE")
#' @export
setMethod("plotDimReduceGrid",
          signature(x = "SingleCellExperiment"),
          function(x,
                   reducedDimName,
                   dim1 = NULL,
                   dim2 = NULL,
                   useAssay = "counts",
                   altExpName = "featureSubset",
                   size = 1,
                   xlab = "Dimension_1",
                   ylab = "Dimension_2",
                   limits = c(-2, 2),
                   colorLow = "blue4",
                   colorMid = "grey90",
                   colorHigh = "firebrick1",
                   midpoint = 0,
                   varLabel = NULL,
                   ncol = NULL,
                   headers = NULL,
                   decreasing = FALSE) {
            altExp <- SingleCellExperiment::altExp(x, altExpName)
            matrix <- SummarizedExperiment::assay(x, i = useAssay)

            if (is.null(dim1)) {
              dim1 <- SingleCellExperiment::reducedDim(altExp,
                                                       reducedDimName)[, 1]
            }

            if (is.null(dim2)) {
              dim2 <- SingleCellExperiment::reducedDim(altExp,
                                                       reducedDimName)[, 2]
            }

            g <- .plotDimReduceGrid(
              dim1 = dim1,
              dim2 = dim2,
              matrix = matrix,
              size = size,
              xlab = xlab,
              ylab = ylab,
              limits = limits,
              colorLow = colorLow,
              colorMid = colorMid,
              colorHigh = colorHigh,
              midpoint = midpoint,
              varLabel = varLabel,
              ncol = ncol,
              headers = headers,
              decreasing = decreasing
            )
            return(g)
          })


#' @rdname plotDimReduceGrid
#' @examples
#' library(SingleCellExperiment)
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceGrid(x = counts(sce),
#'   dim1 = reducedDim(altExp(sce), "celda_tSNE")[, 1],
#'   dim2 = reducedDim(altExp(sce), "celda_tSNE")[, 2],
#'   xlab = "Dimension1",
#'   ylab = "Dimension2",
#'   varLabel = "tSNE")
#' @export
setMethod("plotDimReduceGrid",
          signature(x = "ANY"),
          function(x,
                   dim1,
                   dim2,
                   size = 1,
                   xlab = "Dimension_1",
                   ylab = "Dimension_2",
                   limits = c(-2, 2),
                   colorLow = "blue4",
                   colorMid = "grey90",
                   colorHigh = "firebrick1",
                   midpoint = 0,
                   varLabel = NULL,
                   ncol = NULL,
                   headers = NULL,
                   decreasing = FALSE) {
            x <- as.matrix(x)
            g <- .plotDimReduceGrid(
              dim1 = dim1,
              dim2 = dim2,
              matrix = x,
              size = size,
              xlab = xlab,
              ylab = ylab,
              limits = limits,
              colorLow = colorLow,
              colorMid = colorMid,
              colorHigh = colorHigh,
              midpoint = midpoint,
              varLabel = varLabel,
              ncol = ncol,
              headers = headers,
              decreasing = decreasing
            )
            return(g)
          })


#' @importFrom reshape2 melt
.plotDimReduceGrid <- function(dim1,
                               dim2,
                               matrix,
                               size,
                               xlab,
                               ylab,
                               limits,
                               colorLow,
                               colorMid,
                               colorHigh,
                               midpoint,
                               varLabel,
                               ncol,
                               headers,
                               decreasing) {
  df <-
    data.frame(dim1, dim2, t(as.data.frame(matrix)), check.names = FALSE)
  naIx <- is.na(dim1) | is.na(dim2)
  df <- df[!naIx, ]

  m <- reshape2::melt(df, id.vars = c("dim1", "dim2"))
  colnames(m) <- c(xlab, ylab, "facet", "Expression")

  if (!is.null(decreasing)) {
    m <- m[order(m$facet, m$Expression, decreasing = decreasing), ]
  }

  if (is.null(midpoint)) {
    midpoint <- mean(m[, 4], trim = 0.1)
  }

  varLabel <- gsub("_", " ", varLabel)

  if (isFALSE(is.null(headers))) {
    names(headers) <- levels(m$facet)
    headers <- ggplot2::as_labeller(headers)

    g <- ggplot2::ggplot(m,
                         ggplot2::aes_string(x = xlab, y = ylab)) +
      ggplot2::geom_point(stat = "identity",
                          size = size,
                          ggplot2::aes_string(color = m$Expression)) +
      ggplot2::theme_bw() +
      ggplot2::scale_colour_gradient2(
        limits = limits,
        low = colorLow,
        high = colorHigh,
        mid = colorMid,
        midpoint = midpoint,
        name = varLabel
      ) +
      ggplot2::theme(
        strip.background = ggplot2::element_blank(),
        panel.grid.major = ggplot2::element_blank(),
        panel.grid.minor = ggplot2::element_blank(),
        panel.spacing = unit(0, "lines"),
        panel.background = ggplot2::element_blank(),
        axis.line = ggplot2::element_line(colour = "black")
      )
    if (isFALSE(is.null(ncol))) {
      g <- g + ggplot2::facet_wrap(~ facet,
                                  labeller = headers,
                                  ncol = ncol)
    } else {
      g <- g + ggplot2::facet_wrap(~ facet, labeller = headers)
    }
  } else {
    g <- ggplot2::ggplot(m,
                         ggplot2::aes_string(x = xlab, y = ylab)) +
      ggplot2::geom_point(stat = "identity",
                          size = size,
                          ggplot2::aes_string(color = m$Expression)) +
      ggplot2::facet_wrap(~ facet) +
      ggplot2::theme_bw() +
      ggplot2::scale_colour_gradient2(
        limits = limits,
        low = colorLow,
        high = colorHigh,
        mid = colorMid,
        midpoint = midpoint,
        name = varLabel
      ) +
      ggplot2::theme(
        strip.background = ggplot2::element_blank(),
        panel.grid.major = ggplot2::element_blank(),
        panel.grid.minor = ggplot2::element_blank(),
        panel.spacing = unit(0, "lines"),
        panel.background = ggplot2::element_blank(),
        axis.line = ggplot2::element_line(colour = "black")
      )
    if (isFALSE(is.null(ncol))) {
      g <- g + ggplot2::facet_wrap(~ facet, ncol = ncol)
    } else {
      g <- g + ggplot2::facet_wrap(~ facet)
    }
  }
  return(g)
}


#' @title Plotting feature expression on a dimension reduction plot
#' @description Create a scatterplot for each row of a normalized gene
#'  expression matrix where x and y axis are from a data dimension
#'  reduction tool. The cells are colored by expression of
#'  the specified feature.
#' @param x Numeric matrix or a \linkS4class{SingleCellExperiment} object
#'  with the matrix located in the assay slot under \code{useAssay}. Rows
#'  represent features and columns represent cells.
#' @param features Character vector. Features in the rownames of counts to plot.
#' @param reducedDimName The name of the dimension reduction slot in
#'  \code{reducedDimNames(x)} if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. If \code{NULL}, then both
#'  \code{dim1} and \code{dim2} need to be set. Default \code{NULL}.
#' @param displayName Character. The column name of
#'  \code{rowData(x)} that specifies the display names for
#'  the features. Default \code{NULL}, which displays the row names. Only works
#'  if \code{x} is a \linkS4class{SingleCellExperiment} object. Overwrites
#'  \code{headers}.
#' @param dim1 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the x-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the x-axis. Default \code{1}.
#' @param dim2 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the y-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the y-axis. Default \code{2}.
#' @param headers Character vector. If \code{NULL}, the corresponding
#'  rownames are used as labels. Otherwise, these headers are used to label
#'  the features. Only works if \code{displayName} is \code{NULL} and
#'  \code{exactMatch} is \code{FALSE}.
#' @param useAssay A string specifying which \link{assay}
#'  slot to use if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param normalize Logical. Whether to normalize the columns of `counts`.
#'  Default \code{FALSE}.
#' @param zscore Logical. Whether to scale each feature to have a mean 0
#' and standard deviation of 1. Default \code{TRUE}.
#' @param exactMatch Logical. Whether an exact match or a partial match using
#'  \code{grep()} is used to look up the feature in the rownames of the counts
#'   matrix. Default TRUE.
#' @param trim Numeric vector. Vector of length two that specifies the lower
#'  and upper bounds for the data. This threshold is applied after row scaling.
#'  Set to NULL to disable. Default \code{c(-1,1)}.
#' @param limits Passed to \link{scale_colour_gradient2}. The range
#'  of color scale.
#' @param size Numeric. Sets size of point on plot. Default 1.
#' @param xlab Character vector. Label for the x-axis. If \code{reducedDimName}
#' is used, then this will be set to the column name of the first dimension of
#' that object. Default "Dimension_1".
#' @param ylab Character vector. Label for the y-axis. If \code{reducedDimName}
#' is used, then this will be set to the column name of the second dimension of
#' that object. Default "Dimension_2".
#' @param colorLow Character. A color available from `colors()`. The color
#'  will be used to signify the lowest values on the scale.
#' @param colorMid Character. A color available from `colors()`. The color
#'  will be used to signify the midpoint on the scale.
#' @param colorHigh Character. A color available from `colors()`. The color
#'  will be used to signify the highest values on the scale.
#' @param midpoint Numeric. The value indicating the midpoint of the
#' diverging color scheme. If \code{NULL}, defaults to the mean
#' with 10 percent of values trimmed. Default \code{0}.
#' @param ncol Integer. Passed to \link[ggplot2]{facet_wrap}. Specify the
#'  number of columns for facet wrap.
#' @param decreasing logical. Specifies the order of plotting the points.
#'  If \code{FALSE}, the points will be plotted in increasing order where
#'  the points with largest values will be on top. \code{TRUE} otherwise.
#'  If \code{NULL}, no sorting is performed. Points will be plotted in their
#'  current order in \code{x}. Default \code{FALSE}.
#' @return The plot as a ggplot object
#' @export
setGeneric("plotDimReduceFeature", function(x,
                                            features,
                                            reducedDimName = NULL,
                                            displayName = NULL,
                                            dim1 = NULL,
                                            dim2 = NULL,
                                            headers = NULL,
                                            useAssay = "counts",
                                            altExpName = "featureSubset",
                                            normalize = FALSE,
                                            zscore = TRUE,
                                            exactMatch = TRUE,
                                            trim = c(-2, 2),
                                            limits = c(-2, 2),
                                            size = 0.5,
                                            xlab = NULL,
                                            ylab = NULL,
                                            colorLow = "blue4",
                                            colorMid = "grey90",
                                            colorHigh = "firebrick1",
                                            midpoint = 0,
                                            ncol = NULL,
                                            decreasing = FALSE) {
  standardGeneric("plotDimReduceFeature")
})


#' @rdname plotDimReduceFeature
#' @examples
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceFeature(x = sce,
#'   reducedDimName = "celda_tSNE",
#'   normalize = TRUE,
#'   features = c("Gene_98", "Gene_99"),
#'   exactMatch = TRUE)
#' @export
setMethod("plotDimReduceFeature",
          signature(x = "SingleCellExperiment"),
          function(x,
                   features,
                   reducedDimName,
                   displayName = NULL,
                   dim1 = 1,
                   dim2 = 2,
                   headers = NULL,
                   useAssay = "counts",
                   altExpName = "featureSubset",
                   normalize = FALSE,
                   zscore = TRUE,
                   exactMatch = TRUE,
                   trim = c(-2, 2),
                   limits = c(-2, 2),
                   size = 0.5,
                   xlab = NULL,
                   ylab = NULL,
                   colorLow = "blue4",
                   colorMid = "grey90",
                   colorHigh = "firebrick1",
                   midpoint = 0,
                   ncol = NULL,
                   decreasing = FALSE) {
            altExp <- SingleCellExperiment::altExp(x, altExpName)
            counts <- SummarizedExperiment::assay(x, i = useAssay)
            reddim <- .processReducedDim(
              x = altExp,
              reducedDimName = reducedDimName,
              dim1 = dim1,
              dim2 = dim2,
              xlab = xlab,
              ylab = ylab
            )

            if (isFALSE(is.null(displayName))) {
              featuresIx <- retrieveFeatureIndex(features,
                                                 x,
                                                 by = displayName,
                                                 exactMatch = exactMatch)
        headers <- SummarizedExperiment::rowData(x)[[displayName]][featuresIx]
            } else {
              featuresIx <- retrieveFeatureIndex(features,
                                                 counts,
                                                 by = "rownames",
                                                 exactMatch = exactMatch)

              if (isFALSE(is.null(headers))) {
                if (length(headers) != length(features)) {
                  stop("Headers ",
                       headers,
                       " should be the same length as features ",
                       features)
                }

                if (isFALSE(exactMatch)) {
                  warning("exactMatch is FALSE. headers will not be used!")
                  headers <- NULL
                }
              }
            }

            g <- .plotDimReduceFeature(
              dim1 = reddim$dim1,
              dim2 = reddim$dim2,
              counts = counts,
              features = features,
              headers = headers,
              normalize = normalize,
              zscore = zscore,
              featuresIx = featuresIx,
              trim = trim,
              limits = limits,
              size = size,
              xlab = reddim$xlab,
              ylab = reddim$ylab,
              colorLow = colorLow,
              colorMid = colorMid,
              colorHigh = colorHigh,
              midpoint = midpoint,
              ncol = ncol,
              decreasing = decreasing
            )
            return(g)
          })


#' @rdname plotDimReduceFeature
#' @examples
#' library(SingleCellExperiment)
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceFeature(x = counts(sce),
#'   dim1 = reducedDim(altExp(sce), "celda_tSNE")[, 1],
#'   dim2 = reducedDim(altExp(sce), "celda_tSNE")[, 2],
#'   normalize = TRUE,
#'   features = c("Gene_98", "Gene_99"),
#'   exactMatch = TRUE)
#' @export
setMethod("plotDimReduceFeature",
          signature(x = "ANY"),
          function(x,
                   features,
                   dim1,
                   dim2,
                   headers = NULL,
                   normalize = FALSE,
                   zscore = TRUE,
                   exactMatch = TRUE,
                   trim = c(-2, 2),
                   limits = c(-2, 2),
                   size = 0.5,
                   xlab = "Dimension_1",
                   ylab = "Dimension_2",
                   colorLow = "blue4",
                   colorMid = "grey90",
                   colorHigh = "firebrick1",
                   midpoint = 0,
                   ncol = NULL,
                   decreasing = FALSE) {
            x <- as.matrix(x)
            if (isFALSE(is.null(headers))) {
              if (length(headers) != length(features)) {
                stop("Headers ",
                     headers,
                     " should be the same length as features ",
                     features)
              }

              if (isFALSE(exactMatch)) {
                warning("exactMatch is FALSE. headers will not be used!")
                headers <- NULL
              }
            }

            featuresIx <- retrieveFeatureIndex(features,
                                               x,
                                               by = "rownames",
                                               exactMatch = exactMatch)

            g <- .plotDimReduceFeature(
              dim1 = dim1,
              dim2 = dim2,
              counts = x,
              features = features,
              headers = headers,
              normalize = normalize,
              zscore = zscore,
              featuresIx = featuresIx,
              trim = trim,
              limits = limits,
              size = size,
              xlab = xlab,
              ylab = ylab,
              colorLow = colorLow,
              colorMid = colorMid,
              colorHigh = colorHigh,
              midpoint = midpoint,
              ncol = ncol,
              decreasing = decreasing
            )
            return(g)
          })


.plotDimReduceFeature <- function(dim1,
                                  dim2,
                                  counts,
                                  features,
                                  headers,
                                  normalize,
                                  zscore,
                                  featuresIx,
                                  trim,
                                  limits,
                                  size,
                                  xlab,
                                  ylab,
                                  colorLow,
                                  colorMid,
                                  colorHigh,
                                  midpoint,
                                  ncol,
                                  decreasing) {
  # Perform checks
  if (is.null(features)) {
    stop("at least one feature is required to create a plot")
  }

  ## Normalize data if needed
  if (isTRUE(normalize)) {
    counts <- normalizeCounts(counts, transformationFun = sqrt)
  }

  # After normalization, features can be selected
  featuresIx <- featuresIx[stats::complete.cases(featuresIx)]
  counts <- as.matrix(counts[featuresIx, , drop = FALSE])

  # Scale/zscore data if needed
  varLabel <- "Expression"
  if (isTRUE(zscore)) {
    counts <- t(scale(t(counts)))
    varLabel <- "Scaled\nExpression"
  }

  if (!is.null(trim)) {
    if (length(trim) != 2) {
      stop("'trim' should be a 2 element vector",
           "specifying the lower and upper boundaries")
    }
    trim <- sort(trim)
    counts[counts < trim[1]] <- trim[1]
    counts[counts > trim[2]] <- trim[2]
  }

  .plotDimReduceGrid(
    dim1 = dim1,
    dim2 = dim2,
    matrix = counts,
    size = size,
    xlab = xlab,
    ylab = ylab,
    limits = limits,
    colorLow = colorLow,
    colorMid = colorMid,
    colorHigh = colorHigh,
    varLabel = varLabel,
    midpoint = midpoint,
    ncol = ncol,
    headers = headers,
    decreasing = decreasing
  )
}


#' @title Plotting Celda module probability on a
#'  dimension reduction plot
#' @description Create a scatterplot for each row of a normalized
#'  gene expression matrix where x and y axis are from a data
#'  dimension reduction tool.
#'  The cells are colored by the module probability.
#' @param x Numeric matrix or a \linkS4class{SingleCellExperiment} object
#'  with the matrix located in the assay slot under \code{useAssay}. Rows
#'  represent features and columns represent cells.
#' @param reducedDimName The name of the dimension reduction slot in
#'  \code{reducedDimNames(x)} if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Ignored if both \code{dim1} and
#'  \code{dim2} are set.
#' @param dim1 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the x-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the x-axis. Default \code{1}.
#' @param dim2 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the y-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the y-axis. Default \code{2}.
#' @param useAssay A string specifying which \link{assay}
#'  slot to use if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param celdaMod Celda object of class "celda_G" or "celda_CG". Used only if
#'  \code{x} is a matrix object.
#' @param modules Character vector. Module(s) from celda model to be plotted.
#'  e.g. c("1", "2").
#' @param size Numeric. Sets size of point on plot. Default 0.5.
#' @param xlab Character vector. Label for the x-axis. Default "Dimension_1".
#' @param ylab Character vector. Label for the y-axis. Default "Dimension_2".
#' @param colorLow Character. A color available from `colors()`.
#'  The color will be used to signify the lowest values on the scale.
#' @param rescale Logical.
#'  Whether rows of the matrix should be rescaled to [0, 1]. Default TRUE.
#' @param limits Passed to \link{scale_colour_gradient}. The range
#'  of color scale.
#' @param colorHigh Character. A color available from `colors()`.
#'  The color will be used to signify the highest values on the scale.
#' @param ncol Integer. Passed to \link[ggplot2]{facet_wrap}. Specify the
#'  number of columns for facet wrap.
#' @param decreasing logical. Specifies the order of plotting the points.
#'  If \code{FALSE}, the points will be plotted in increasing order where
#'  the points with largest values will be on top. \code{TRUE} otherwise.
#'  If \code{NULL}, no sorting is performed. Points will be plotted in their
#'  current order in \code{x}. Default \code{FALSE}.
#' @return The plot as a ggplot object
#' @export
setGeneric("plotDimReduceModule",
           function(x,
                    reducedDimName,
                    useAssay = "counts",
                    altExpName = "featureSubset",
                    celdaMod,
                    modules = NULL,
                    dim1 = NULL,
                    dim2 = NULL,
                    size = 0.5,
                    xlab = NULL,
                    ylab = NULL,
                    rescale = TRUE,
                    limits = c(0, 1),
                    colorLow = "grey90",
                    colorHigh = "firebrick1",
                    ncol = NULL,
                    decreasing = FALSE) {
             standardGeneric("plotDimReduceModule")
           })


#' @rdname plotDimReduceModule
#' @examples
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceModule(x = sce,
#'   reducedDimName = "celda_tSNE",
#'   modules = c("1", "2"))
#' @export
setMethod("plotDimReduceModule",
          signature(x = "SingleCellExperiment"),
          function(x,
                   reducedDimName,
                   useAssay = "counts",
                   altExpName = "featureSubset",
                   modules = NULL,
                   dim1 = 1,
                   dim2 = 2,
                   size = 0.5,
                   xlab = NULL,
                   ylab = NULL,
                   rescale = TRUE,
                   limits = c(0, 1),
                   colorLow = "grey90",
                   colorHigh = "firebrick1",
                   ncol = NULL,
                   decreasing = FALSE) {
            # Get reduced dim object
            altExp <- SingleCellExperiment::altExp(x, altExpName)
            reddim <- .processReducedDim(
              x = altExp,
              reducedDimName = reducedDimName,
              dim1 = dim1,
              dim2 = dim2,
              xlab = xlab,
              ylab = ylab
            )

            factorized <- factorizeMatrix(x,
                                          useAssay = useAssay,
                                          altExpName = altExpName,
                                          type = "proportion")

            g <- .plotDimReduceModule(
              dim1 = reddim$dim1,
              dim2 = reddim$dim2,
              factorized = factorized,
              modules = modules,
              rescale = rescale,
              limits = limits,
              size = size,
              xlab = reddim$xlab,
              ylab = reddim$ylab,
              colorLow = colorLow,
              colorHigh = colorHigh,
              ncol = ncol,
              decreasing = decreasing
            )
            return(g)
          })


#' @rdname plotDimReduceModule
#' @examples
#' library(SingleCellExperiment)
#' data(sceCeldaCG, celdaCGMod)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceModule(x = counts(sce),
#'   dim1 = reducedDim(altExp(sce), "celda_tSNE")[, 1],
#'   dim2 = reducedDim(altExp(sce), "celda_tSNE")[, 2],
#'   celdaMod = celdaCGMod,
#'   modules = c("1", "2"))
#' @export
setMethod("plotDimReduceModule",
          signature(x = "ANY"),
          function(x,
                   celdaMod,
                   modules = NULL,
                   dim1,
                   dim2,
                   size = 0.5,
                   xlab = "Dimension_1",
                   ylab = "Dimension_2",
                   rescale = TRUE,
                   limits = c(0, 1),
                   colorLow = "grey90",
                   colorHigh = "firebrick1",
                   ncol = NULL,
                   decreasing = FALSE) {
            factorized <- factorizeMatrix(x = x, celdaMod = celdaMod)
            reddim <- .processReducedDim(
              x = x,
              dim1 = dim1,
              dim2 = dim2,
              xlab = xlab,
              ylab = ylab
            )

            g <- .plotDimReduceModule(
              dim1 = reddim$dim1,
              dim2 = reddim$dim2,
              factorized = factorized,
              modules = modules,
              rescale = rescale,
              limits = limits,
              size = size,
              xlab = reddim$xlab,
              ylab = reddim$ylab,
              colorLow = colorLow,
              colorHigh = colorHigh,
              ncol = ncol,
              decreasing = decreasing
            )
            return(g)
          })


.plotDimReduceModule <- function(dim1,
                                 dim2,
                                 factorized,
                                 modules,
                                 rescale,
                                 limits,
                                 size,
                                 xlab,
                                 ylab,
                                 colorLow,
                                 colorHigh,
                                 ncol,
                                 decreasing) {
  matrix <- factorized$proportions$cell

  if (rescale == TRUE) {
    for (x in seq(nrow(matrix))) {
      matrix[x, ] <- matrix[x, ] - min(matrix[x, ])
      matrix[x, ] <- matrix[x, ] / max(matrix[x, ])
      varLabel <- "Scaled Probability"
    }
  } else {
    varLabel <- "Probability"
  }

  rownames(matrix) <- gsub("L", "", rownames(matrix))
  if (!is.null(modules)) {
    if (length(rownames(matrix)[rownames(matrix) %in% modules]) < 1) {
      stop("All modules selected do not exist in the model.")
    }
    matrix <-
      matrix[which(rownames(matrix) %in% modules), , drop = FALSE]
    matrix <-
      matrix[match(rownames(matrix), modules), , drop = FALSE]
  }

  rownames(matrix) <- paste0("L", rownames(matrix))

  df <-
    data.frame(dim1, dim2, t(as.data.frame(matrix)), check.names = FALSE)
  naIx <- is.na(dim1) | is.na(dim2)
  df <- df[!naIx, ]

  m <- reshape2::melt(df, id.vars = c("dim1", "dim2"))
  colnames(m) <- c(xlab, ylab, "facet", "Expression")

  if (!is.null(decreasing)) {
    m <- m[order(m$facet, m$Expression, decreasing = decreasing), ]
  }

  g <-
    ggplot2::ggplot(m, ggplot2::aes_string(x = xlab, y = ylab)) +
    ggplot2::geom_point(stat = "identity",
                        size = size,
                        ggplot2::aes_string(color = m$Expression)) +
    ggplot2::facet_wrap(~ facet) +
    ggplot2::theme_bw() +
    ggplot2::scale_colour_gradient(
      limits = limits,
      low = colorLow,
      high = colorHigh,
      name = varLabel
    ) +
    ggplot2::theme(
      strip.background = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      panel.spacing = unit(0, "lines"),
      panel.background = ggplot2::element_blank(),
      axis.line = ggplot2::element_line(colour = "black")
    )
  if (isFALSE(is.null(ncol))) {
    g <- g + ggplot2::facet_wrap(~ facet, ncol = ncol)
  } else {
    g <- g + ggplot2::facet_wrap(~ facet)
  }

  return(g)
}


# Labeling code adapted from Seurat (https://github.com/satijalab/seurat)
#' @title Plotting the cell labels on a dimension reduction plot
#' @description Create a scatterplot for each row of a normalized
#'  gene expression matrix where x and y axis are from a
#'  data dimension reduction tool.
#'  The cells are colored by "celda_cell_cluster" column in
#'   \code{colData(altExp(x, altExpName))} if \code{x} is a
#'   \linkS4class{SingleCellExperiment} object, or \code{x} if \code{x} is
#'   a integer vector of cell cluster labels.
#' @param x Integer vector of cell cluster labels or a
#'  \linkS4class{SingleCellExperiment} object
#'  containing cluster labels for each cell in \code{"celda_cell_cluster"}
#'  column in \code{colData(x)}.
#' @param reducedDimName The name of the dimension reduction slot in
#'  \code{reducedDimNames(x)} if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Ignored if both \code{dim1} and
#'  \code{dim2} are set.
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param dim1 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the x-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the x-axis. Default \code{1}.
#' @param dim2 Integer or numeric vector. If \code{reducedDimName} is supplied,
#'  then, this will be used as an index to determine which dimension will be
#'  plotted on the y-axis. If \code{reducedDimName} is not supplied, then this
#'  should be a vector which will be plotted on the y-axis. Default \code{2}.
#' @param size Numeric. Sets size of point on plot. Default \code{0.5}.
#' @param xlab Character vector. Label for the x-axis. Default \code{NULL}.
#' @param ylab Character vector. Label for the y-axis. Default \code{NULL}.
#' @param specificClusters Numeric vector.
#'  Only color cells in the specified clusters.
#'  All other cells will be grey.
#'  If NULL, all clusters will be colored. Default \code{NULL}.
#' @param labelClusters Logical. Whether the cluster labels are plotted.
#'  Default FALSE.
#' @param groupBy Character vector. Contains sample labels for each cell.
#'  If NULL, all samples will be plotted together. Default NULL.
#' @param labelSize Numeric. Sets size of label if labelClusters is TRUE.
#'  Default 3.5.
#' @return The plot as a ggplot object
#' @importFrom ggrepel geom_text_repel
#' @export
setGeneric("plotDimReduceCluster",
           function(x,
                    reducedDimName,
                    altExpName = "featureSubset",
                    dim1 = NULL,
                    dim2 = NULL,
                    size = 0.5,
                    xlab = NULL,
                    ylab = NULL,
                    specificClusters = NULL,
                    labelClusters = FALSE,
                    groupBy = NULL,
                    labelSize = 3.5) {
             standardGeneric("plotDimReduceCluster")
           })


#' @rdname plotDimReduceCluster
#' @examples
#' data(sceCeldaCG)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceCluster(x = sce,
#'   reducedDimName = "celda_tSNE",
#'   specificClusters = c(1, 2, 3))
#' @export
setMethod("plotDimReduceCluster",
          signature(x = "SingleCellExperiment"),
          function(x,
                   reducedDimName,
                   altExpName = "featureSubset",
                   dim1 = 1,
                   dim2 = 2,
                   size = 0.5,
                   xlab = NULL,
                   ylab = NULL,
                   specificClusters = NULL,
                   labelClusters = FALSE,
                   groupBy = NULL,
                   labelSize = 3.5) {
            altExp <- SingleCellExperiment::altExp(x, altExpName)

            if (!("celda_cell_cluster" %in%
                  colnames(SummarizedExperiment::colData(altExp)))) {
              stop("Must have column 'celda_cell_cluster' in",
                   " colData(altExp(x, altExpName))!")
            }
            cluster <-
              SummarizedExperiment::colData(altExp)[["celda_cell_cluster"]]

            reddim <- .processReducedDim(
              x = altExp,
              reducedDimName = reducedDimName,
              dim1 = dim1,
              dim2 = dim2,
              xlab = xlab,
              ylab = ylab
            )

            g <- .plotDimReduceCluster(
              dim1 = reddim$dim1,
              dim2 = reddim$dim2,
              cluster = cluster,
              size = size,
              xlab = reddim$xlab,
              ylab = reddim$ylab,
              specificClusters = specificClusters,
              labelClusters = labelClusters,
              groupBy = groupBy,
              labelSize = labelSize
            )
            return(g)
          })


#' @rdname plotDimReduceCluster
#' @examples
#' library(SingleCellExperiment)
#' data(sceCeldaCG, celdaCGMod)
#' sce <- celdaTsne(sceCeldaCG)
#' plotDimReduceCluster(x = celdaClusters(celdaCGMod)$z,
#'   dim1 = reducedDim(altExp(sce), "celda_tSNE")[, 1],
#'   dim2 = reducedDim(altExp(sce), "celda_tSNE")[, 2],
#'   specificClusters = c(1, 2, 3))
#' @export
setMethod("plotDimReduceCluster",
          signature(x = "vector"),
          function(x,
                   dim1,
                   dim2,
                   size = 0.5,
                   xlab = "Dimension_1",
                   ylab = "Dimension_2",
                   specificClusters = NULL,
                   labelClusters = FALSE,
                   groupBy = NULL,
                   labelSize = 3.5) {
            reddim <- .processReducedDim(
              x = x,
              dim1 = dim1,
              dim2 = dim2,
              xlab = xlab,
              ylab = ylab
            )
            g <- .plotDimReduceCluster(
              dim1 = reddim$dim1,
              dim2 = reddim$dim2,
              cluster = x,
              size = size,
              xlab = reddim$xlab,
              ylab = reddim$ylab,
              specificClusters = specificClusters,
              labelClusters = labelClusters,
              groupBy = groupBy,
              labelSize = labelSize
            )
            return(g)
          })


.plotDimReduceCluster <- function(dim1,
                                  dim2,
                                  cluster,
                                  size,
                                  xlab,
                                  ylab,
                                  specificClusters,
                                  labelClusters,
                                  groupBy,
                                  labelSize) {
  if (!is.null(groupBy)) {
    df <- data.frame(dim1, dim2, cluster, groupBy)
    colnames(df) <- c(xlab, ylab, "Cluster", "Sample")
  } else {
    df <- data.frame(dim1, dim2, cluster)
    colnames(df) <- c(xlab, ylab, "Cluster")
  }

  naIx <- is.na(dim1) | is.na(dim2)
  df <- df[!naIx, ]
  df[3] <- as.factor(df[[3]])
  clusterColors <- distinctColors(nlevels(as.factor(cluster)))

  if (!is.null(specificClusters)) {
    clusterColors[!levels(df[[3]]) %in% specificClusters] <- "gray92"
  }

  g <-
    ggplot2::ggplot(df, ggplot2::aes_string(x = xlab, y = ylab)) +
    ggplot2::geom_point(stat = "identity",
                        size = size,
                        ggplot2::aes_string(color = "Cluster")) +
    ggplot2::theme(
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      panel.background = ggplot2::element_blank(),
      axis.line = ggplot2::element_line(color = "black")
    ) +
    ggplot2::scale_color_manual(values = clusterColors) +
    ggplot2::guides(color =
                      ggplot2::guide_legend(override.aes = list(size = 1)))

  if (isTRUE(labelClusters)) {
    # centroidList <- lapply(seq(length(unique(cluster))), function(x) {
    centroidList <- lapply(unique(cluster), function(x) {
      df.sub <- df[df$Cluster == x, ]
      median1 <- stats::median(df.sub[, xlab])
      median2 <- stats::median(df.sub[, ylab])
      data.frame(median1 = median1,
                 median2 = median2,
                 x = x)
    })
    centroid <- do.call(rbind, centroidList)
    centroid <- data.frame(
      Dimension_1 = as.numeric(centroid[, 1]),
      Dimension_2 = as.numeric(centroid[, 2]),
      Cluster = centroid[, 3]
    )

    colnames(centroid)[seq(2)] <- c(xlab, ylab)
    g <- g + ggplot2::geom_point(
      data = centroid,
      mapping = ggplot2::aes_string(x = xlab,
                                    y = ylab),
      size = 0,
      alpha = 0
    ) +
      ggrepel::geom_text_repel(
        data = centroid,
        mapping = ggplot2::aes_string(label = "Cluster"),
        size = labelSize,
        max.overlaps = Inf
      )
  }
  if (!is.null(x = groupBy)) {
    g <- g + ggplot2::facet_wrap(facets =
                                   ggplot2::vars(!!ggplot2::sym(x =
                                                                  "Sample"))) +
      ggplot2::theme(strip.background = ggplot2::element_blank())
  }
  return(g)
}


#' @title Feature Expression Violin Plot
#' @description Outputs a violin plot for feature expression data.
#' @param x Numeric matrix or a \linkS4class{SingleCellExperiment} object
#'  with the matrix located in the assay slot under \code{useAssay}. Rows
#'  represent features and columns represent cells.
#' @param features Character vector. Uses these genes for plotting.
#' @param displayName Character. The column name of
#'  \code{rowData(x)} that specifies the display names for
#'  the features. Default \code{NULL}, which displays the row names. Only works
#'  if \code{x} is a \linkS4class{SingleCellExperiment} object.
#' @param useAssay A string specifying which \link{assay}
#'  slot to use if \code{x} is a
#'  \linkS4class{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#'  to use. Default "featureSubset".
#' @param celdaMod Celda object of class "celda_G" or "celda_CG". Used only if
#'  \code{x} is a matrix object.
#' @param exactMatch Logical. Whether an exact match or a partial match using
#'  \code{grep()} is used to look up the feature in the rownames of the counts
#'   matrix. Default \code{TRUE}.
#' @param plotDots Boolean. If \code{TRUE}, the
#'  expression of features will be plotted as points in addition to the violin
#'  curve. Default \code{TRUE}.
#' @param dotSize Numeric. Size of points if \code{plotDots = TRUE}.
#' Default \code{0.1}.
#' @return Violin plot for each feature, grouped by celda cluster
#' @export
setGeneric("plotCeldaViolin",
           function(x,
                    celdaMod,
                    features,
                    displayName = NULL,
                    useAssay = "counts",
                    altExpName = "featureSubset",
                    exactMatch = TRUE,
                    plotDots = TRUE,
                    dotSize = 0.1) {
             standardGeneric("plotCeldaViolin")
           })


#' @rdname plotCeldaViolin
#' @examples
#' data(sceCeldaCG)
#' plotCeldaViolin(x = sceCeldaCG, features = "Gene_1")
#' @export
setMethod("plotCeldaViolin",
          signature(x = "SingleCellExperiment"),
          function(x,
                   features,
                   displayName = NULL,
                   useAssay = "counts",
                   altExpName = "featureSubset",
                   exactMatch = TRUE,
                   plotDots = TRUE,
                   dotSize = 0.1) {
            counts <- SummarizedExperiment::assay(x, i = useAssay)
            cluster <-
              as.integer(celdaClusters(x, altExpName = altExpName))

            if (is.null(displayName)) {
              featuresIx <- retrieveFeatureIndex(features,
                                                 counts,
                                                 by = "rownames",
                                                 exactMatch = exactMatch)
              rnames <- rownames(x)[featuresIx]
            } else {
              featuresIx <- retrieveFeatureIndex(features,
                                                 x,
                                                 by = displayName,
                                                 exactMatch = exactMatch)
              rnames <-
                SummarizedExperiment::rowData(x)[featuresIx, displayName]
            }

            g <- .plotCeldaViolin(
              counts = counts,
              cluster = cluster,
              features = features,
              featuresIx = featuresIx,
              rnames = rnames,
              exactMatch = exactMatch,
              plotDots = plotDots,
              dotSize = dotSize
            )
            return(g)
          })


#' @rdname plotCeldaViolin
#' @examples
#' data(celdaCGSim, celdaCGMod)
#' plotCeldaViolin(x = celdaCGSim$counts,
#'    celdaMod = celdaCGMod,
#'    features = "Gene_1")
#' @export
setMethod("plotCeldaViolin",
          signature(x = "ANY"),
          function(x,
                   celdaMod,
                   features,
                   exactMatch = TRUE,
                   plotDots = TRUE,
                   dotSize = 0.1) {
            x <- as.matrix(x)
            cluster <- celdaClusters(celdaMod)$z

            featuresIx <- retrieveFeatureIndex(features,
                                               x,
                                               by = "rownames",
                                               exactMatch = exactMatch)
            rnames <- rownames(x)[featuresIx]

            g <- .plotCeldaViolin(
              counts = x,
              cluster = cluster,
              features = features,
              featuresIx = featuresIx,
              rnames = rnames,
              exactMatch = exactMatch,
              plotDots = plotDots,
              dotSize = dotSize
            )
            return(g)
          })


.plotCeldaViolin <- function(counts,
                             cluster,
                             features,
                             featuresIx,
                             rnames,
                             exactMatch = TRUE,
                             plotDots = TRUE,
                             dotSize = 0.1) {
  dataFeature <- as.matrix(counts[featuresIx, , drop = FALSE])
  rownames(dataFeature) <- rnames
  dataFeature <- as.data.frame(t(dataFeature))
  df <- cbind(cluster, dataFeature)
  df$cluster <- as.factor(df$cluster)
  m <- reshape2::melt(df, id.vars = c("cluster"))
  colnames(m) <- c("Cluster", "Feature", "Expression")
  colorPal <- distinctColors(length(unique(cluster)))

  p <- ggplot2::ggplot(m,
                       ggplot2::aes_string(x = "Cluster",
                                           y = "Expression",
                                           fill = "Cluster")) +
    ggplot2::facet_wrap(~ Feature) +
    ggplot2::geom_violin(trim = TRUE, scale = "width") +
    ggplot2::scale_fill_manual(values = colorPal) +
    ggplot2::theme(
      strip.background = ggplot2::element_blank(),
      panel.grid.major = ggplot2::element_blank(),
      panel.grid.minor = ggplot2::element_blank(),
      panel.spacing = grid::unit(0, "lines"),
      panel.background = ggplot2::element_blank(),
      axis.line = ggplot2::element_line(colour = "black")
    )

  if (isTRUE(plotDots)) {
    p <- p + ggplot2::geom_jitter(height = 0, size = dotSize)
  }

  return(p)
}




.processReducedDim <- function(x,
                               reducedDimName = NULL,
                               dim1 = NULL,
                               dim2 = NULL,
                               xlab = NULL,
                               ylab = NULL) {
  if (inherits(x, "SingleCellExperiment") &
      !is.null(reducedDimName)) {
    reddim <- SingleCellExperiment::reducedDim(x, reducedDimName)

    # Get dims to retrieve from redDim object
    if (is.null(dim1))
      dim1 <- 1
    if (is.null(dim2))
      dim2 <- 2

    # Get labels
    xlab <- colnames(reddim)[dim1]
    ylab <- colnames(reddim)[dim2]

    # Set up return object
    res <-
      list(
        dim1 = reddim[, dim1],
        dim2 = reddim[, dim2],
        xlab = xlab,
        ylab = ylab
      )

  } else if (!is.null(dim1) & !is.null(dim2)) {
    if (inherits(x, c("matrix", "Matrix"))) {
      if (length(dim1) != ncol(x)) {
        stop("'dim1' needs to be the same length as 'x'.")
      }
      if (length(dim2) != ncol(x)) {
        stop("'dim2' needs to be the same length as 'x'.")
      }
    } else {
      if (length(dim1) != length(x)) {
        stop("'dim1' needs to be the same length as 'x'.")
      }
      if (length(dim2) != length(x)) {
        stop("'dim2' needs to be the same length as 'x'.")
      }
    }
    if (is.null(xlab))
      xlab <- "Dimension 1"
    if (is.null(ylab))
      ylab <- "Dimension 2"

    res <- list(
      dim1 = dim1,
      dim2 = dim2,
      xlab = xlab,
      ylab = ylab
    )

  } else {
    stop(
    "'x' can be supplied as a SingleCelExperiment along with ",
    "'reducedDimName' and 'dim1' and 'dim2' can be used to specify which ",
    "dimensions to plot on the x- and y-axis, respectively. Alternatively, ",
    "'x', 'dim1' and 'dim2' can be supplied as vectors of the same length ",
    "where 'dim1' is the x-axis, 'dim2', is the y-axis, and 'x' will be used ",
    "to color the points."
    )
  }

  return(res)
}