From 145b367157a8165501e9aedd6ffd5d39b89b2e64 Mon Sep 17 00:00:00 2001 From: hanneoberman Date: Wed, 16 Oct 2024 14:55:02 +0200 Subject: [PATCH] functionality to plot difference between observed and imputed data --- R/plot_corr.R | 143 +++++++++++++++++++++++++++++++------------------- 1 file changed, 90 insertions(+), 53 deletions(-) diff --git a/R/plot_corr.R b/R/plot_corr.R index e50a3b6..0772afc 100644 --- a/R/plot_corr.R +++ b/R/plot_corr.R @@ -2,6 +2,7 @@ #' #' @param data A dataset of class `data.frame`, `tibble`, or `matrix`. #' @param vrb String, vector, or unquoted expression with variable name(s), default is "all". +#' @param diff Logical indicating whether the difference between the observed and imputed data is plotted. #' @param label Logical indicating whether correlation values should be displayed. #' @param square Logical indicating whether the plot tiles should be squares. #' @param diagonal Logical indicating whether the correlation of each variable with itself should be displayed. @@ -31,62 +32,95 @@ plot_corr <- function(data, vrb = "all", + diff = FALSE, label = FALSE, square = TRUE, diagonal = FALSE, rotate = FALSE, caption = TRUE) { + # process inputs if (is.matrix(data) && ncol(data) > 1) { data <- as.data.frame(data) } verify_data(data = data, df = TRUE, imp = TRUE) + if (diff && !mice::is.mids(data)) { + cli::cli_abort("Difference in correlations can only be computed with imputed data.") + } vrb <- rlang::enexpr(vrb) - vrbs_in_data <- if (mice::is.mids(data)) { - names(data$imp) + if (mice::is.mids(data)) { + imp <- TRUE + mids <- data + data <- data$data } else { - names(data) + imp <- FALSE } + vrbs_in_data <- names(data) vrb_matched <- match_vrb(vrb, vrbs_in_data) + # check if any column is constant + constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) { + all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE) + }) + if (any(constants)) { + vrb_matched <- vrb_matched[!constants] + cli::cli_inform( + c( + "No correlations computed for variable(s):", + " " = paste(names(constants[which(constants)]), collapse = ", "), + "i" = "Correlations are undefined for constants." + ) + ) + } if (length(vrb_matched) < 2) { cli::cli_abort("The number of variables should be two or more to compute correlations.") } - if (is.data.frame(data)) { - # for data: check if any column is constant - constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) { - all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE) - }) - if (any(constants)) { - vrb_matched <- vrb_matched[!constants] - cli::cli_inform( - c( - "No correlations computed for variable(s):", - " " = paste(names(constants[which(constants)]), collapse = ", "), - "i" = "Correlations are undefined for constants." - ) - ) - } - # compute correlations - corr <- stats::cov2cor(stats::cov( - data.matrix(data[, vrb_matched]), - use = "pairwise.complete.obs" - )) + # create plot labels + lab_x <- "Imputation model predictor" + lab_y <- "Imputation target" + lab_fill <- "Correlation* + " + if (!imp) { + lab_note <- "*pairwise complete observations" } - if (mice::is.mids(data)) { - # check constatnts etc. - imps <- mice::complete(data, "all") + if (imp) { + lab_note <- "*pooled across imputations" + } + if (diff) { + lab_fill <- "Difference in correlations* +" + lab_note <- "*observed minus imputed (pooled across imputations)" + } + # compute correlations + if (!imp | diff) { + corr <- stats::cov2cor(stats::cov(data.matrix(data[, vrb_matched]), use = "pairwise.complete.obs")) + } + if (imp) { + imps <- mice::complete(mids, "all") corrs <- purrr::map(imps, ~ { - stats::cor(.x) + stats::cov2cor(stats::cov(data.matrix(.x[, vrb_matched]), use = "pairwise.complete.obs")) }) - corr <- Reduce("+", corrs) / length(corrs) + if (diff) { + corr <- corr - (Reduce("+", corrs) / length(corrs)) + } else { + corr <- Reduce("+", corrs) / length(corrs) + } } + # convert correlations into plotting object p <- length(vrb_matched) long <- data.frame( vrb = rep(vrb_matched, each = p), prd = vrb_matched, corr = matrix(round(corr, 2), nrow = p * p, byrow = TRUE) ) + if (!diff) { + long$text <- long$corr + } + if (diff) { + long$text <- sprintf("%+.2f", long$corr) + long$text[long$corr < 0] <- long$corr[long$corr < 0] + } if (!diagonal) { long[long$vrb == long$prd, "corr"] <- NA + long[long$vrb == long$prd, "text"] <- "" } # create plot gg <- @@ -94,40 +128,43 @@ plot_corr <- ggplot2::aes( x = .data$prd, y = .data$vrb, - label = .data$corr, + label = .data$text, fill = .data$corr )) + ggplot2::geom_tile(color = "black", alpha = 0.6) + ggplot2::scale_x_discrete(limits = vrb_matched, position = "top") + ggplot2::scale_y_discrete(limits = rev(vrb_matched)) + - ggplot2::scale_fill_gradient2( - low = ggplot2::alpha("deepskyblue", 0.6), - mid = "lightyellow", - high = ggplot2::alpha("orangered", 0.6), - na.value = "grey90", - limits = c(-1, 1) - ) + + ggplot2::labs( + x = lab_x, + y = lab_y, + fill = lab_fill, + caption = lab_note + ) + theme_minimice() - lab_x <- "Imputation model predictor" - if (mice::is.mids(data)) { - lab_y <- "Column name" - lab_note <- "*pooled across imputations" - } else { - lab_y <- "Column name" - lab_note <- "*pairwise complete observations" - } - if (caption) { + # edit plot to match function arguments + if (!diff) { gg <- gg + - ggplot2::labs( - x = lab_x, - y = lab_y, - fill = "Correlation* - ", - caption = lab_note + ggplot2::scale_fill_gradient2( + low = ggplot2::alpha("deepskyblue", 0.6), + mid = "lightyellow", + high = ggplot2::alpha("orangered", 0.6), + na.value = "grey90", + limits = c(-1, 1) ) - } else { + } + if (diff) { + gg <- gg + ggplot2::scale_fill_gradient2( + low = ggplot2::alpha("deepskyblue", 1), + mid = "lightyellow", + high = ggplot2::alpha("orangered", 1), + na.value = "grey90", + limits = c(-2, 2) + ) + } + if (!caption) { + lab_fill <- substring(lab_fill, 1, nchar(lab_fill) - 2) gg <- gg + - ggplot2::labs(x = lab_x, y = lab_y, fill = "Correlation") + ggplot2::labs(fill = lab_fill, caption = NULL) } if (label) { gg <-