#' @title Calculate and visualize perplexity of all models in a celdaList, with
#'  count resampling
#' @description Calculates the perplexity of each model's cluster assignments
#'  given the provided countMatrix, as well as resamplings of that count
#'  matrix, providing a distribution of perplexities and a better sense of the
#'  quality of a given K/L choice.
#' @param counts Integer matrix. Rows represent features and columns represent
#'  cells. This matrix should be the same as the one used to generate
#'  `celda.mod`.
#' @param celdaList Object of class 'celdaList'.
#' @param resample Integer. The number of times to resample the counts matrix
#'  for evaluating perplexity. Default 5.
#' @param seed Integer. Passed to \link[withr]{with_seed}. For reproducibility,
#'  a default value of 12345 is used. If NULL, no calls to
#'  \link[withr]{with_seed} are made.
#' @return celdaList. Returns the provided `celdaList` with a `perplexity`
#'  property, detailing the perplexity of all K/L combinations that appeared in
#'  the celdaList's models.
#' @examples
#' data(celdaCGSim, celdaCGGridSearchRes)
#' celdaCGGridSearchRes <- resamplePerplexity(
#'   celdaCGSim$counts,
#'   celdaCGGridSearchRes)
#' plotGridSearchPerplexity(celdaCGGridSearchRes)
#' @export
resamplePerplexity <- function(counts,
    celdaList,
    resample = 5,
    seed = 12345) {

    if (is.null(seed)) {
        res <- .resamplePerplexity(counts = counts,
            celdaList = celdaList,
            resample = resample)
    } else {
        with_seed(seed,
            res <- .resamplePerplexity(counts = counts,
                celdaList = celdaList,
                resample = resample))
    }

    return(res)
}


.resamplePerplexity <- function(counts,
    celdaList,
    resample = 5) {

    if (!methods::is(celdaList, "celdaList")) {
        stop("celdaList parameter was not of class celdaList.")
    }
    if (!isTRUE(is.numeric(resample))) {
        stop("Provided resample parameter was not numeric.")
    }

    perpRes <- matrix(NA, nrow = length(resList(celdaList)), ncol = resample)
    for (j in seq(resample)) {
        newCounts <- .resampleCountMatrix(counts)
        for (i in seq(length(resList(celdaList)))) {
            perpRes[i, j] <- perplexity(counts, resList(celdaList)[[i]],
                newCounts)
        }
    }
    celdaList@perplexity <- perpRes

    ## Add mean perplexity to runParams
    perpMean <- apply(perpRes, 1, mean)
    celdaList@runParams$mean_perplexity <- perpMean

    return(celdaList)
}


#' @title Visualize perplexity of a list of celda models
#' @description Visualize perplexity of every model in a celdaList, by unique
#'  K/L combinations
#' @param celdaList Object of class 'celdaList'.
#' @param sep Numeric. Breaks in the x axis of the resulting plo.t.
#' @return A ggplot plot object showing perplexity as a function of clustering
#'  parameters.
#' @examples
#' data(celdaCGSim, celdaCGGridSearchRes)
#' ## Run various combinations of parameters with 'celdaGridSearch'
#' celdaCGGridSearchRes <- resamplePerplexity(
#'   celdaCGSim$counts,
#'   celdaCGGridSearchRes)
#' plotGridSearchPerplexity(celdaCGGridSearchRes)
#' @export
plotGridSearchPerplexity <- function(celdaList, sep = 1) {
    do.call(paste0("plotGridSearchPerplexity",
        as.character(class(resList(celdaList)[[1]]))),
        args = list(celdaList, sep))
}


