首页 > 解决方案 > R函数中的子集

问题描述

我有一个生成表格图形的函数:

plot_covariate_means_by_ntile <- function(.df, .ntile = "ntile", n_top = 10, directory) {
  .df <- as.data.frame(.df)
  covariate_names <- covariate_names
  #.df[, .ntile] <- as.factor(.df[, .ntile])
  .df[, .ntile] <- as_factor(.df[, .ntile], levels = "both")
    
  # Regress each covariate on ntile/subgroup assignment to means p
  cov_means <- lapply(covariate_names, function(covariate) {
    lm_robust(as.formula(paste0(covariate, " ~ 0 + ", .ntile)), data = .df, se_type = "stata")
  })
  
  # Extract the mean and standard deviation of each covariate per ntile/subgroup
  cov_table <- lapply(cov_means, function(cov_mean) {
    means <- as.data.frame(t(coef(summary(cov_mean))[,c("Estimate", "Std. Error")]))
    means
  })
  
  # Preparation to color the chart
  temp_standardized <- sapply(seq_along(covariate_names), function(j) {
    covariate_name <- covariate_names[j]
    .mean <- mean(.df[, covariate_name], na.rm = TRUE)
    .sd <- sd(.df[, covariate_name], na.rm = TRUE)
    m <- as.matrix(round(signif(cov_table[[j]], digits=4), 3))
    .standardized <- (m["Estimate",] - .mean) / .sd
    .standardized
  })
  
  colnames(temp_standardized) <- covariate_names
  ordering <- order(apply(temp_standardized, MARGIN = 2, function(x) {.range <- range(x); abs(.range[2] - .range[1])}), decreasing = TRUE)
  
  # fwrite(tibble::rownames_to_column(as.data.frame(t(temp_standardized)[ordering,])), 
  #        paste0(directory$data, "/covariate_standardized_means_by_", .ntile, ".csv"))
  
  
  color_scale <- max(abs(c(max(temp_standardized, na.rm = TRUE), min(temp_standardized, na.rm = TRUE))))
  color_scale <- color_scale * c(-1,1)
  max_std_dev <- floor(max(color_scale))
  breaks <- -max_std_dev:max_std_dev
  labels <- c(" ", breaks, " ")
  breaks <- c(min(color_scale), breaks, max(color_scale))
  
  
  # Little trick to display the standard errors
  table <- lapply(seq_along(covariate_names), function(j) {
    covariate_name <- covariate_names[j]
    .mean <- mean(.df[, covariate_name], na.rm = TRUE)
    .sd <- sd(.df[, covariate_name], na.rm = TRUE)
    m <- as.matrix(round(signif(cov_table[[j]], digits=4), 3))
    .standardized <- (m["Estimate",] - .mean) / .sd
    return(data.frame(covariate = covariate_name, 
                      group = c(1,2,5) , 
                      estimate = m["Estimate",], std.error = m["Std. Error",], 
                      standardized = .standardized))
  })
  # table <- do.call(rbind, table)
  table <- rbindlist(table)
  
  setnames(table, "group", .ntile)
  table[, covariate := factor(covariate, levels = rev(covariate_names[ordering]), ordered = TRUE)]
  
  table[covariate %in% head(covariate_names[ordering], n_top)] %>%
    mutate(info = paste0(estimate, "\n(", std.error, ")")) %>%
    ggplot(aes_string(x = .ntile, y = "covariate")) +
    # Add coloring
    geom_raster(aes(fill = standardized)
                , alpha = 0.9
    ) +
    scale_fill_distiller(palette = "RdBu",
                         direction = 1,
                         breaks = breaks,
                         labels = labels,
                         limits = color_scale,
                         name = "Standard\nDeviation on\nNormalized\nDistribution"
    ) +
    # add numerics
    geom_text(aes(label = info), size=2.1) +
    # reformat
    labs(title = paste0("Covariate averages within ", ifelse(tolower(.ntile) == "leaf", .ntile, "Assigned Group")),
         y = "within covariate") +
    scale_x_continuous(position = "top") #+
  #cowplot::theme_minimal_hgrid(16)
}

但输出显示所有 5 列,我希望它只显示 1、2 和 5。

在此处输入图像描述

我可以调整线路

groups = 1:ncol(m)

但是随后错误地标记了组,第三列实际上是第 5 组:

在此处输入图像描述

有什么方法可以调整函数以显示正确的列和正确的标签?

标签: rfunctionggplot2

解决方案


也许您可以使用 facet_wrap 作为解决方法?

library(tidyverse)
data.frame(X = rep(1:5, each = 25),
           Y = rep(factor(rev(LETTERS[-26]),
                          levels = rev(LETTERS[-26])), 5),
           Z = rnorm(125, 5, 1)) %>%
  mutate(X = ifelse(X %in% c(1,2,5), X, NA)) %>%
  na.omit() %>% 
  ggplot(aes(x = X, y = Y, fill = Z)) +
  geom_raster() +
  facet_wrap(~X, ncol=3, scales="free_x") +
  theme_minimal() +
  theme(axis.text.x = element_blank())

示例-geom_raster.png

我试图找出一个使用scale_x_discrete(例如类似的东西scale_x_discrete(limits = c("1", "2", "5"), breaks = c("1", "2", "5")))的解决方案,它“感觉”它可以工作,但我放弃了——也许是值得追求的东西。


推荐阅读