Skip to content

Commit

Permalink
make plot_corr() work on mids objects
Browse files Browse the repository at this point in the history
  • Loading branch information
hanneoberman committed Oct 16, 2024
1 parent 41222e3 commit 015dc64
Showing 1 changed file with 50 additions and 29 deletions.
79 changes: 50 additions & 29 deletions R/plot_corr.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,45 +39,58 @@ plot_corr <-
if (is.matrix(data) && ncol(data) > 1) {
data <- as.data.frame(data)
}
verify_data(data = data, df = TRUE)
verify_data(data = data, df = TRUE, imp = TRUE)
vrb <- rlang::enexpr(vrb)
vrb_matched <- match_vrb(vrb, names(data))
vrbs_in_data <- if (mice::is.mids(data)) {
names(data$imp)
} else {
names(data)
}
vrb_matched <- match_vrb(vrb, vrbs_in_data)
if (length(vrb_matched) < 2) {
cli::cli_abort("The number of variables should be two or more to compute correlations.")
}
# check if any column is constant
constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) {
all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE)
})
if (any(constants)) {
vrb_matched <- vrb_matched[!constants]
cli::cli_inform(
c(
"No correlations computed for variable(s):",
" " = paste(names(constants[which(constants)]), collapse = ", "),
"i" = "Correlations are undefined for constants."
if (is.data.frame(data)) {
# for data: check if any column is constant
constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) {
all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE)
})
if (any(constants)) {
vrb_matched <- vrb_matched[!constants]
cli::cli_inform(
c(
"No correlations computed for variable(s):",
" " = paste(names(constants[which(constants)]), collapse = ", "),
"i" = "Correlations are undefined for constants."
)
)
)
}
# compute correlations
corr <- stats::cov2cor(stats::cov(
data.matrix(data[, vrb_matched]),
use = "pairwise.complete.obs"
))
}
if (mice::is.mids(data)) {
# check constatnts etc.
imps <- mice::complete(data, "all")
corrs <- purrr::map(imps, ~ {
stats::cor(.x)
})
corr <- Reduce("+", corrs) / length(corrs)
}
# compute correlations
p <- length(vrb_matched)
corrs <- data.frame(
long <- data.frame(
vrb = rep(vrb_matched, each = p),
prd = vrb_matched,
corr = matrix(
round(stats::cov2cor(
stats::cov(data.matrix(data[, vrb_matched]), use = "pairwise.complete.obs")
), 2),
nrow = p * p,
byrow = TRUE
)
corr = matrix(round(corr, 2), nrow = p * p, byrow = TRUE)
)
if (!diagonal) {
corrs[corrs$vrb == corrs$prd, "corr"] <- NA
long[long$vrb == long$prd, "corr"] <- NA
}
# create plot
gg <-
ggplot2::ggplot(corrs,
ggplot2::ggplot(long,
ggplot2::aes(
x = .data$prd,
y = .data$vrb,
Expand All @@ -95,18 +108,26 @@ plot_corr <-
limits = c(-1, 1)
) +
theme_minimice()
lab_x <- "Imputation model predictor"
if (mice::is.mids(data)) {
lab_y <- "Column name"
lab_note <- "*pooled across imputations"
} else {
lab_y <- "Column name"
lab_note <- "*pairwise complete observations"
}
if (caption) {
gg <- gg +
ggplot2::labs(
x = "Imputation model predictor",
y = "Variable to impute",
x = lab_x,
y = lab_y,
fill = "Correlation*
",
caption = "*pairwise complete observations"
caption = lab_note
)
} else {
gg <- gg +
ggplot2::labs(x = "Imputation model predictor", y = "Variable to impute", fill = "Correlation")
ggplot2::labs(x = lab_x, y = lab_y, fill = "Correlation")
}
if (label) {
gg <-
Expand Down

0 comments on commit 015dc64

Please sign in to comment.