#' @title Plot perplexity as a function of K and L from celda_CG models
#' @description This function plots perplexity as a function of the cell/gene
#'  (K/L) clusters as generated by celdaGridSearch().
#' @param celdaList Object of class 'celdaList'.
#' @param sep Numeric. Breaks in the x axis of the resulting plot.
#' @return A ggplot plot object showing perplexity as a function of clustering
#'  parameters.
#' @examples
#' data(celdaCGSim, celdaCGGridSearchRes)
#' celdaCGGridSearchRes <- resamplePerplexity(
#'   celdaCGSim$counts,
#'   celdaCGGridSearchRes
#' )
#' plotGridSearchPerplexity(celdaCGGridSearchRes)
#' @export
plotGridSearchPerplexitycelda_CG <- function(celdaList, sep) {
    if (!all(c("K", "L") %in% colnames(runParams(celdaList)))) {
        stop("runParams(celdaList) needs K and L columns.")
    }
    if (is.null(celdaPerplexity(celdaList))) {
        stop("No perplexity measurements available. First run",
            " 'resamplePerplexity' with celdaList object.")
    }

    ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
        each = ncol(celdaPerplexity(celdaList)))
    ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
        nrow(celdaPerplexity(celdaList)))
    df <- data.frame(runParams(celdaList)[ix1, ],
        perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
    df$K <- as.factor(df$K)
    df$L <- as.factor(df$L)

    lMeansByK <- stats::aggregate(df$perplexity, by = list(df$K, df$L),
        FUN = mean)
    colnames(lMeansByK) <- c("K", "L", "mean_perplexity")
    lMeansByK$K <- as.factor(lMeansByK$K)
    lMeansByK$L <- as.factor(lMeansByK$L)

    if (nlevels(df$K) > 1) {
        plot <- ggplot2::ggplot(df,
            ggplot2::aes_string(x = "K", y = "perplexity")) +
            ggplot2::geom_jitter(height = 0, width = 0.1,
                ggplot2::aes_string(color = "L")) +
            ggplot2::scale_color_discrete(name = "L") +
            ggplot2::geom_path(data = lMeansByK, ggplot2::aes_string(x = "K",
                y = "mean_perplexity", group = "L", color = "L")) +
            ggplot2::ylab("Perplexity") +
            ggplot2::xlab("K") +
            ggplot2::scale_x_discrete(breaks = seq(
                min(runParams(celdaList)$K),
                max(runParams(celdaList)$K), sep)) +
            ggplot2::theme_bw()
    } else {
        plot <-
            ggplot2::ggplot(df,
                ggplot2::aes_string(x = "L", y = "perplexity")) +
            ggplot2::geom_jitter(height = 0, width = 0.1,
                ggplot2::aes_string(color = "K")) +
            ggplot2::scale_color_discrete(name = "K") +
            ggplot2::geom_path(data = lMeansByK,
                ggplot2::aes_string(x = "L", y = "mean_perplexity", group = "K",
                    color = "K")) +
            ggplot2::ylab("Perplexity") +
            ggplot2::xlab("L") +
            ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$L),
                max(runParams(celdaList)$L), sep)) +
            ggplot2::theme_bw()
    }

    return(plot)
}


#' @title Plot perplexity as a function of K from celda_C models
#' @description Plots perplexity as a function of the cell (K) clusters as
#'   generated by celdaGridSearch().
#' @param celdaList Object of class 'celdaList'.
#' @param sep Numeric. Breaks in the x axis of the resulting plot.
#' @return A ggplot plot object showing perplexity as a function of clustering
#'   parameters.
#' @examples
#' data(celdaCGSim, celdaCGGridSearchRes)
#' celdaCGGridSearchRes <- resamplePerplexity(
#'   celdaCGSim$counts,
#'   celdaCGGridSearchRes
#' )
#' plotGridSearchPerplexity(celdaCGGridSearchRes)
#' @export
plotGridSearchPerplexitycelda_C <- function(celdaList, sep) {
    if (!all(c("K") %in% colnames(runParams(celdaList)))) {
        stop("runParams(celdaList) needs the column K.")
    }
    if (is.null(celdaPerplexity(celdaList))) {
        stop("No perplexity measurements available. First run",
            " 'resamplePerplexity' with celdaList object.")
    }

    ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
        each = ncol(celdaPerplexity(celdaList)))
    ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
        nrow(celdaPerplexity(celdaList)))
    df <- data.frame(runParams(celdaList)[ix1, ],
        perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
    df$K <- as.factor(df$K)

    meansByK <- stats::aggregate(df$perplexity, by = list(df$K), FUN = mean)
    colnames(meansByK) <- c("K", "mean_perplexity")
    meansByK$K <- as.factor(meansByK$K)

    plot <-
        ggplot2::ggplot(df, ggplot2::aes_string(x = "K", y = "perplexity")) +
        ggplot2::geom_jitter(height = 0, width = 0.1) +
        ggplot2::geom_path(data = meansByK,
            ggplot2::aes_string(x = "K", y = "mean_perplexity", group = 1)) +
        ggplot2::ylab("Perplexity") +
        ggplot2::xlab("K") +
        ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$K),
            max(runParams(celdaList)$K), sep)) +
        ggplot2::theme_bw()

    return(plot)
}


