#Evaluating FPR

#Benchmark Marker genes for starTracer
library(Seurat)
library(splatter)
library(scater)
library(SingleCellExperiment)

#      1. Create Simulated Dataset ------
params <- newSplatParams()
params <- setParam(params, "seed",114514)
params <- setParam(params, "group.prob", rep(0.2,5))
params <- setParam(params, "de.facLoc", 3)
params <- setParam(params, "de.facScale",0.2)
sim.sce <- splatSimulate(params,batchCells = c(1000),method = "groups")
sim.sce <- logNormCounts(sim.sce)



sim.seurat <- as.Seurat(sim.sce, counts = "counts", data = "logcounts")
sim.seurat <- NormalizeData(sim.seurat)
sim.seurat <- ScaleData(sim.seurat)
sim.seurat <- FindVariableFeatures(sim.seurat, nfeatures = 2000)
sim.seurat <- RunPCA(sim.seurat, verbose = FALSE)
pc.num <- FindPCNum(sim.seurat)
sim.seurat <- RunUMAP(sim.seurat, reduction = "PCA", dims = pc.num)
sim.seurat <- FindNeighbors(sim.seurat, reduction = "PCA", dims = pc.num) %>% FindClusters()
Idents(sim.seurat) = "Group"
DefaultAssay(sim.seurat) = "originalexp"
DimPlot(sim.seurat,group.by = "Group",label = T)


#       2. Find Ground Truth marker genes ------
find_marker_genes <- function(sce) {
  
  stopifnot(is(sce, "SingleCellExperiment"))
  
  n_groups <- length(unique(colData(sce)$Group))
  col_names <- paste0("DEFacGroup", seq_len(n_groups))
  data <- as.matrix(rowData(sce)[, col_names])
  
  out <- list()
  for (i in seq_len(n_groups)) {
    n_genes <- nrow(data)
    not_i <- setdiff(1:n_groups, i)
    
    mean_scores <- numeric(n_genes)
    max_scores <- numeric(n_genes)
    median_scores <- numeric(n_genes)
    directions <- character(n_genes)
    de_values <- numeric(n_genes)
    for (j in seq_len(n_genes)) {
      y <- data[j, not_i]
      x <- data[j, i]
      log_ratios <- log(x / y)
      mean_scores[[j]] <- mean(log_ratios)
      median_scores[[j]] <- median(log_ratios)
      max_scores[[j]] <- max(log_ratios)
      de_values[[j]] <- x
      if (x > 1) {
        directions[[j]] <- "up"
      } else if (x == 1) {
        directions[[j]] <- "background"
      } else {
        directions[[j]] <- "down"
      }
    }
    out[[i]] <- tibble::tibble(
      gene = rowData(sce)$Gene,
      gene_mean = rowData(sce)$BaseGeneMean,
      # FIXME: Kept for backwards compatibility.
      fc = mean_scores,
      mean_score = mean_scores,
      median_score = median_scores,
      max_score = max_scores,
      direction = directions,
      de_value = de_values
    )
  }
  
  names(out) <- paste0("group_", seq_len(n_groups))
  out
}

true_marker_genes <- find_marker_genes(sim.sce)


#      3.Seurat Findallmarkers ------
sim.seurat@assays$originalexp@data
seurat_markergenes <- FindAllMarkers(sim.seurat,
                                      assay = "originalexp",
                                     slot = "data",
                                     test.use = "wilcox")

frame.seurat <- seurat_markergenes[,c("cluster","gene","avg_log2FC","p_val_adj")] %>% subset(p_val_adj < 0.05)
frame.seurat <- arrange(frame.seurat,gene,desc(avg_log2FC))
frame.seurat <- frame.seurat[!duplicated(frame.seurat$gene),] %>% arrange(cluster,desc(avg_log2FC))

frame.seurat <- frame.seurat %>%
  group_by(cluster) %>% #group by cluster
  top_n(50, wt = avg_log2FC) %>% #arrange by FC
  ungroup()
frame.seurat$cluster %>% table() #检查是否都有50个
seurat_markergenes = frame.seurat
#      4.starTracer Searchmarkers ------
library(starTracer)
startracer_markergenes <- searchMarker(sim.seurat)


#      5. ConfutionMatrix ------
#choose top50
N = 50

#true marker genes
true_markers <- lapply(true_marker_genes, function(x){
  x = x %>%
    subset(gene_mean > 0.1) %>%
    subset(direction == "up") %>%
    slice_max(fc,n=N)
})

#seurat 
seurat_markers = seurat_markergenes %>%
  group_split(cluster,.keep = T) %>%
  lapply(function(x){
    x = x %>% 
      subset(p_val_adj < 0.05) %>%
      slice_max(avg_log2FC,n=N)
  })

