#' Make windows merging restriction fragments
#' 
#' Use a set of continuous restriction fragments to generate windows containing
#' a fixed number of fragments (n_frags).
#' @param input Input object containing the restriction fragments. Should be 
#' class UMI4C (rowRanges will be extracted) or class GRanges.
#' @param n_frags Number of fragments to use for generating the windows. This 
#' should include restriction fragments with 0 counts (Default: 8).
#' @param sliding Numeric indicating the factor for generating sliding windows.
#' If set to 1 (default) will use fixed windows. If set to > 0 and < 1 will use
#' n_frags * sliding fragments to generate sliding windows.
#' @return A GRanges object containing the windows of merged restriction 
#' fragments.
#' @export
#' @examples
#' data("ex_ciita_umi4c")
#' 
#' # Without sliding windows
#' win_frags <- makeWindowFragments(ex_ciita_umi4c, n_frags=30, sliding=1)
#' win_frags
#' 
#' # With sliding windows (n_frags*sliding)
#' win_frags <- makeWindowFragments(ex_ciita_umi4c, n_frags=30, sliding=0.5)
#' win_frags
makeWindowFragments <- function(input, 
                                n_frags=8,
                                sliding=1) {
  if (isClass(input, "UMI4C")) frags <- rowRanges(input)
  else if (isClass(input, "GRanges")) frags <- input
  else stop("Input object should be of class 'UMI4C' or 'GRanges'")
  
  frags <- frags[order(frags)]
  start_upstream <- frags$id_contact[frags$position=="upstream" & start(frags) == max(start(frags[frags$position=="upstream"]))]
  start_downstream <- frags$id_contact[frags$position=="downstream" & start(frags) == min(start(frags[frags$position=="downstream"]))]
  
  if (sliding==1) window_frags <- .makeFixedWindow(frags, n_frags, start_upstream, start_downstream)
  else if (sliding < 1 & sliding > 0) window_frags <- .makeSlidingWindow(frags, n_frags, start_upstream, start_downstream, sliding)
  else stop("Sliding value should be 1 (no sliding) or > 0 & < 1 for sliding.")
  
  return(window_frags)
}

.makeFixedWindow <- function(frags,
                             n_frags,
                             start_upstream,
                             start_downstream) {
  ## Upstream
  up1 <- seq(grep(start_upstream, frags$id_contact), 1, by=-n_frags)
  up2 <- c(up1[-1]+1, 1)
  frags_upstream <- GRanges(seqnames = unique(seqnames(frags)),
                            ranges = IRanges(start=start(frags[up2]), end=end(frags[up1])),
                            mcols= data.frame("id"=paste0("window_UP_", seq_len(length(up1))),
                                              "position"="upstream"))
  
  ## Downstream
  dwn1 <- seq(grep(start_downstream, frags$id_contact), length(frags), by=n_frags)
  dwn2 <- c(dwn1[-1]-1, length(frags))
  frags_downstream <- GRanges(seqnames = unique(seqnames(frags)),
                              ranges = IRanges(start=start(frags[dwn1]), end=end(frags[dwn2])),
                              mcols= data.frame("id"=paste0("window_DOWN_", seq_len(length(dwn1))),
                                                "position"="downstream"))
  
  ## Merge
  window_frags <- c(frags_upstream, frags_downstream)
  window_frags <- window_frags[order(window_frags)]
  
  return(window_frags)
}

.makeSlidingWindow <- function(frags,
                               n_frags,
                               start_upstream,
                               start_downstream,
                               sliding) {
  
  ## Upstream
  up1 <- seq(grep(start_upstream, frags$id_contact), 1, by=-(n_frags*sliding))
  up2 <- up1 - n_frags + 1
  up2[up2<=0] <- 1
  
  frags_upstream <- GRanges(seqnames = unique(seqnames(frags)),
                            ranges = IRanges(start=start(frags[up2]), end=end(frags[up1])),
                            mcols= data.frame("id"=paste0("window_UP_", seq_len(length(up1))),
                                              "position"="upstream"))
  
  ## Downstream
  dwn1 <- seq(grep(start_downstream, frags$id_contact), length(frags), by=n_frags*sliding)
  dwn2 <-dwn1 + n_frags - 1
  dwn2[dwn2 > length(frags)] <- length(frags)
  
  frags_downstream <- GRanges(seqnames = unique(seqnames(frags)),
                              ranges = IRanges(start=start(frags[dwn1]), end=end(frags[dwn2])),
                              mcols= data.frame("id"=paste0("window_DOWN_", seq_len(length(dwn1))),
                                                "position"="downstream"))
  
  ## Merge
  window_frags <- c(frags_upstream, frags_downstream)
  window_frags <- window_frags[order(window_frags)]
  
  return(window_frags)
}