#' @title Plot perplexity as a function of L from a celda_G model
#' @description Plots perplexity as a function of the gene (L) clusters as
#'   generated by celdaGridSearch().
#' @param celdaList Object of class 'celdaList'.
#' @param sep Numeric. Breaks in the x axis of the resulting plot.
#' @return A ggplot plot object showing perplexity as a function of clustering
#'   parameters.
#' @examples
#' data(celdaCGSim, celdaCGGridSearchRes)
#' celdaCGGridSearchRes <- resamplePerplexity(
#'   celdaCGSim$counts,
#'   celdaCGGridSearchRes)
#' plotGridSearchPerplexity(celdaCGGridSearchRes)
#' @export
plotGridSearchPerplexitycelda_G <- function(celdaList, sep) {
    if (!all(c("L") %in% colnames(runParams(celdaList)))) {
        stop("runParams(celdaList) needs the column L.")
    }
    if (length(celdaPerplexity(celdaList)) == 0) {
        stop("No perplexity measurements available. First run",
            " 'resamplePerplexity' with celdaList object.")
    }

    ix1 <- rep(seq(nrow(celdaPerplexity(celdaList))),
        each = ncol(celdaPerplexity(celdaList)))
    ix2 <- rep(seq(ncol(celdaPerplexity(celdaList))),
        nrow(celdaPerplexity(celdaList)))
    df <- data.frame(runParams(celdaList)[ix1, ],
        perplexity = celdaPerplexity(celdaList)[cbind(ix1, ix2)])
    df$L <- as.factor(df$L)


    meansByL <- stats::aggregate(df$perplexity, by = list(df$L), FUN = mean)
    colnames(meansByL) <- c("L", "mean_perplexity")
    meansByL$L <- as.factor(meansByL$L)

    plot <-
        ggplot2::ggplot(df, ggplot2::aes_string(x = "L", y = "perplexity")) +
        ggplot2::geom_jitter(height = 0, width = 0.1) +
        ggplot2::geom_path(data = meansByL,
            ggplot2::aes_string(x = "L", y = "mean_perplexity", group = 1)) +
        ggplot2::ylab("Perplexity") +
        ggplot2::xlab("L") +
        ggplot2::scale_x_discrete(breaks = seq(min(runParams(celdaList)$L),
            max(runParams(celdaList)$L), sep)) +
        ggplot2::theme_bw()

    return(plot)
}


# Resample a counts matrix for evaluating perplexity
# Normalizes each column (cell) of a countMatrix by the column sum to
# create a distribution of observing a given number of counts for a given
# gene in that cell,
# then samples across all cells.
# This is primarily used to evaluate the stability of the perplexity for
# a given K/L combination.
# @param celda.mod A single celda run (usually from the _resList_ property
# of a celdaList).
# @return The perplexity for the provided chain as an mpfr number.
.resampleCountMatrix <- function(countMatrix) {
    colsums <- colSums(countMatrix)
    prob <- t(t(countMatrix) / colsums)
    resample <- vapply(seq(ncol(countMatrix)), function(idx) {
        stats::rmultinom(n = 1,
            size = colsums[idx],
            prob = prob[, idx])
    }, integer(nrow(countMatrix)))
    return(resample)
}