From 015dc649fd7eb24ff05bf6a024879e83b3ef21e5 Mon Sep 17 00:00:00 2001 From: hanneoberman Date: Wed, 16 Oct 2024 12:36:23 +0200 Subject: [PATCH] make `plot_corr()` work on `mids` objects --- R/plot_corr.R | 79 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/R/plot_corr.R b/R/plot_corr.R index 2d290e0..e50a3b6 100644 --- a/R/plot_corr.R +++ b/R/plot_corr.R @@ -39,45 +39,58 @@ plot_corr <- if (is.matrix(data) && ncol(data) > 1) { data <- as.data.frame(data) } - verify_data(data = data, df = TRUE) + verify_data(data = data, df = TRUE, imp = TRUE) vrb <- rlang::enexpr(vrb) - vrb_matched <- match_vrb(vrb, names(data)) + vrbs_in_data <- if (mice::is.mids(data)) { + names(data$imp) + } else { + names(data) + } + vrb_matched <- match_vrb(vrb, vrbs_in_data) if (length(vrb_matched) < 2) { cli::cli_abort("The number of variables should be two or more to compute correlations.") } - # check if any column is constant - constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) { - all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE) - }) - if (any(constants)) { - vrb_matched <- vrb_matched[!constants] - cli::cli_inform( - c( - "No correlations computed for variable(s):", - " " = paste(names(constants[which(constants)]), collapse = ", "), - "i" = "Correlations are undefined for constants." + if (is.data.frame(data)) { + # for data: check if any column is constant + constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) { + all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE) + }) + if (any(constants)) { + vrb_matched <- vrb_matched[!constants] + cli::cli_inform( + c( + "No correlations computed for variable(s):", + " " = paste(names(constants[which(constants)]), collapse = ", "), + "i" = "Correlations are undefined for constants." + ) ) - ) + } + # compute correlations + corr <- stats::cov2cor(stats::cov( + data.matrix(data[, vrb_matched]), + use = "pairwise.complete.obs" + )) + } + if (mice::is.mids(data)) { + # check constatnts etc. + imps <- mice::complete(data, "all") + corrs <- purrr::map(imps, ~ { + stats::cor(.x) + }) + corr <- Reduce("+", corrs) / length(corrs) } - # compute correlations p <- length(vrb_matched) - corrs <- data.frame( + long <- data.frame( vrb = rep(vrb_matched, each = p), prd = vrb_matched, - corr = matrix( - round(stats::cov2cor( - stats::cov(data.matrix(data[, vrb_matched]), use = "pairwise.complete.obs") - ), 2), - nrow = p * p, - byrow = TRUE - ) + corr = matrix(round(corr, 2), nrow = p * p, byrow = TRUE) ) if (!diagonal) { - corrs[corrs$vrb == corrs$prd, "corr"] <- NA + long[long$vrb == long$prd, "corr"] <- NA } # create plot gg <- - ggplot2::ggplot(corrs, + ggplot2::ggplot(long, ggplot2::aes( x = .data$prd, y = .data$vrb, @@ -95,18 +108,26 @@ plot_corr <- limits = c(-1, 1) ) + theme_minimice() + lab_x <- "Imputation model predictor" + if (mice::is.mids(data)) { + lab_y <- "Column name" + lab_note <- "*pooled across imputations" + } else { + lab_y <- "Column name" + lab_note <- "*pairwise complete observations" + } if (caption) { gg <- gg + ggplot2::labs( - x = "Imputation model predictor", - y = "Variable to impute", + x = lab_x, + y = lab_y, fill = "Correlation* ", - caption = "*pairwise complete observations" + caption = lab_note ) } else { gg <- gg + - ggplot2::labs(x = "Imputation model predictor", y = "Variable to impute", fill = "Correlation") + ggplot2::labs(x = lab_x, y = lab_y, fill = "Correlation") } if (label) { gg <-