From e8b0d0b27a5640e26a6fa279d38a0c5333dab960 Mon Sep 17 00:00:00 2001
From: pepijnvink <pepijnvink@gmail.com>
Date: Thu, 16 Nov 2023 13:19:17 +0100
Subject: [PATCH] add option to make square

---
 R/plot_miss.R                     | 11 ++++-
 R/plot_variance.R                 | 71 +++++++++++++++++++++++++++++++
 man/plot_miss.Rd                  |  3 ++
 man/plot_variance.Rd              | 29 +++++++++++++
 tests/testthat/test-plot_miss.R.R |  6 +--
 5 files changed, 114 insertions(+), 6 deletions(-)
 create mode 100644 R/plot_variance.R
 create mode 100644 man/plot_variance.Rd

diff --git a/R/plot_miss.R b/R/plot_miss.R
index 53b935e0..2bd70218 100644
--- a/R/plot_miss.R
+++ b/R/plot_miss.R
@@ -5,6 +5,7 @@
 #' @param border Logical indicating whether borders should be present between tiles.
 #' @param row.breaks Optional numeric input specifying the number of breaks to be visualized on the y axis.
 #' @param ordered Logical indicating whether rows should be ordered according to their pattern.
+#' @param square  Logical indicating whether the plot tiles should be squares, defaults to squares.
 #'
 #' @return An object of class [ggplot2::ggplot].
 #'
@@ -17,6 +18,7 @@ plot_miss <-
            vrb = "all",
            border = FALSE,
            row.breaks = nrow(data),
