From c602d43f49689560ac37278e39c61775c165a204 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Sat, 19 Oct 2024 19:09:22 +1300 Subject: [PATCH] Install torch in GitHub actions (#136) --- .github/workflows/ci.yml | 13 +++++++++++++ .github/workflows/documentation.yml | 5 ++++- docs/make.jl | 7 ++----- docs/src/tutorials/pytorch.jl | 17 +++++++++++++---- test/Project.toml | 1 - test/test_PythonCall.jl | 13 ++++++++----- 6 files changed, 40 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0414e87..95fc321 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,8 +22,21 @@ jobs: - version: '1' os: ubuntu-latest arch: x64 + env: + JULIA_CONDAPKG_BACKEND: "Null" + JULIA_PYTHONCALL_EXE: "python3" steps: - uses: actions/checkout@v4 + # Install pytorch + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + - name: Install pytorch + run: | + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install numpy + # Install Julia - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index bf050c5..6052a76 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -15,6 +15,8 @@ jobs: env: GKSwstype: nul DATADEPS_ALWAYS_ACCEPT: true + JULIA_CONDAPKG_BACKEND: "Null" + JULIA_PYTHONCALL_EXE: "python3" steps: - uses: actions/checkout@v4 # Install pytorch @@ -24,7 +26,8 @@ jobs: python-version: '3.10' - name: Install pytorch run: | - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install numpy # Install Julia - uses: julia-actions/setup-julia@latest with: diff --git a/docs/make.jl b/docs/make.jl index bd6f4ca..b4309f5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,9 +37,6 @@ function _literate_directory(dir) rm(filename) end for filename in _file_list(dir, dir, ".jl") - if endswith(filename, "pytorch.jl") - continue # Skip for now - end # `include` the file to test it before `#src` lines are removed. It is # in a testset to isolate local variables between files. Test.@testset "$(filename)" begin @@ -80,14 +77,14 @@ Documenter.makedocs(; "manual/Flux.md", "manual/GLM.md", "manual/Lux.md", - # "manual/PyTorch.md", + "manual/PyTorch.md", ], "Tutorials" => [ "tutorials/student_enrollment.md", "tutorials/decision_trees.md", "tutorials/mnist.md", "tutorials/mnist_lux.md", - # "tutorials/pytorch.md", + "tutorials/pytorch.md", "tutorials/gaussian.md", ], "Developers" => ["developers/design_principles.md"], diff --git a/docs/src/tutorials/pytorch.jl b/docs/src/tutorials/pytorch.jl index 3c772a7..daf0f40 100644 --- a/docs/src/tutorials/pytorch.jl +++ b/docs/src/tutorials/pytorch.jl @@ -18,15 +18,24 @@ # over how to link Julia to an existing Python environment. For example, if you # have an existing Python installation (with PyTorch installed), and it is # available in the current conda environment, set: - -ENV["JULIA_CONDAPKG_BACKEND"] = "Current" - +# +# ```julia +# ENV["JULIA_CONDAPKG_BACKEND"] = "Current" +# ``` +# # before importing PythonCall.jl. If the Python installation can be found on # the path and it is not in a conda environment, set: - +# # ```julia # ENV["JULIA_CONDAPKG_BACKEND"] = "Null" # ``` +# +# If `python` is not on your path, you may additionally need to set +# `JULIA_PYTHONCALL_EXE`, for example, to: +# +# ```julia +# ENV["JULIA_PYTHONCALL_EXE"] = "python3" +# ``` # ## Required packages diff --git a/test/Project.toml b/test/Project.toml index d851cbd..d292d27 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/test/test_PythonCall.jl b/test/test_PythonCall.jl index baaa873..cc2ef18 100644 --- a/test/test_PythonCall.jl +++ b/test/test_PythonCall.jl @@ -17,11 +17,14 @@ import PythonCall is_test(x) = startswith(string(x), "test_") function runtests() - try - PythonCall.pyimport("torch") - catch - @warn("Skipping PythonCall tests because we cannot import PyTorch.") - return + # If we're running the tests locally, allow skipping Python tests + if get(ENV, "CI", "false") == "false" + try + PythonCall.pyimport("torch") + catch + @warn("Skipping PythonCall tests because we cannot import PyTorch.") + return + end end @testset "$name" for name in filter(is_test, names(@__MODULE__; all = true)) getfield(@__MODULE__, name)()