#starTracer
starTracer_markers <- startracer_markergenes$para_frame %>%
  group_split(max.X) %>%
  lapply(function(x){
    x = x %>% 
      arrange(n,desc(del_MI)) %>%
      slice_head(n = N)
  })
starTracer_markers[[1]]

#FPR for startracer
FPR1 = c()
for (i in 1:5) {
  TP = length(intersect(starTracer_markers[[i]]$gene,
                        true_markers[[i]]$gene))
  FPR1[i] = (N-TP)/N
}
median(FPR1)

#FPR for seurat
FPR2 = c()
for (i in 1:5) {
  TP = length(intersect(seurat_markers[[i]]$gene,
                        true_markers[[i]]$gene))
  FPR2[i] = (N-TP)/N
}
median(FPR2)

plot.df <- data.frame(Group = levels(sim.seurat),
                      starTracer_FPR = FPR1,
                      seuart_FPR = FPR2) %>% 
  reshape2::melt(id.vars = "Group")


p = ggplot(plot.df, aes(x = Group,y = value,color = variable,group = variable,shape = variable))+
  geom_line()+
  geom_point(size = 3)+
  theme_bw()
ggsave(p,filename = "./Figures/startracer_FPR.pdf",
       height = 5,width = 7)

p = DimPlot(sim.seurat,label = T)
ggsave(p,filename = "./Figures/startracer_Dimplot.pdf",
       height = 5,width = 7)



features.list = lapply(starTracer_markers, function(x){x = x$gene})
names(features.list) = levels(sim.seurat)
p = DotPlot(sim.seurat,
        features = features.list) +
  ggtitle("starTracer")+
  theme(axis.text.x = element_text(angle = 90))

ggsave(p,filename = "./Figures/startracer_dotplot.pdf",
       height = 5,width = 16)


features.list = lapply(seurat_markers, function(x){x = x$gene})
names(features.list) = levels(sim.seurat)
p = DotPlot(sim.seurat,
            features = features.list) +
  ggtitle("Seurat")+
  theme(axis.text.x = element_text(angle = 90))

ggsave(p,filename = "./Figures/seurat_dotplot.pdf",
       height = 5,width = 16)



features.list = lapply(true_markers, function(x){x = x$gene})
names(features.list) = levels(sim.seurat)
p = DotPlot(sim.seurat,
            features = features.list) +
  ggtitle("Simulated GroundTruth")+
  theme(axis.text.x = element_text(angle = 90))

ggsave(p,filename = "./Figures/truth_dotplot.pdf",
       height = 5,width = 16)


#Line plot 1-50
star_true = vector("list", 5)
seurat_true = vector("list", 5)
star_seurat = vector("list", 5)
#Line plot 1-50
star_true = vector("list", 5)
seurat_true = vector("list", 5)
star_seurat = vector("list", 5)


for (i in 1:50) {
  N = i 
  
  #true marker genes
  true_markers <- lapply(true_marker_genes, function(x){
    x = x %>%
      subset(gene_mean > 0.1) %>%
      subset(direction == "up") %>%
      arrange(desc(fc)) %>% 
      slice_head(n=N)
  })
  
  #seurat 
  seurat_markers = seurat_markergenes %>%
    group_split(cluster,.keep   = T) %>%
    lapply(function(x){
      x = x %>% 
        subset(p_val_adj < 0.05) %>%
        arrange(desc(avg_log2FC)) %>%
        slice_head(n=N)
    })
  
  #starTracer
  starTracer_markers <- startracer_markergenes$para_frame %>%
    group_split(max.X) %>%
    lapply(function(x){
      x = x %>% 
        arrange(n,desc(del_MI)) %>%
        slice_head(n = N)
    })
  
  
  
  for (j in 1:5) {
    star_true[[j]] = append(star_true[[j]],length(intersect(starTracer_markers[[j]]$gene,
                                                            true_markers[[j]]$gene))) 
    seurat_true[[j]] = append(seurat_true[[j]],length(intersect(seurat_markers[[j]]$gene,
                                                                true_markers[[j]]$gene)))
    star_seurat[[j]] = append(star_seurat[[j]],length(intersect(starTracer_markers[[j]]$gene,
                                                                seurat_markers[[j]]$gene)))
  }
}

for (j in 1:5) {
  #plot
  plot.df = data.frame(star_true = star_true[[j]],
                       seurat_true = seurat_true[[j]],
                       star_seurat = star_seurat[[j]],
                       topN = 1:50)
  plot.df = reshape2::melt(plot.df,id.var = "topN")
  p  = ggplot(plot.df,aes(x = topN,y = value,group = variable,color = variable))+
    geom_line()+
    geom_point(size = 3)+
    ylab("Number of Intersected genes")+
    ggtitle(paste0("Group",j))+
  theme_bw()
  
  
  ggsave(p,filename = paste0("./Figures/starTracer/topN_intersected_genes_Group",j,".pdf"),
         height = 6,width = 7)
}