Skip to content

Commit

Permalink
functionality to plot difference between observed and imputed data
Browse files Browse the repository at this point in the history
  • Loading branch information
hanneoberman committed Oct 16, 2024
1 parent 015dc64 commit 145b367
Showing 1 changed file with 90 additions and 53 deletions.
143 changes: 90 additions & 53 deletions R/plot_corr.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#'
#' @param data A dataset of class `data.frame`, `tibble`, or `matrix`.
#' @param vrb String, vector, or unquoted expression with variable name(s), default is "all".
#' @param diff Logical indicating whether the difference between the observed and imputed data is plotted.
#' @param label Logical indicating whether correlation values should be displayed.
#' @param square Logical indicating whether the plot tiles should be squares.
#' @param diagonal Logical indicating whether the correlation of each variable with itself should be displayed.
Expand Down Expand Up @@ -31,103 +32,139 @@
plot_corr <-
function(data,
vrb = "all",
diff = FALSE,
label = FALSE,
square = TRUE,
diagonal = FALSE,
rotate = FALSE,
caption = TRUE) {
# process inputs
if (is.matrix(data) && ncol(data) > 1) {
data <- as.data.frame(data)
}
verify_data(data = data, df = TRUE, imp = TRUE)
if (diff && !mice::is.mids(data)) {
cli::cli_abort("Difference in correlations can only be computed with imputed data.")
}
vrb <- rlang::enexpr(vrb)
vrbs_in_data <- if (mice::is.mids(data)) {
names(data$imp)
if (mice::is.mids(data)) {
imp <- TRUE
mids <- data
data <- data$data
} else {
names(data)
imp <- FALSE
}
vrbs_in_data <- names(data)
vrb_matched <- match_vrb(vrb, vrbs_in_data)
# check if any column is constant
constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) {
all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE)
})
if (any(constants)) {
vrb_matched <- vrb_matched[!constants]
cli::cli_inform(
c(
"No correlations computed for variable(s):",
" " = paste(names(constants[which(constants)]), collapse = ", "),
"i" = "Correlations are undefined for constants."
)
)
}
if (length(vrb_matched) < 2) {
cli::cli_abort("The number of variables should be two or more to compute correlations.")
}
if (is.data.frame(data)) {
# for data: check if any column is constant
constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) {
all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE)
})
if (any(constants)) {
vrb_matched <- vrb_matched[!constants]
cli::cli_inform(
c(
"No correlations computed for variable(s):",
" " = paste(names(constants[which(constants)]), collapse = ", "),
"i" = "Correlations are undefined for constants."
)
)
}
# compute correlations
corr <- stats::cov2cor(stats::cov(
data.matrix(data[, vrb_matched]),
use = "pairwise.complete.obs"
))
# create plot labels
lab_x <- "Imputation model predictor"
lab_y <- "Imputation target"
lab_fill <- "Correlation*
"
if (!imp) {
lab_note <- "*pairwise complete observations"
}
if (mice::is.mids(data)) {
# check constatnts etc.
imps <- mice::complete(data, "all")
if (imp) {
lab_note <- "*pooled across imputations"
}
if (diff) {
lab_fill <- "Difference in correlations*
"
lab_note <- "*observed minus imputed (pooled across imputations)"
}
# compute correlations
if (!imp | diff) {

Check warning on line 93 in R/plot_corr.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_corr.R,line=93,col=14,[vector_logic_linter] Conditional expressions require scalar logical operators (&& and ||)
corr <- stats::cov2cor(stats::cov(data.matrix(data[, vrb_matched]), use = "pairwise.complete.obs"))
}
if (imp) {
imps <- mice::complete(mids, "all")
corrs <- purrr::map(imps, ~ {
stats::cor(.x)
stats::cov2cor(stats::cov(data.matrix(.x[, vrb_matched]), use = "pairwise.complete.obs"))
})
corr <- Reduce("+", corrs) / length(corrs)
if (diff) {
corr <- corr - (Reduce("+", corrs) / length(corrs))
} else {
corr <- Reduce("+", corrs) / length(corrs)
}
}
# convert correlations into plotting object
p <- length(vrb_matched)
long <- data.frame(
vrb = rep(vrb_matched, each = p),
prd = vrb_matched,
corr = matrix(round(corr, 2), nrow = p * p, byrow = TRUE)
)
if (!diff) {
long$text <- long$corr
}
if (diff) {
long$text <- sprintf("%+.2f", long$corr)
long$text[long$corr < 0] <- long$corr[long$corr < 0]
}
if (!diagonal) {
long[long$vrb == long$prd, "corr"] <- NA
long[long$vrb == long$prd, "text"] <- ""
}
# create plot
gg <-
ggplot2::ggplot(long,
ggplot2::aes(
x = .data$prd,
y = .data$vrb,
label = .data$corr,
label = .data$text,
fill = .data$corr
)) +
ggplot2::geom_tile(color = "black", alpha = 0.6) +
ggplot2::scale_x_discrete(limits = vrb_matched, position = "top") +
ggplot2::scale_y_discrete(limits = rev(vrb_matched)) +
ggplot2::scale_fill_gradient2(
low = ggplot2::alpha("deepskyblue", 0.6),
mid = "lightyellow",
high = ggplot2::alpha("orangered", 0.6),
na.value = "grey90",
limits = c(-1, 1)
) +
ggplot2::labs(
x = lab_x,
y = lab_y,
fill = lab_fill,
caption = lab_note
) +
theme_minimice()
lab_x <- "Imputation model predictor"
if (mice::is.mids(data)) {
lab_y <- "Column name"
lab_note <- "*pooled across imputations"
} else {
lab_y <- "Column name"
lab_note <- "*pairwise complete observations"
}
if (caption) {
# edit plot to match function arguments
if (!diff) {
gg <- gg +
ggplot2::labs(
x = lab_x,
y = lab_y,
fill = "Correlation*
",
caption = lab_note
ggplot2::scale_fill_gradient2(
low = ggplot2::alpha("deepskyblue", 0.6),
mid = "lightyellow",
high = ggplot2::alpha("orangered", 0.6),
na.value = "grey90",
limits = c(-1, 1)
)
} else {
}
if (diff) {
gg <- gg + ggplot2::scale_fill_gradient2(
low = ggplot2::alpha("deepskyblue", 1),
mid = "lightyellow",
high = ggplot2::alpha("orangered", 1),
na.value = "grey90",
limits = c(-2, 2)
)
}
if (!caption) {
lab_fill <- substring(lab_fill, 1, nchar(lab_fill) - 2)
gg <- gg +
ggplot2::labs(x = lab_x, y = lab_y, fill = "Correlation")
ggplot2::labs(fill = lab_fill, caption = NULL)
}
if (label) {
gg <-
Expand Down

0 comments on commit 145b367

Please sign in to comment.