#' Combine UMI4C fragments 
#' 
#' Combine the UMI4C fragments that overlap a given set of \code{query_regions}.
#' @param umi4c UMI4C object as generated by \code{\link{makeUMI4C}} or the
#' \code{UMI4C} constructor.
#' @param query_regions \code{GRanges} object containing the coordinates of the 
#' genomic regions for combining restriction fragments.
#' @return \code{UMI4C} object with rowRanges corresponding to query_regions and
#' assay containing the sum of raw UMI counts at each specified \code{query_region}.
#' @export
#' @examples 
#' data("ex_ciita_umi4c")
#' 
#' wins <- makeWindowFragments(ex_ciita_umi4c)
#' umi_comb <- combineUMI4C(ex_ciita_umi4c, wins)
combineUMI4C <- function(umi4c,
                         query_regions) {
  query_regions <- query_regions[order(query_regions)]
  
  matrix <- assay(umi4c)
  rowranges <- rowRanges(umi4c)
  
  hits <- findOverlaps(rowranges, query_regions)
  
  # Change id for mcol 4
  ids <- split(mcols(rowranges)[queryHits(hits),1], subjectHits(hits))
  
  mat_sp <- lapply(ids, function(x) matrix[x,])
  mat_sum <- lapply(mat_sp, function(x) if(is.null(dim(x))) x else colSums(x))
  mat_final <- do.call(rbind, mat_sum)
  
  umi4c_comb <- UMI4C(colData=colData(umi4c),
                      rowRanges=unique(query_regions[subjectHits(hits)]),
                      assays=SimpleList(umi = mat_final),
                      metadata=metadata(umi4c))
  rownames(umi4c_comb) <- unique(mcols(query_regions)[subjectHits(hits),1])
  
  return(umi4c_comb)
}

#' UMI4Cats object to DDS object.
#' 
#' Transforms an UMI4C object to a DDS object
#' @inheritParams differentialNbinomWaldTestUMI4C
#' @return DDS object.
#' @import GenomicRanges
UMI4C2dds <- function(umi4c,
                      design = ~condition){
  
  stopifnot(is(umi4c, "UMI4C"))
  
  # transform UMI4Cats object to DDS
  dds <- DESeq2::DESeqDataSetFromMatrix(
    countData = assays(umi4c)$umi,
    colData = colData(umi4c),
    rowRanges = rowRanges(umi4c),
    metadata = metadata(umi4c),
    design = ~ condition)
  
  rowRanges(dds) <- rowRanges(umi4c)
  colnames(mcols(rowRanges(dds)))[1] <- "id"
  
  return(dds)
}

#' DDS object to UMI4Cats object.
#' 
#' Transforms an DDS object to a UMI4C object after applying 
#' \code{nbinomWaldTestUMI4C}.
#' @inheritParams differentialNbinomWaldTestUMI4C
#' @param dds DDS object as generated by \code{nbinomWaldTestUMI4C} 
#' with the DESeq2 Wald Test results
#' @return UMI4C object with the DESeq2 Wald Test results.
dds2UMI4C <- function(umi4c,
                      dds,
                      normalized = TRUE,
                      padj_method = "fdr",
                      padj_threshold = 0.05) {
  
  res <- DESeq2::results(dds,
                         pAdjustMethod = padj_method
  )
  
  res <- data.frame(res[, c(5, 2, 6)])
  res$query_id <- rownames(res)
  res$sign <- FALSE
  res$sign[res$padj <= padj_threshold] <- TRUE
  
  counts <- as.data.frame(counts(dds, normalized = normalized))
  counts$query_id <- rownames(counts)
  counts <- counts[, c(ncol(counts), seq_len(ncol(counts) - 1))]
  
  umi4c@results <- S4Vectors::SimpleList(
    test = "DESeq2 Test based on the Negative Binomial distribution",
    ref = DESeq2::resultsNames(dds)[2],
    padj_threshold = padj_threshold,
    results = res[, c(4, 1, 2, 3, 5)],
    query = rowRanges(dds),
    counts = counts
  )
  
  return(umi4c)
}