+           square = TRUE,
            ordered = FALSE) {
     # input processing
     if (is.matrix(data) && ncol(data) > 1) {
@@ -40,7 +42,7 @@ plot_miss <-
     if(ordered){
       # extract md.pattern matrix
       mdpat <- mice::md.pattern(data, plot = FALSE) %>%
-        head(., -1)
+        utils::head(., -1)
       # save frequency of patterns
       freq.pat <- rownames(mdpat) %>%
         as.numeric()
@@ -103,13 +105,18 @@ plot_miss <-
         fill = "",
         alpha = ""
       ) +
-      ggplot2::coord_cartesian(expand = FALSE) +
       theme_minimice()
+    # additional arguments
     if(border){
       gg <- gg + ggplot2::geom_tile(color = "black")
     } else{
       gg <- gg + ggplot2::geom_tile()
     }
+    if (square) {
+      gg <- gg + ggplot2::coord_fixed(expand = FALSE)
+    } else {
+      gg <- gg + ggplot2::coord_cartesian(expand = FALSE)
+    }
     if(ordered){
       gg <- gg +
         ggplot2::theme(axis.text.y = ggplot2::element_blank(),
diff --git a/R/plot_variance.R b/R/plot_variance.R
new file mode 100644
index 00000000..ff63b6ab
--- /dev/null
+++ b/R/plot_variance.R
@@ -0,0 +1,71 @@
+#' Plot the scaled between imputation variance for every cell as a heatmap
+#'
+#' This function plots the cell-level between imputation variance. The function
+#' scales the variances column-wise, without centering cf. `base::scale(center = FALSE)`
+#' and plots the data image as a heatmap. Darker red cells indicate more variance,
+#' lighter cells indicate less variance. White cells represent observed cells or unobserved cells with zero between
+#' imputation variance.
+#'
+#' @param data A package `mice` generated multiply imputed data set of class
+#' `mids`. Non-`mids` objects that have not been generated with `mice::mice()`
+#' can be converted through a pipeline with `mice::as.mids()`.
+#' @param grid Logical indicating whether grid lines should be displayed.
+#'
+#' @return An object of class `ggplot`.
+#' @examples
+#' imp <- mice::mice(mice::nhanes, printFlag = FALSE)
+#' plot_variance(imp)
+#' @export
+plot_variance <- function(data, grid = TRUE) {
+  verify_data(data, imp = TRUE)
+  if (data$m < 2) {
+    cli::cli_abort(
+      c(
+        "The between imputation variance cannot be computed if there are fewer than two imputations (m < 2).",
+        "i" = "Please provide an object with 2 or more imputations"
+      )
+    )
+  }
+  if (grid) {
+    gridcol <- "black"
+  } else {
+    gridcol <- NA
+  }
+
+  gg <- mice::complete(data, "long") %>%
+    dplyr::mutate(dplyr::across(where(is.factor), as.numeric)) %>%
+    dplyr::select(-.imp) %>%
+    dplyr::group_by(.id) %>%
+    dplyr::summarise(dplyr::across(dplyr::everything(), stats::var)) %>%
+    dplyr::ungroup() %>%
+    dplyr::mutate(dplyr::across(.cols = -.id, ~ scale_above_zero(.))) %>%
+    tidyr::pivot_longer(cols = -.id) %>%
+    ggplot2::ggplot(ggplot2::aes(name, .id, fill = value)) +
+    ggplot2::geom_tile(color = gridcol) +
+    ggplot2::scale_fill_gradient(low = "white", high = mice::mdc(2)) +
+    ggplot2::labs(
+      x = "Column name",
+      y = "Row number",
+      fill = "Imputation variability*
+      ",
+      caption = "*scaled cell-level between imputation variance"
+    ) + # "Cell-level between imputation\nvariance (scaled)\n\n"
+    ggplot2::scale_x_discrete(position = "top", expand = c(0, 0)) +
+    ggplot2::scale_y_continuous(trans = "reverse", expand = c(0, 0)) +
+    theme_minimice()
+
+  if (!grid) {
+    gg <-
+      gg + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA))
+  }
+
+  # return the ggplot object
+  return(gg)
+}
+
+# function to scale only non-zero values without centering
+scale_above_zero <- function(x) {
+  x <- as.matrix(x)
+  x[x != 0] <- scale(x[x != 0], center = FALSE)
+  return(x)
+}
diff --git a/man/plot_miss.Rd b/man/plot_miss.Rd
index 246d1d4b..c1103c89 100644
--- a/man/plot_miss.Rd
+++ b/man/plot_miss.Rd
@@ -9,6 +9,7 @@ plot_miss(
   vrb = "all",
   border = FALSE,
   row.breaks = nrow(data),
+  square = TRUE,
   ordered = FALSE
 )
 }
@@ -21,6 +22,8 @@ plot_miss(
 
 \item{row.breaks}{Optional numeric input specifying the number of breaks to be visualized on the y axis.}
 
+\item{square}{Logical indicating whether the plot tiles should be squares, defaults to squares.}
+
 \item{ordered}{Logical indicating whether rows should be ordered according to their pattern.}
 }
 \value{
diff --git a/man/plot_variance.Rd b/man/plot_variance.Rd
new file mode 100644
index 00000000..ade26899
--- /dev/null
+++ b/man/plot_variance.Rd
@@ -0,0 +1,29 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/plot_variance.R
+\name{plot_variance}
+\alias{plot_variance}
+\title{Plot the scaled between imputation variance for every cell as a heatmap}
+\usage{
+plot_variance(data, grid = TRUE)
+}
+\arguments{
+\item{data}{A package \code{mice} generated multiply imputed data set of class
+\code{mids}. Non-\code{mids} objects that have not been generated with \code{mice::mice()}
+can be converted through a pipeline with \code{mice::as.mids()}.}
+
+\item{grid}{Logical indicating whether grid lines should be displayed.}
+}
+\value{
+An object of class \code{ggplot}.
+}
+\description{
+This function plots the cell-level between imputation variance. The function
+scales the variances column-wise, without centering cf. \code{base::scale(center = FALSE)}
+and plots the data image as a heatmap. Darker red cells indicate more variance,
+lighter cells indicate less variance. White cells represent observed cells or unobserved cells with zero between
+imputation variance.
+}
+\examples{
+imp <- mice::mice(mice::nhanes, printFlag = FALSE)
+plot_variance(imp)
+}
diff --git a/tests/testthat/test-plot_miss.R.R b/tests/testthat/test-plot_miss.R.R
index 2fd60495..d08d30d8 100644
--- a/tests/testthat/test-plot_miss.R.R
+++ b/tests/testthat/test-plot_miss.R.R
@@ -4,7 +4,7 @@ dat <- mice::nhanes
 # tests
 test_that("plot_miss produces plot", {
   expect_s3_class(plot_miss(dat), "ggplot")
-  expect_s3_class(plot_miss(dat), "ggplot")
+  expect_s3_class(plot_miss(dat, border = TRUE, ordered = T, row.breaks = 25, square = TRUE), "ggplot")
   expect_s3_class(plot_miss(cbind(dat, "testvar" = NA)), "ggplot")
 })
 
@@ -17,10 +17,8 @@ test_that("plot_miss works with different inputs", {
 
 
 test_that("plot_miss with incorrect argument(s)", {
-  expect_output(plot_miss(na.omit(dat)))
+  expect_s3_class(plot_miss(na.omit(dat)), "ggplot")
   expect_error(plot_miss("test"))
   expect_error(plot_miss(dat, vrb = "test"))
-  expect_error(plot_miss(dat, cluster = "test"))
   expect_error(plot_miss(cbind(dat, .x = NA)))
-  expect_error(plot_miss(dat, npat = "test"))
 })