-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changed to support TensorFlow Probability
- Loading branch information
Jack Baker
committed
Sep 12, 2018
1 parent
fd274fc
commit 2f8ee97
Showing
37 changed files
with
288 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
Package: sgmcmc | ||
Type: Package | ||
Title: Stochastic Gradient Markov Chain Monte Carlo | ||
Version: 0.2.2 | ||
Version: 0.2.3 | ||
Authors@R: c( | ||
person("Jack", "Baker", email = "[email protected]", role = c("aut", "cre", "cph")), | ||
person( "Christopher", "Nemeth", role = c("aut", "cph") ), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Check tensorflow is installed. If it isn't throw an error. | ||
checkTFInstall <- function() { | ||
if ( !get("TF", envir = tf_status) ) { | ||
stop(tfErrorMsg(), call. = FALSE) | ||
} else if ( !get("TFP", envir = tf_status) ) { | ||
stop(tfpErrorMsg(), call. = FALSE) | ||
} | ||
} | ||
|
||
# If there is an error building the posterior print a hopefully more readable error message | ||
getPosteriorBuildError <- function(e) { | ||
stop(buildErrorMsg(e), call. = FALSE) | ||
} | ||
|
||
|
||
tfErrorMsg <- function() { | ||
msg <- "\nNo TensorFlow python installation found.\n" | ||
msg <- paste0(msg, "This can be installed using the installTF() function.\n") | ||
return(msg) | ||
} | ||
|
||
|
||
tfpErrorMsg <- function() { | ||
msg <- "\nNo TensorFlow Probability python installation found.\n" | ||
msg <- paste0(msg, "This can be installed using the installTF() function.\n") | ||
return(msg) | ||
} | ||
|
||
|
||
buildErrorMsg = function(e) { | ||
msg <- "Problem building log posterior estimate from supplied logLik and logPrior functions.\n\n" | ||
msg <- paste0(msg, "Python error output:\n", e) | ||
msg <- paste0(msg, "\n", | ||
"Check your tensorflow code specifying the logLik and logPrior functions is correct.\n") | ||
msg <- paste0(msg, "Ensure constants in logLik and logPrior functions are specified as ", | ||
"type float32 using \ntf$constant(.., dtype = tf$float32) -- see the tutorials for some examples.") | ||
return(msg) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,52 @@ | ||
# Environment determining status of the TensorFlow Installation. | ||
# This allows a custom error message to be displayed. | ||
tf_status <- new.env() | ||
|
||
|
||
# Load TensorFlow Probability and add the contents to tf$distributions. | ||
.onLoad <- function(libname, pkgname) { | ||
# Check TensorFlow is installed and load tensorflow_probability | ||
# If either are not installed display a custom install message | ||
# Set tf$distributions to be tfp$distributions | ||
tryCatch(tfp <- loadTF(), | ||
error = function(e) { | ||
tfMissing() | ||
} | ||
) | ||
|
||
# Set default tf_status that everything is installed correctly. | ||
assign("TF", TRUE, envir = tf_status) | ||
assign("TFP", TRUE, envir = tf_status) | ||
# Check TensorFlow is installed. Update tf_status accordingly. | ||
checkTF() | ||
# If checkTF was not successful, return to avoid printing multiple messages | ||
if (!get("TF", envir = tf_status)) { | ||
return() | ||
} | ||
# Check TensorFlow Probability is installed, and load in. Update tf_status accordingly. | ||
tryCatch(loadTFP(), error = function(e) tfpMissing(e)) | ||
} | ||
|
||
|
||
# Build message if TensorFlow Probability missing | ||
tfMissing <- function() { | ||
message("\nNo TensorFlow or TensorFlow Probability python installation found.") | ||
message("This can be installed using the installTF() function.\n") | ||
# Set custom error message incase user still tries to use tf | ||
assign("on_error", function (e) error_fn(e), env = tf) | ||
# Check tensorflow installed by doing a dummy operation that will throw an error | ||
checkTF = function() { | ||
tryCatch(temp <- tf$constant(4), | ||
error = function (e) tfMissing()) | ||
} | ||
|
||
|
||
# Check TensorFlow installed and load TensorFlow probability | ||
loadTF = function() { | ||
# Check tensorflow installed by doing a dummy operation that will throw an error | ||
temp <- tf$constant(4) | ||
|
||
# Delay load tensorflow_probability as tfp using reticulate package. | ||
tfp <- reticulate::import("tensorflow_probability", delay_load = list( | ||
priority = 5, | ||
environment = "r-tensorflow" | ||
)) | ||
|
||
# Set tfp$distributions to be tf$distributions | ||
# Load tensorflow probability and assign distns to tf$distributions. | ||
# If this fails, print message and update tf_status | ||
loadTFP <- function() { | ||
import_opts <- list(priority = 5, environment = "r-tensorflow") | ||
tfp <- reticulate::import("tensorflow_probability", delay_load = import_opts) | ||
tf$distributions <- tfp$distributions | ||
} | ||
|
||
|
||
error_fn = function(e) { | ||
stop(tfErrorMsg(), call. = FALSE) | ||
# Build message if TensorFlow missing. Update tf_status | ||
tfMissing <- function() { | ||
message("\nNo TensorFlow python installation found.") | ||
message("This can be installed using the installTF() function.\n") | ||
assign("TF", FALSE, envir = tf_status) | ||
assign("TFP", FALSE, envir = tf_status) | ||
} | ||
|
||
|
||
# Build error message for TensorFlow configuration errors | ||
tfErrorMsg <- function() { | ||
message <- "Installation of TensorFlow or TensorFlow Probability not found.\n" | ||
message <- paste0(message, | ||
"These can be installed by running the installTF() function.") | ||
return(message) | ||
# Build message if TensorFlow Probability missing. Update tf_status | ||
tfpMissing <- function(e) { | ||
message("\nNo TensorFlow Probability python installation found.") | ||
message("This can be installed using the installTF() function.\n") | ||
assign("TFP", FALSE, envir = tf_status) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.