diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a65615f..1ef244d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -3,7 +3,8 @@ on: push: branches: - main - tags: '*' + tags: + - '*' pull_request: concurrency: # Skip intermediate builds: always. @@ -17,14 +18,15 @@ jobs: strategy: fail-fast: false matrix: - version: - - '1.7' - - '1.8' - - '~1.9.0-0' - os: - - ubuntu-latest - arch: - - x64 + version: ['1.6', '1'] + os: [ubuntu-latest] + arch: [x64] + allow_failure: [false] + include: + - version: 'nightly' + os: ubuntu-latest + arch: x64 + allow_failure: true steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/.gitignore b/.gitignore index 79280c3..2132e19 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ *.jl.mem /Manifest.toml /docs/build/ -docs/src/examples/*.md \ No newline at end of file +docs/src/examples/*.md +test/playground.jl \ No newline at end of file diff --git a/CITATION.bib b/CITATION.bib index a48d829..6e84fc2 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl author = {Guillaume Dalle, Mohamed Tarek and contributors}, title = {ImplicitDifferentiation.jl}, url = {https://github.com/gdalle/ImplicitDifferentiation.jl}, - version = {v0.3.0}, + version = {v0.4.0}, year = {2023}, - month = {3} + month = {4} } diff --git a/Project.toml b/Project.toml index 55ac717..4429ec2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,44 +1,49 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.3.0" +version = "0.4.0" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] +AbstractDifferentiation = "0.5" ChainRulesCore = "1.14" +ForwardDiff = "0.10" Krylov = "0.8, 0.9" LinearOperators = "2.2" -julia = "1.7" +Requires = "1.3" +julia = "1.6" + +[extensions] +ImplicitDifferentiationChainRulesExt = "ChainRulesCore" +ImplicitDifferentiationForwardDiffExt = "ForwardDiff" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -Convex = "f65535da-76fb-5f13-bab9-19810c17039a" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ForwardDiffChainRules = "c9556dd2-1aed-4cfe-8560-1557cf593001" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -MathOptSetDistances = "3b969827-a86c-476c-9527-bb6f1a8fbad5" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Convex", "Distances", "Documenter", "FiniteDifferences", "ForwardDiff", "ForwardDiffChainRules", "JET", "JuliaFormatter", "LinearAlgebra", "LinearOperators", "MathOptInterface", "MathOptSetDistances", "Optim", "Pkg", "Random", "SCS", "SparseArrays", "Test", "Zygote"] +test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "LinearAlgebra", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "Test", "Zygote"] + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/README.md b/README.md index 92f7160..a5defc7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ImplicitDifferentiation.jl - +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://gdalle.github.io/ImplicitDifferentiation.jl/stable) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://gdalle.github.io/ImplicitDifferentiation.jl/dev) [![Build Status](https://github.com/gdalle/ImplicitDifferentiation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/gdalle/ImplicitDifferentiation.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/gdalle/ImplicitDifferentiation.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/gdalle/ImplicitDifferentiation.jl) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 3a05cd7..8da2095 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "bc626ac97bad3fec2d3c66010fe004ea07631115" +project_hash = "4c2333aa8abadf9a95e60e340f908dd68a4d2e49" [[deps.AMD]] deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"] @@ -15,17 +15,18 @@ git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" version = "0.0.1" +[[deps.AbstractDifferentiation]] +deps = ["ChainRulesCore", "ExprTools", "LinearAlgebra", "Requires"] +git-tree-sha1 = "f83fd553acff1c6a7f5c4e6f5f2b5941d533cdc9" +uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" +version = "0.5.2" + [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] git-tree-sha1 = "16b6dbc4cf7caee4e1e75c49485ec67b667098a0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.3.1" -[[deps.AbstractTrees]] -git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.4" - [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" @@ -48,24 +49,6 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[deps.BenchmarkTools]] -deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "d9a9701b899b30332bbcb3e1679c41cce81fb0e8" -uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.3.2" - -[[deps.BlockDiagonals]] -deps = ["ChainRulesCore", "FillArrays", "FiniteDifferences", "LinearAlgebra"] -git-tree-sha1 = "ffd635c19b56f50d1d4278d876219644299b5711" -uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -version = "0.1.41" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+0" - [[deps.CEnum]] git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -90,27 +73,10 @@ uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" version = "1.10.1" [[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e" +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "f84967c4497e0e1955f9a582c232b02847c5f589" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.6" - -[[deps.CodecBzip2]] -deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] -git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" -uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" -version = "0.7.2" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "9c209fb7536406834aa938fb149964b985de6c83" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.1" - -[[deps.CommonSolve]] -git-tree-sha1 = "9441451ee712d1aec22edad62db1a9af3dc8d852" -uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" -version = "0.2.3" +version = "0.1.7" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -129,24 +95,12 @@ deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.0.1+0" -[[deps.ComponentArrays]] -deps = ["ArrayInterface", "ChainRulesCore", "ConstructionBase", "ForwardDiff", "GPUArrays", "LinearAlgebra", "RecursiveArrayTools", "Requires", "ReverseDiff", "SciMLBase", "StaticArrayInterface", "StaticArrays"] -git-tree-sha1 = "bbc7dd1b536d2950c2a91801283c609ef74e72fe" -uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -version = "0.13.9" - [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.1" -[[deps.Convex]] -deps = ["AbstractTrees", "BenchmarkTools", "LDLFactorizations", "LinearAlgebra", "MathOptInterface", "OrderedCollections", "SparseArrays", "Test"] -git-tree-sha1 = "af4188609c0620ed4b0e4493ed416d3c8b2dadeb" -uuid = "f65535da-76fb-5f13-bab9-19810c17039a" -version = "0.15.3" - [[deps.DataAPI]] git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" @@ -179,12 +133,6 @@ git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.13.0" -[[deps.DifferentiableFlatten]] -deps = ["ChainRulesCore", "LinearAlgebra", "NamedTupleTools", "OrderedCollections", "Requires", "SparseArrays"] -git-tree-sha1 = "f4dc2c1d994c7e2e602692a7dadd2ac79212c3a9" -uuid = "c78775a3-ee38-4681-b694-0504db4f5dc7" -version = "0.1.1" - [[deps.Distances]] deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2" @@ -212,11 +160,6 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.EnumX]] -git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" -uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -version = "1.0.4" - [[deps.ExprTools]] git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -254,23 +197,6 @@ git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.35" -[[deps.ForwardDiffChainRules]] -deps = ["ChainRulesCore", "DifferentiableFlatten", "ForwardDiff", "MacroTools"] -git-tree-sha1 = "55d07c37b391e47a2dd254d59f7bc31cc85d74d5" -uuid = "c9556dd2-1aed-4cfe-8560-1557cf593001" -version = "0.2.0" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.3" - -[[deps.FunctionWrappersWrappers]] -deps = ["FunctionWrappers"] -git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" -uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" -version = "0.1.3" - [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" @@ -299,22 +225,11 @@ git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.9" -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - [[deps.ImplicitDifferentiation]] -deps = ["ChainRulesCore", "Krylov", "LinearOperators"] -git-tree-sha1 = "25ae10b4942b9405a0ca63524270ba8b6da8694f" +deps = ["AbstractDifferentiation", "ChainRulesCore", "Krylov", "LinearOperators", "Requires"] +path = ".." uuid = "57b37032-215b-411a-8a7c-41a003a55207" -version = "0.3.0" - -[[deps.IntelOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" -uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.1.0+0" +version = "0.4.0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -322,9 +237,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f" +git-tree-sha1 = "6667aadd1cdee2c6cd068128b3d226ebc4fb0c67" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.8" +version = "0.1.9" [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" @@ -372,12 +287,6 @@ git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" version = "0.0.21+0" -[[deps.Lazy]] -deps = ["MacroTools"] -git-tree-sha1 = "1370f8202dac30758f3c345f9909b97f53d87d3f" -uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" -version = "0.15.1" - [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -435,12 +344,6 @@ version = "0.3.23" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[deps.MKL_jll]] -deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "2ce8695e1e699b68702c03402672a69f54b8aca9" -uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2022.2.0+0" - [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" @@ -451,18 +354,6 @@ version = "0.5.10" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[deps.MathOptInterface]] -deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "Printf", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Test", "Unicode"] -git-tree-sha1 = "3b38f6fbd62cbd61d8dbf625136d7b75478bf2c5" -uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -version = "1.15.0" - -[[deps.MathOptSetDistances]] -deps = ["BlockDiagonals", "ChainRulesCore", "FillArrays", "LinearAlgebra", "MathOptInterface", "StaticArrays"] -git-tree-sha1 = "9511d5196e08b10e25ca361735fc523951204b4d" -uuid = "3b969827-a86c-476c-9527-bb6f1a8fbad5" -version = "0.2.7" - [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" @@ -481,39 +372,28 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2022.2.1" -[[deps.MutableArithmetics]] -deps = ["LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "3295d296288ab1a0a2528feb424b854418acff57" -uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" -version = "1.2.3" - [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" +[[deps.NLsolve]] +deps = ["Distances", "LineSearches", "LinearAlgebra", "NLSolversBase", "Printf", "Reexport"] +git-tree-sha1 = "019f12e9a1a7880459d0173c182e6a99365d7ac1" +uuid = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +version = "4.5.1" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "1.0.2" -[[deps.NamedTupleTools]] -git-tree-sha1 = "90914795fc59df44120fe3fff6742bb0d7adb1d0" -uuid = "d9ec5142-1e00-5aa0-9d6a-321866360f50" -version = "0.14.3" - [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" -[[deps.OpenBLAS32_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9c6c2ed4b7acd2137b878eb96c68e63b76199d0f" -uuid = "656ef2d0-ae68-5445-9ca0-591084a874a2" -version = "0.3.17+0" - [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" @@ -574,10 +454,6 @@ version = "1.3.0" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[deps.Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" - [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -592,18 +468,6 @@ git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" version = "0.1.0" -[[deps.RecipesBase]] -deps = ["SnoopPrecompile"] -git-tree-sha1 = "261dddd3b862bd2c940cf6ca4d1c8fe593e457c8" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.3" - -[[deps.RecursiveArrayTools]] -deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "140cddd2c457e4ebb0cdc7c2fd14a7fbfbdf206e" -uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "2.38.3" - [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" @@ -615,64 +479,16 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" -[[deps.ReverseDiff]] -deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "a8d90f5bf4880df810a13269eb5e3e29f22cbd96" -uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.14.5" - [[deps.Richardson]] deps = ["LinearAlgebra"] git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" uuid = "708f8203-808e-40c0-ba2d-98a6953ed40d" version = "1.4.0" -[[deps.RuntimeGeneratedFunctions]] -deps = ["ExprTools", "SHA", "Serialization"] -git-tree-sha1 = "f139e81a81e6c29c40f1971c9e5309b09c03f2c3" -uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" -version = "0.5.6" - -[[deps.SCS]] -deps = ["MathOptInterface", "Requires", "SCS_GPU_jll", "SCS_MKL_jll", "SCS_jll", "SparseArrays"] -git-tree-sha1 = "05c1ed62a8d78827d0dd1a9fa04040a4a254bf08" -uuid = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13" -version = "1.1.4" - -[[deps.SCS_GPU_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "OpenBLAS32_jll"] -git-tree-sha1 = "6a61274837cfa050bd996910d347e876bef3a6b3" -uuid = "af6e375f-46ec-5fa0-b791-491b0dfa44a4" -version = "3.2.3+1" - -[[deps.SCS_MKL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "MKL_jll"] -git-tree-sha1 = "15b887b4b1f747f98b22fba2225fe7cd26861cea" -uuid = "3f2553a9-4106-52be-b7dd-865123654657" -version = "3.2.3+0" - -[[deps.SCS_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "OpenBLAS32_jll"] -git-tree-sha1 = "e4902566d6207206c27fe6f45e8c2d28c34889df" -uuid = "f4f2fc5b-1d94-523c-97ea-2ab488bedf4b" -version = "3.2.3+0" - [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.SciMLBase]] -deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SnoopPrecompile", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"] -git-tree-sha1 = "392d3e28b05984496af37100ded94dc46fa6c8de" -uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "1.91.7" - -[[deps.SciMLOperators]] -deps = ["ArrayInterface", "DocStringExtensions", "Lazy", "LinearAlgebra", "Setfield", "SparseArrays", "StaticArraysCore", "Tricks"] -git-tree-sha1 = "e61e48ef909375203092a6e83508c8416df55a83" -uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -version = "0.2.0" - [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -707,18 +523,6 @@ git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.2.0" -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "08be5ee09a7632c32695d954a602df96a877bf0d" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.6" - -[[deps.StaticArrayInterface]] -deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "Static", "SuiteSparse"] -git-tree-sha1 = "33040351d2403b84afce74dae2e22d3f5b18edcb" -uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" -version = "1.4.0" - [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99" @@ -756,12 +560,6 @@ version = "0.6.15" deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" -[[deps.SymbolicIndexingInterface]] -deps = ["DocStringExtensions"] -git-tree-sha1 = "f8ab052bfcbdb9b48fad2c80c873aa0d0344dfe5" -uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.2.2" - [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" @@ -794,23 +592,6 @@ git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.22" -[[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.12" - -[[deps.Tricks]] -git-tree-sha1 = "aadb748be58b492045b4f56166b5188aa63ce549" -uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" -version = "0.1.7" - -[[deps.TruncatedStacktraces]] -deps = ["InteractiveUtils", "MacroTools", "Preferences"] -git-tree-sha1 = "7bc1632a4eafbe9bd94cf1a784a9a4eb5e040a91" -uuid = "781d530d-4396-4725-bb49-402e4bee1e77" -version = "1.3.0" - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/docs/Project.toml b/docs/Project.toml index aaaef96..9720157 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,20 +1,13 @@ [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -Convex = "f65535da-76fb-5f13-bab9-19810c17039a" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ForwardDiffChainRules = "c9556dd2-1aed-4cfe-8560-1557cf593001" ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -MathOptSetDistances = "3b969827-a86c-476c-9527-bb6f1a8fbad5" +NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index 716f406..70576c8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -46,7 +46,12 @@ for file in sort(readdir(EXAMPLES_DIR_MD)) end end -pages = ["Home" => "index.md", "API reference" => "api.md", "Examples" => example_pages] +pages = [ + "Home" => "index.md", + "API reference" => "api.md", + "Examples" => example_pages, + "FAQ" => "faq.md", +] format = Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", @@ -63,7 +68,8 @@ makedocs(; format=format, pages=pages, linkcheck=true, - strict=true, ) -deploydocs(; repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main") +deploydocs(; + repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main", push_preview=true +) diff --git a/docs/src/api.md b/docs/src/api.md index b6c9dc4..51d9867 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,4 +9,4 @@ ```@autodocs Modules = [ImplicitDifferentiation] -``` \ No newline at end of file +``` diff --git a/docs/src/faq.md b/docs/src/faq.md new file mode 100644 index 0000000..baafbdc --- /dev/null +++ b/docs/src/faq.md @@ -0,0 +1,46 @@ +# Frequently Asked Questions + +## Higher-dimensional arrays + +For simplicity, the examples only display functions that work on vectors. +However, arbitrary array sizes are supported. +Beware however, sparse arrays will be densified in the differentiation process. + +## Multiple inputs / outputs + +In this package, implicit functions can only take a single input array `x` and output a single output array `y` (plus the additional info `z`). +But sometimes, your forward pass or conditions may require multiple input arrays, say `a` and `b`: + +```julia +function f(a, b) + # do stuff + return y, z +end +``` + +In that case, you should gather the inputs inside a single `ComponentVector` from [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) and define a new method: + +```julia +f(x::ComponentVector) = f(x.a, x.b) +``` + +The same trick works for multiple outputs. + +## Constrained optimization modeling + +To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions. +See [_Efficient and modular implicit differentiation_](https://arxiv.org/abs/2105.15183) for precise formulations. + +In case these operators are too complicated to code them yourself, here are a few places you can look: + +- [MathOptSetDistances.jl](https://github.com/matbesancon/MathOptSetDistances.jl) +- [ProximalOperators.jl](https://github.com/JuliaFirstOrder/ProximalOperators.jl) + +An alternative is differentiating through the KKT conditions, which is exactly what [DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl) does for JuMP models. + +## Which autodiff backends are supported? + +- Forward mode: ForwardDiff.jl +- Reverse mode: all the packages compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) + +In the future, we would like to add [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) support. diff --git a/docs/src/index.md b/docs/src/index.md index f1f279e..6fb62af 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,7 +4,17 @@ CurrentModule = ImplicitDifferentiation # ImplicitDifferentiation.jl -[ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl) is a package for automatic differentiation of functions defined implicitly. +[ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl) is a package for automatic differentiation of functions defined implicitly, i.e., mappings + +```math +x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m +``` + +whose output is defined by conditions + +```math +F(x,y(x)) = 0 \in \mathbb{R}^m +``` ## Background @@ -29,11 +39,8 @@ For the latest version, run this instead: julia> using Pkg; Pkg.add(url="https://github.com/gdalle/ImplicitDifferentiation.jl") ``` -Check out the API reference to know more about the main object defined here, [`ImplicitFunction`](@ref). -The tutorials give you some ideas of real-life applications for our package. - ## Related projects - [DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiation of convex optimization problems -- [InferOpt.jl](https://github.com/axelparmentier/InferOpt.jl): differentiation of combinatorial optimization problems +- [InferOpt.jl](https://github.com/axelparmentier/InferOpt.jl): approximate differentiation of combinatorial optimization problems - [NonconvexUtils.jl](https://github.com/JuliaNonconvex/NonconvexUtils.jl): contains the original implementation from which this package drew inspiration diff --git a/examples/0_basic.jl b/examples/0_basic.jl new file mode 100644 index 0000000..7a6fd44 --- /dev/null +++ b/examples/0_basic.jl @@ -0,0 +1,175 @@ +# # Basic use + +#= +In this example, we demonstrate the basics of our package on a simple function that is not amenable to automatic differentiation. +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using Random +using Test #src +using Zygote + +Random.seed!(63); + +# ## Why do we bother? + +#= +[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [Zygote.jl](https://github.com/FluxML/Zygote.jl) are two prominent packages for automatic differentiation in Julia. +While they are very generic, there are simple language constructs that they cannot differentiate through. +=# + +function mysqrt(x::AbstractArray) + a = [0.0] + a[1] = first(x) + return sqrt.(x) +end + +#= +This is essentially the componentwise square root function but with an additional twist: `a::Vector{Float64}` is created internally, and its only element is replaced with the first element of `x`. +We can check that it does what it's supposed to do. +=# + +x = rand(2) +mysqrt(x) ≈ sqrt.(x) +@test mysqrt(x) ≈ sqrt.(x) #src + +#= +Of course the Jacobian has an explicit formula. +=# + +J = Diagonal(0.5 ./ sqrt.(x)) + +#= +However, things start to go wrong when we compute it with autodiff, due to the [limitations of ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) and [those of Zygote.jl](https://fluxml.ai/Zygote.jl/stable/limitations/). +=# + +try + ForwardDiff.jacobian(mysqrt, x) +catch e + e +end +@test_throws MethodError ForwardDiff.jacobian(mysqrt, x) #src + +#= +ForwardDiff.jl throws an error because it tries to call `mysqrt` with an array of dual numbers, and cannot use one of these numbers to fill `a` (which has element type `Float64`). +=# + +try + Zygote.jacobian(mysqrt, x) +catch e + e +end +@test_throws ErrorException Zygote.jacobian(mysqrt, x) #src + +#= +Zygote.jl also throws an error because it cannot handle mutation. +=# + +# ## Implicit function + +#= +The first possible use of ImplicitDifferentiation.jl is to overcome the limitations of automatic differentiation packages by defining functions (and computing their derivatives) implicitly. +An implicit function is a mapping +```math +x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m +``` +whose output is defined by conditions +```math +F(x,y(x)) = 0 \in \mathbb{R}^m +``` +We represent it using a type called `ImplicitFunction`, which you will see in action shortly. +=# + +#= +First we define a `forward` pass correponding to the function we consider. +It returns the actual output $y(x)$ of the function, as well as additional information $z(x)$. +Here we don't need any additional information, so we set it to $0$. +Importantly, this forward pass _doesn't need to be differentiable_. +=# + +function forward(x) + y = mysqrt(x) + z = 0 + return y, z +end + +#= +Then we define `conditions` $F(x, y, z) = 0$ that the output $y(x)$ is supposed to satisfy. +These conditions must be array-valued, with the same size as $y$, and take $z$ as an additional argument. +And unlike the forward pass, _the conditions need to be differentiable_ with respect to $x$ and $y$. +Here they are very obvious: the square of the square root should be equal to the original value. +=# + +function conditions(x, y, z) + c = y .^ 2 .- x + return c +end + +#= +Finally, we construct a wrapper `implicit` around the previous objects. +What does this wrapper do? +=# + +implicit = ImplicitFunction(forward, conditions) + +#= +When we call it as a function, it just falls back on `implicit.forward`, so unsurprisingly we get the same tuple $(y(x), z(x))$. +=# + +(first ∘ implicit)(x) ≈ sqrt.(x) +@test (first ∘ implicit)(x) ≈ sqrt.(x) #src + +#= +And when we try to compute its Jacobian, the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem) is applied in the background to circumvent the lack of differentiablility of the forward pass. +=# + +# ## Forward and reverse mode autodiff + +#= +Now ForwardDiff.jl works seamlessly. +=# + +ForwardDiff.jacobian(first ∘ implicit, x) ≈ J +@test ForwardDiff.jacobian(first ∘ implicit, x) ≈ J #src + +#= +And so does Zygote.jl. Hurray! +=# + +Zygote.jacobian(first ∘ implicit, x)[1] ≈ J +@test Zygote.jacobian(first ∘ implicit, x)[1] ≈ J #src + +# ## Second derivative + +#= +We can even go higher-order by mixing the two packages (forward-over-reverse mode). +The only technical requirement is to switch the linear solver to something that can handle dual numbers: +=# + +linear_solver(A, b) = (Matrix(A) \ b, (solved=true,)) +implicit2 = ImplicitFunction(forward, conditions, linear_solver) + +#= +Then the Jacobian itself is differentiable. +=# + +h = rand(2) +J_Z(t) = Zygote.jacobian(first ∘ implicit2, x .+ t .* h)[1] +ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) +@test ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src + +# The following tests are not included in the docs #src + +X = rand(2, 3, 4) #src +JJ = Diagonal(0.5 ./ sqrt.(vec(X))) #src +@test (first ∘ implicit)(X) ≈ sqrt.(X) #src +@test ForwardDiff.jacobian(first ∘ implicit, X) ≈ JJ #src +@test Zygote.jacobian(first ∘ implicit, X)[1] ≈ JJ #src + +# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 #src +@testset verbose = true "ChainRulesTestUtils.jl" begin #src + @test_skip test_rrule(implicit, x) #src + @test_skip test_rrule(implicit, X) #src +end #src diff --git a/examples/1_unconstrained_optim.jl b/examples/1_unconstrained_optim.jl new file mode 100644 index 0000000..e95342f --- /dev/null +++ b/examples/1_unconstrained_optim.jl @@ -0,0 +1,108 @@ +# # Unconstrained optimization + +#= +In this example, we show how to differentiate through the solution of an unconstrained optimization problem: +```math +y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) +``` +The optimality conditions are given by gradient stationarity: +```math +\nabla_2 f(x, y) = 0 +``` +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using Optim +using Random +using Test #src +using Zygote + +Random.seed!(63); + +# ## Implicit function + +#= +To make verification easy, we minimize the following objective: +```math +f(x, y) = \lVert y \odot y - x \rVert^2 +``` +In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl). +Note the presence of a keyword argument. +=# + +function mysqrt_optim(x; method) + f(y) = sum(abs2, y .^ 2 .- x) + y0 = ones(eltype(x), size(x)) + result = optimize(f, y0, method) + return Optim.minimizer(result) +end + +#= +First, we create the forward pass which returns the solution $y(x)$. +Remember that it should also return additional information $z(x)$, which is useless here. +=# +function forward_optim(x; method) + y = mysqrt_optim(x; method) + z = 0 + return y, z +end + +#= +Even though they are defined as a gradient, it is better to provide optimality conditions explicitly: that way we avoid nesting autodiff calls. +Remember, the conditions should accept three arguments to take additional information into account when needed. +Moreover, the forward pass and the conditions should accept the same set of keyword arguments. +=# + +function conditions_optim(x, y, z; method) + ∇₂f = 2 .* (y .^ 2 .- x) + return ∇₂f +end + +# We now have all the ingredients to construct our implicit function. + +implicit_optim = ImplicitFunction(forward_optim, conditions_optim) + +# And indeed, it behaves as it should when we call it: + +x = rand(2) + +#- + +first(implicit_optim(x; method=LBFGS())) .^ 2 +@test first(implicit_optim(x; method=LBFGS())) .^ 2 ≈ x #src + +#= +Let's see what the explicit Jacobian looks like. +=# + +J = Diagonal(0.5 ./ sqrt.(x)) + +# ## Forward mode autodiff + +ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x) +@test ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x) ≈ J #src + +#= +Unsurprisingly, the Jacobian is the identity. +In this instance, we could use ForwardDiff.jl directly on the solver, but it returns the wrong result (not sure why). +=# + +ForwardDiff.jacobian(_x -> mysqrt_optim(x; method=LBFGS()), x) + +# ## Reverse mode autodiff + +Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1] +@test Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1] ≈ J #src + +#= +Again, the Jacobian is the identity. +In this instance, we cannot use Zygote.jl directly on the solver (due to unsupported `try/catch` statements). +=# + +try + Zygote.jacobian(_x -> mysqrt_optim(x; method=LBFGS()), x)[1] +catch e + e +end diff --git a/examples/1_unconstrained_optimization.jl b/examples/1_unconstrained_optimization.jl deleted file mode 100644 index 9a9ceda..0000000 --- a/examples/1_unconstrained_optimization.jl +++ /dev/null @@ -1,147 +0,0 @@ -# # Unconstrained optimization - -#= -In this example, we show how to differentiate through the solution of the following unconstrained optimization problem: -```math -\hat{y}(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) -``` -The optimality conditions are given by gradient stationarity: -```math -F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \nabla_2 f(x, y) = 0 -``` -=# - -using ChainRulesTestUtils #src -using ForwardDiff -using ForwardDiffChainRules -using ImplicitDifferentiation -using LinearAlgebra #src -using Optim -using Random -using Test #src -using Zygote - -Random.seed!(63); - -# ## Implicit function wrapper - -#= -To make verification easy, we minimize a quadratic objective -```math -f(x, y) = \lVert y - x \rVert^2 -``` -In this case, the optimization algorithm is very simple (the identity function does the job), but still we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) to show that it doesn't change the result. -=# - -function dumb_identity(x::AbstractArray{Float64}) - f(y) = sum(abs2, y - x) - y0 = zero(x) - res = optimize(f, y0, LBFGS()) - y = Optim.minimizer(res) - return y -end; - -#= -On the other hand, optimality conditions should be provided explicitly whenever possible, so as to avoid nesting autodiff calls. -=# - -zero_gradient(x, y) = 2(y - x); - -# We now have all the ingredients to construct our implicit function. - -implicit = ImplicitFunction(dumb_identity, zero_gradient); - -# Time to test! - -x = rand(3, 2) - -# Let's start by taking a look at the forward pass, which should be the identity function. - -implicit(x) - -# ## Why bother? - -# It is important to understand why implicit differentiation is necessary here. Indeed, our optimization solver alone doesn't support autodiff with ForwardDiff.jl (due to type constraints) - -try - ForwardDiff.jacobian(dumb_identity, x) -catch e - e -end - -# ... nor is it compatible with Zygote.jl (due to unsupported `try/catch` statements). - -try - Zygote.jacobian(dumb_identity, x)[1] -catch e - e -end - -# ## Autodiff with Zygote.jl - -# If we use an autodiff package compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), such as [Zygote.jl](https://github.com/FluxML/Zygote.jl), implicit differentiation works out of the box. - -Zygote.jacobian(implicit, x)[1] - -# As expected, we recover the identity matrix as Jacobian. Strictly speaking, the Jacobian should be a 4D tensor, but it is flattened into a 2D matrix. - -# ## Autodiff with ForwardDiff.jl - -# If we want to use [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) instead, we run into a problem: custom chain rules are not directly translated into dual number dispatch. Luckily, [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) provides us with a workaround. All we need to do is to apply the following macro: - -@ForwardDiff_frule (f::typeof(implicit))(x::AbstractArray{<:ForwardDiff.Dual}; kwargs...) - -# And then things work like a charm! - -ForwardDiff.jacobian(implicit, x) - -# ## Higher order differentiation - -h = rand(size(x)); - -# Assuming we need second-order derivatives, nesting calls to Zygote.jl is generally a bad idea. We can, however, nest calls to ForwardDiff.jl. - -D(x, h) = ForwardDiff.derivative(t -> implicit(x .+ t .* h), 0) -DD(x, h1, h2) = ForwardDiff.derivative(t -> D(x .+ t .* h2, h1), 0); - -#- - -try - DD(x, h, h) # fails -catch e - e -end - -# The only requirement is to switch to a linear solver that is compatible with dual numbers (which the default `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) is not). - -linear_solver2(A, b) = (Matrix(A) \ b, (solved=true,)) -implicit2 = ImplicitFunction(dumb_identity, zero_gradient, linear_solver2); -@ForwardDiff_frule (f::typeof(implicit2))(x::AbstractArray{<:ForwardDiff.Dual}; kwargs...) - -D2(x, h) = ForwardDiff.derivative(t -> implicit2(x .+ t .* h), 0) -DD2(x, h1, h2) = ForwardDiff.derivative(t -> D2(x .+ t .* h2, h1), 0); - -#- - -DD2(x, h, h) - -# The following tests are not included in the docs. #src - -@testset verbose = true "ForwardDiff.jl" begin #src - @test_throws MethodError ForwardDiff.jacobian(dumb_identity, x) #src - @test ForwardDiff.jacobian(implicit, x) == I #src - @test all(DD2(x, h, h) .≈ 0) #src -end #src - -@testset verbose = true "Zygote.jl" begin #src - @test_throws Zygote.CompileError Zygote.jacobian(dumb_identity, x)[1] #src - @test Zygote.jacobian(implicit, x)[1] == I #src -end #src - -@testset verbose = false "ChainRulesTestUtils.jl (forward)" begin #src - test_frule(implicit, x; check_inferred=true, rtol=1e-3) #src -end #src - -@testset verbose = false "ChainRulesTestUtils.jl (reverse)" begin #src - test_rrule(implicit, x; check_inferred=true, rtol=1e-3) #src -end #src diff --git a/examples/2_nonlinear_solve.jl b/examples/2_nonlinear_solve.jl new file mode 100644 index 0000000..c564bf4 --- /dev/null +++ b/examples/2_nonlinear_solve.jl @@ -0,0 +1,93 @@ +# # Nonlinear solve + +#= +In this example, we show how to differentiate through the solution of a nonlinear system of equations: +```math +\text{find} \quad y(x) \quad \text{such that} \quad F(x, y(x)) = 0 +``` +The optimality conditions are pretty obvious: +```math +F(x, y) = 0 +``` +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using NLsolve +using Random +using Test #src +using Zygote + +Random.seed!(63); + +# ## Implicit function + +#= +To make verification easy, we solve the following system: +```math +F(x, y) = y \odot y - x = 0 +``` +In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). +=# + +function mysqrt_nlsolve(x; method) + F!(storage, y) = (storage .= y .^ 2 - x) + initial_y = ones(eltype(x), size(x)) + result = nlsolve(F!, initial_y; method) + return result.zero +end + +#- + +function forward_nlsolve(x; method) + y = mysqrt_nlsolve(x; method) + z = 0 + return y, z +end + +#- + +function conditions_nlsolve(x, y, z; method) + F = y .^ 2 .- x + return F +end + +#- + +implicit_nlsolve = ImplicitFunction(forward_nlsolve, conditions_nlsolve) + +#- + +x = rand(2) + +#- + +first(implicit_nlsolve(x; method=:newton)) .^ 2 +@test first(implicit_nlsolve(x; method=:newton)) .^ 2 ≈ x #src + +#- + +J = Diagonal(0.5 ./ sqrt.(x)) + +# ## Forward mode autodiff + +ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x) +@test ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x) ≈ J #src + +#- + +ForwardDiff.jacobian(_x -> mysqrt_nlsolve(_x; method=:newton), x) + +# ## Reverse mode autodiff + +Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1] +@test Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1] ≈ J #src + +#- + +try + Zygote.jacobian(_x -> mysqrt_nlsolve(_x; method=:newton), x)[1] +catch e + e +end diff --git a/examples/2_sparse_linear_regression.jl b/examples/2_sparse_linear_regression.jl deleted file mode 100644 index ff62afb..0000000 --- a/examples/2_sparse_linear_regression.jl +++ /dev/null @@ -1,117 +0,0 @@ -# # Sparse linear regression - -#= -In this example, we show how to differentiate through the solution of the following constrained optimization problem: -```math -\hat{y}(x) = \underset{y \in \mathcal{C}}{\mathrm{argmin}} ~ f(x, y) -``` -where ``\mathcal{C}`` is a closed convex set. -The optimal solution can be found as the fixed point of the projected gradient algorithm for any step size ``\eta``. This insight yields the following optimality conditions: -```math -F(x, \hat{y}(x)) = 0 \quad \text{with} \quad F(x,y) = \mathrm{proj}_{\mathcal{C}}(y - \eta \nabla_2 f(x, y)) - y -``` -=# - -using ComponentArrays -using Convex -using FiniteDifferences -using ImplicitDifferentiation -using MathOptInterface -using MathOptSetDistances -using Random -using SCS -using Test #src -using Zygote - -Random.seed!(63); - -# ## Introduction - -#= -We have a matrix of features $X \in \mathbb{R}^{n \times p}$ and a vector of targets $y \in \mathbb{R}^n$. - -In a linear regression setting $y \approx X \beta$, one way to ensure sparsity of the parameter $\beta \in \mathbb{R}^p$ is to select it within the $\ell_1$ ball $\mathcal{B}_1$: -```math -\hat{\beta}(X, y) = \underset{\beta}{\mathrm{argmin}} ~ \lVert y - X \beta \rVert_2^2 \quad \text{s.t.} \quad \lVert \beta \rVert_1 \leq 1 \tag{QP} -``` -We want to compute the derivatives of the optimal parameter wrt to the data: $\partial \hat{\beta} / \partial X$ and $\partial \hat{\beta} / \partial y$. - -Possible application: sensitivity analysis of $\hat{\beta}(X, y)$. -=# - -# ## Forward solver - -# The function $\hat{\beta}$ is computed with a disciplined convex solver thanks to `Convex.jl`. - -function lasso(X::AbstractMatrix, y::AbstractVector) - n, p = size(X) - β = Variable(p) - objective = sumsquares(X * β - y) - constraints = [norm(β, 1) <= 1.0] - problem = minimize(objective, constraints) - solve!(problem, SCS.Optimizer; silent_solver=true) - return Convex.evaluate(β) -end; - -# To comply with the requirements of ImplicitDifferentiation.jl, we need to provide the input arguments within a single array. We exploit [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) for that purpose. - -lasso(data::ComponentVector) = lasso(data.X, data.y); - -# ## Optimality conditions - -# We use [MathOptSetDistances.jl](https://github.com/matbesancon/MathOptSetDistances.jl) to compute the projection onto the unit $\ell_1$ ball. - -function proj_l1_ball(v::AbstractVector{R}) where {R<:Real} - distance = MathOptSetDistances.DefaultDistance() - cone = MathOptInterface.NormOneCone(length(v)) - ball = MathOptSetDistances.NormOneBall{R}(one(R), cone) - return projection_on_set(distance, v, ball) -end; - -# Since this projection uses mutation internally, it is not compatible with Zygote.jl. Thus, we need to specify that it should be differentiated with ForwardDiff.jl. - -function proj_grad_fixed_point(data, β) - grad = 2 * data.X' * (data.X * β - data.y) - return β - Zygote.forwarddiff(proj_l1_ball, β - grad) -end; - -# This is the last ingredient we needed to build a differentiable sparse linear regression. - -implicit = ImplicitFunction(lasso, proj_grad_fixed_point); - -# ## Testing - -n, p = 5, 7; -X, y = rand(n, p), rand(n); -data = ComponentVector(; X=X, y=y); - -# As expected, the forward pass returns a sparse solution - -round.(implicit(data); digits=4) - -# Note that implicit differentiation is necessary here because the convex solver breaks autodiff. - -try - Zygote.jacobian(lasso, data) -catch e - e -end - -# Meanwhile, our implicit wrapper makes autodiff work seamlessly. - -J = Zygote.jacobian(implicit, data)[1] - -# The number of columns of the Jacobian is explained by the following formula: - -prod(size(X)) + prod(size(y)) - -# We can validate the result using finite differences. - -J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), lasso, data)[1] -sum(abs, J - J_ref) / prod(size(J)) - -# The following tests are not included in the docs. #src - -@testset verbose = true "FiniteDifferences.jl" begin #src - @test sum(abs, J - J_ref) / prod(size(J)) <= 1e-2 #src -end #src diff --git a/examples/3_fixed_points.jl b/examples/3_fixed_points.jl new file mode 100644 index 0000000..c6bde74 --- /dev/null +++ b/examples/3_fixed_points.jl @@ -0,0 +1,93 @@ +# # Fixed point + +#= +In this example, we show how to differentiate through the limit of a fixed point iteration: +```math +y \longmapsto T(x, y) +``` +The optimality conditions are pretty obvious: +```math +y = T(x, y) +``` +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using Random +using Test #src +using Zygote + +Random.seed!(63); + +# ## Implicit function + +#= +To make verification easy, we consider [Heron's method](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method): +```math +T(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) +``` +In this case, the fixed point algorithm boils down to the componentwise square root function, but we implement it manually. +=# + +function mysqrt_fixedpoint(x; iterations) + y = ones(eltype(x), size(x)) + for _ in 1:iterations + y .= 0.5 .* (y .+ x ./ y) + end + return y +end + +#- + +function forward_fixedpoint(x; iterations) + y = mysqrt_fixedpoint(x; iterations) + z = 0 + return y, z +end + +#- + +function conditions_fixedpoint(x, y, z; iterations) + T = 0.5 .* (y .+ x ./ y) + return T .- y +end + +#- + +implicit_fixedpoint = ImplicitFunction(forward_fixedpoint, conditions_fixedpoint) + +#- + +x = rand(2) + +#- + +first(implicit_fixedpoint(x; iterations=10)) .^ 2 +@test first(implicit_fixedpoint(x; iterations=10)) .^ 2 ≈ x #src + +#- + +J = Diagonal(0.5 ./ sqrt.(x)) + +# ## Forward mode autodiff + +ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x) +@test ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x) ≈ J #src + +#- + +ForwardDiff.jacobian(_x -> mysqrt_fixedpoint(_x; iterations=10), x) + +# ## Reverse mode autodiff + +Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1] +@test Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1] ≈ J #src + +#- + +try + Zygote.jacobian(_x -> mysqrt_fixedpoint(_x; iterations=10), x)[1] +catch e + e +end diff --git a/examples/3_optimal_transport.jl b/examples/3_optimal_transport.jl deleted file mode 100644 index b3abe59..0000000 --- a/examples/3_optimal_transport.jl +++ /dev/null @@ -1,170 +0,0 @@ -# # Optimal transport - -#= -In this example, we show how to differentiate through the solution of the entropy-regularized optimal transport problem. -=# - -using Distances -using FiniteDifferences -using ImplicitDifferentiation -using LinearAlgebra -using Random -using Test #src -using Zygote - -Random.seed!(63); - -#= -## Introduction - -Here we give a brief introduction to optimal transport, see the [book by Gabriel Peyré and Marco Cuturi](https://optimaltransport.github.io/book/) for more details. - -### Problem description - -Suppose we have a distribution of mass ``a \in \Delta^{n}`` over points ``x_1, ..., x_{n} \in \mathbb{R}^d`` (where ``\Delta`` denotes the probability simplex). -We want to transport it to a distribution ``b \in \Delta^{m}`` over points ``y_1, ..., y_{m} \in \mathbb{R}^d``. -The unit moving cost from ``x`` to ``y`` is proportional to the squared Euclidean distance ``c(x, y) = \lVert x - y \rVert_2^2``. - -A transportation plan can be described by a coupling ``p = \Pi(a, b)``, i.e. a probability distribution on the product space with the right marginals: -```math -\Pi(a, b) = \{p \in \Delta^{n \times m}: p \mathbf{1} = a, p^\top \mathbf{1} = b\} -``` -Let ``C \in \mathbb{R}^{n \times m}`` be the moving cost matrix, with ``C_{ij} = c(x_i, y_j)``. -The basic optimization problem we want to solve is a linear program: -```math -\hat{p}(C) = \underset{p \in \Pi(a, b)}{\mathrm{argmin}} ~ \sum_{i,j} p_{ij} C_{ij} -``` -In order to make it smoother, we add an entropic regularization term: - ```math -\hat{p}_{\varepsilon}(C) = \underset{p \in \Pi(a, b)}{\mathrm{argmin}} ~ \sum_{i,j} \left(p_{ij} C_{ij} + \varepsilon p_{ij} \log \frac{p_{ij}}{a_i b_j} \right) -``` - -### Sinkhorn algorithm - -To solve the regularized problem, we can use the Sinkhorn fixed point algorithm. -Let ``K \in \mathbb{R}^{n \times m}`` be the matrix defined by ``K_{ij} = \exp(-C_{ij} / \varepsilon)``. -Then the optimal coupling ``\hat{p}_{\varepsilon}(C)`` can be written as: -```math -\hat{p}_{\varepsilon}(C) = \mathrm{diag}(\hat{u}) ~ K ~ \mathrm{diag}(\hat{v}) \tag{1} -``` -where ``\hat{u}`` and ``\hat{v}`` are the fixed points of the following Sinkhorn iteration: -```math -u^{t+1} = \frac{a}{Kv^t} \qquad \text{and} \qquad v^{t+1} = \frac{b}{K^\top u^t} \tag{S} -``` - -The implicit function theorem can be used to differentiate ``\hat{u}`` and ``\hat{v}`` with respect to ``C``, ``a`` and/or ``b``. -This can be combined with automatic differentiation of Equation (1) to find the Jacobian -```math -J = \frac{\partial ~ \mathrm{vec}(\hat{p}_{\varepsilon}(C))}{\partial ~ \mathrm{vec}(C)} -``` -=# - -d = 10 -n = 3 -m = 4 - -X = rand(d, n) -Y = rand(d, m) - -a = fill(1 / n, n) -b = fill(1 / m, m) -C = pairwise(SqEuclidean(), X, Y; dims=2) - -ε = 1.0; -T = 100; - -# ## Forward solver - -# For technical reasons related to optimality checking, our Sinkhorn solver returns ``\hat{u}`` instead of ``\hat{p}_\varepsilon``. - -function sinkhorn(C; a, b, ε, T) - K = exp.(.-C ./ ε) - u = copy(a) - v = copy(b) - for t in 1:T - u = a ./ (K * v) - v = b ./ (K' * u) - end - return u -end - -function sinkhorn_efficient(C; a, b, ε, T) - K = exp.(.-C ./ ε) - u = copy(a) - v = copy(b) - for t in 1:T - mul!(u, K, v) - u .= a ./ u - mul!(v, K', u) - v .= b ./ v - end - return u -end - -# ## Optimality conditions - -# We simply used the fixed point equation $(\text{S})$. - -function sinkhorn_fixed_point(C, u; a, b, ε, T=nothing) - K = exp.(.-C ./ ε) - v = b ./ (K' * u) - return u .- a ./ (K * v) -end - -# We have all we need to build a differentiable Sinkhorn that doesn't require unrolling the fixed point iterations. - -implicit = ImplicitFunction(sinkhorn_efficient, sinkhorn_fixed_point); - -# ## Testing - -u1 = sinkhorn(C; a=a, b=b, ε=ε, T=T) -u2 = implicit(C; a=a, b=b, ε=ε, T=T) -u1 == u2 - -# First, let us check that the forward pass works correctly and returns a fixed point. - -all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) - -# Using the implicit function defined above, we can build an autodiff-compatible Sinkhorn which does not require backpropagating through the fixed point iterations: - -function transportation_plan_slow(C; a, b, ε, T) - K = exp.(.-C ./ ε) - u = sinkhorn(C; a=a, b=b, ε=ε, T=T) - v = b ./ (K' * u) - p = u .* K .* v' - return p -end; - -function transportation_plan_fast(C; a, b, ε, T) - K = exp.(.-C ./ ε) - u = implicit(C; a=a, b=b, ε=ε, T=T) - v = b ./ (K' * u) - p = u .* K .* v' - return p -end; - -# What does the transportation plan look like? - -p1 = transportation_plan_slow(C; a=a, b=b, ε=ε, T=T) -p2 = transportation_plan_fast(C; a=a, b=b, ε=ε, T=T) -p1 == p2 - -# Let us compare its Jacobian with the one obtained using finite differences. - -J1 = Zygote.jacobian(C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C)[1] -J2 = Zygote.jacobian(C -> transportation_plan_fast(C; a=a, b=b, ε=ε, T=T), C)[1] -J_ref = FiniteDifferences.jacobian( - central_fdm(5, 1), C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C -)[1] - -sum(abs, J2 - J_ref) / prod(size(J_ref)) - -# The following tests are not included in the docs. #src - -@testset verbose = true "FiniteDifferences.jl" begin #src - @test u1 == u2 #src - @test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src - @test p1 == p2 #src - @test sum(abs, J1 - J_ref) / prod(size(J_ref)) < 1e-5 #src - @test sum(abs, J2 - J_ref) / prod(size(J_ref)) < 1e-5 #src -end #src diff --git a/examples/4_constrained_optim.jl b/examples/4_constrained_optim.jl new file mode 100644 index 0000000..71a369f --- /dev/null +++ b/examples/4_constrained_optim.jl @@ -0,0 +1,105 @@ +# # Constrained optimization + +#= +In this example, we show how to differentiate through the solution of a constrained optimization problem: +```math +y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) \quad \text{subject to} \quad g(x, y) \leq 0 +``` +The optimality conditions are a bit trickier than in the previous cases. +We can projection on the feasible set $\mathcal{C}(x) = \{y: g(x, y) \leq 0 \}$ and exploit the convergence of projected gradient descent with step size $\eta$: +```math +y = \mathrm{proj}_{\mathcal{C}(x)} (y - \eta \nabla_2 f(x, y)) +``` +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using Optim +using Random +using Test #src +using Zygote + +Random.seed!(63); + +# ## Implicit function + +#= +To make verification easy, we minimize the following objective: +```math +f(x, y) = \lVert y \odot y - x \rVert^2 +``` +on the hypercube $\mathcal{C}(x) = [0, 1]^n$. +In this case, the optimization problem boils down to a thresholded componentwise square root function, but we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl). +=# + +function mysqrt_cstr_optim(x) + f(y) = sum(abs2, y .^ 2 - x) + lower = zeros(size(x)) + upper = ones(size(x)) + y0 = ones(eltype(x), size(x)) ./ 2 + res = optimize(f, lower, upper, y0, Fminbox(GradientDescent())) + y = Optim.minimizer(res) + return y +end + +#- + +function forward_cstr_optim(x) + y = mysqrt_cstr_optim(x) + z = 0 + return y, z +end + +#- + +function proj_hypercube(p) + return max.(0, min.(1, p)) +end + +function conditions_cstr_optim(x, y, z) + ∇₂f = 2 .* (y .^ 2 .- x) + η = 0.1 + return y .- proj_hypercube(y .- η .* ∇₂f) +end + +# We now have all the ingredients to construct our implicit function. + +implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim) + +# And indeed, it behaves as it should when we call it: + +x = rand(2) .+ [0, 1] + +#= +The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$. +=# + +(first ∘ implicit_cstr_optim)(x) .^ 2 +@test (first ∘ implicit_cstr_optim)(x) .^ 2 ≈ [x[1], 1] #src + +#- + +J_thres = Diagonal([0.5 / sqrt(x[1]), 0]) + +# ## Forward mode autodiff + +ForwardDiff.jacobian(first ∘ implicit_cstr_optim, x) +@test ForwardDiff.jacobian(first ∘ implicit_cstr_optim, x) ≈ J_thres #src + +#- + +ForwardDiff.jacobian(mysqrt_cstr_optim, x) + +# ## Reverse mode autodiff + +Zygote.jacobian(first ∘ implicit_cstr_optim, x)[1] +@test Zygote.jacobian(first ∘ implicit_cstr_optim, x)[1] ≈ J_thres #src + +#- + +try + Zygote.jacobian(mysqrt_cstr_optim, x)[1] +catch e + e +end diff --git a/ext/ImplicitDifferentiationChainRulesExt.jl b/ext/ImplicitDifferentiationChainRulesExt.jl new file mode 100644 index 0000000..91fe7b5 --- /dev/null +++ b/ext/ImplicitDifferentiationChainRulesExt.jl @@ -0,0 +1,43 @@ +module ImplicitDifferentiationChainRulesExt + +using AbstractDifferentiation: ReverseRuleConfigBackend, pullback_function +using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, ZeroTangent, unthunk +using ImplicitDifferentiation: ImplicitFunction, PullbackMul!, check_solution +using LinearOperators: LinearOperator + +""" + rrule(rc, implicit, x[; kwargs...]) + +Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). + +We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`. +Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. +""" +function ChainRulesCore.rrule( + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... +) where {R<:Real} + conditions = implicit.conditions + linear_solver = implicit.linear_solver + + y, z = implicit(x; kwargs...) + n, m = length(x), length(y) + + backend = ReverseRuleConfigBackend(rc) + pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y) + pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x) + Aᵀ_op = LinearOperator(R, m, m, false, false, PullbackMul!(pbA, size(y))) + Bᵀ_op = LinearOperator(R, n, m, false, false, PullbackMul!(pbB, size(y))) + + function implicit_pullback((dy, dz)) + dy_vec = convert(Vector{R}, vec(unthunk(dy))) + dF_vec, stats = linear_solver(Aᵀ_op, dy_vec) + check_solution(linear_solver, stats) + dx_vec = -(Bᵀ_op * dF_vec) + dx = reshape(dx_vec, size(x)) + return (NoTangent(), dx) + end + + return (y, z), implicit_pullback +end + +end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl new file mode 100644 index 0000000..468b4c8 --- /dev/null +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -0,0 +1,50 @@ +module ImplicitDifferentiationForwardDiffExt + +@static if isdefined(Base, :get_extension) + using ForwardDiff: Dual, Partials, jacobian, partials, value +else + using ..ForwardDiff: Dual, Partials, jacobian, partials, value +end + +using AbstractDifferentiation: ForwardDiffBackend, pushforward_function +using ImplicitDifferentiation: ImplicitFunction, PushforwardMul!, check_solution +using LinearOperators: LinearOperator + +""" + implicit(x_and_dx::AbstractArray{ForwardDiff.Dual}[; kwargs...]) + +Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with ForwardDiff.jl. +""" +function (implicit::ImplicitFunction)( + x_and_dx::AbstractArray{Dual{T,R,N}}; kwargs... +) where {T,R,N} + conditions = implicit.conditions + linear_solver = implicit.linear_solver + + x = value.(x_and_dx) + y, z = implicit(x; kwargs...) + n, m = length(x), length(y) + + backend = ForwardDiffBackend() + pfA = pushforward_function(backend, _y -> conditions(x, _y, z; kwargs...), y) + pfB = pushforward_function(backend, _x -> conditions(_x, y, z; kwargs...), x) + A_op = LinearOperator(R, m, m, false, false, PushforwardMul!(pfA, size(y))) + B_op = LinearOperator(R, m, n, false, false, PushforwardMul!(pfB, size(x))) + + dy = map(1:N) do k + dₖx_vec = vec(partials.(x_and_dx, k)) + dₖy_vec, stats = linear_solver(A_op, -(B_op * dₖx_vec)) + check_solution(linear_solver, stats) + reshape(dₖy_vec, size(y)) + end + + y_and_dy = map(eachindex(y)) do i + Dual{T}(y[i], Partials(Tuple(dy[k][i] for k in 1:N))) + end + + z_and_dz = Dual{T}(z, Partials(Tuple(zero(z) for k in 1:N))) + + return y_and_dy, z_and_dz +end + +end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 2ef759d..c84571c 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -1,12 +1,22 @@ module ImplicitDifferentiation -using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig -using ChainRulesCore: frule_via_ad, rrule_via_ad, unthunk -using Krylov: gmres +using AbstractDifferentiation: LazyJacobian, ReverseRuleConfigBackend, lazy_jacobian +using Krylov: KrylovStats, gmres using LinearOperators: LinearOperator +using Requires: @require +include("utils.jl") include("implicit_function.jl") export ImplicitFunction +@static if !isdefined(Base, :get_extension) + include("../ext/ImplicitDifferentiationChainRulesExt.jl") + function __init__() + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/ImplicitDifferentiationForwardDiffExt.jl") + end + end +end + end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 937ecc6..586fb27 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,19 +1,20 @@ """ ImplicitFunction{F,C,L} -Differentiable wrapper for an implicit function `x -> ŷ(x)` whose output is defined by explicit conditions `F(x,ŷ(x)) = 0`. +Differentiable wrapper for an implicit function `x -> y(x)` whose output is defined by conditions `F(x,y(x)) = 0`. -If `x ∈ ℝⁿ` and `y ∈ ℝᵈ`, then we need as many conditions as output dimensions: `F(x,y) ∈ ℝᵈ`. -Thanks to these conditions, we can compute the Jacobian of `ŷ(⋅)` using the implicit function theorem: +More generally, we consider functions `x -> (y(x),z(x))` and conditions `F(x,y(x),z(x)) = 0`, where `z(x)` contains additional information that _is considered constant for differentiation purposes_. Beware: the method `zero(z)` must exist. + +If `x ∈ ℝⁿ` and `y ∈ ℝᵈ`, then we need as many conditions as output dimensions: `F(x,y,z) ∈ ℝᵈ`. Thanks to these conditions, we can compute the Jacobian of `y(⋅)` using the implicit function theorem: ``` -∂₂F(x,ŷ(x)) * ∂ŷ(x) = -∂₁F(x,ŷ(x)) +∂₂F(x,y(x),z(x)) * ∂y(x) = -∂₁F(x,y(x),z(x)) ``` -This requires solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`. +This amounts to solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`. # Fields: -- `forward::F`: callable of the form `x -> ŷ(x)` -- `conditions::C`: callable of the form `(x,y) -> F(x,y)` -- `linear_solver::L`: callable of the form `(A,b) -> u` such that `A * u = b` +- `forward::F`: callable of the form `x -> (ŷ(x),z(x))`. +- `conditions::C`: callable of the form `(x,y,z) -> F(x,y,z)` +- `linear_solver::L`: callable of the form `(A,b) -> u` such that `Au = b`, must be taken from Krylov.jl """ struct ImplicitFunction{F,C,L} forward::F @@ -24,103 +25,18 @@ end """ ImplicitFunction(forward, conditions) -Construct an [`ImplicitFunction{F,C,L}`](@ref) with `Krylov.gmres` as the default linear solver. +Construct an `ImplicitFunction` with `Krylov.gmres` as the default linear solver. """ -function ImplicitFunction(forward::F, conditions::C) where {F,C} +function ImplicitFunction(forward, conditions) return ImplicitFunction(forward, conditions, gmres) end -struct SolverFailureException{S} <: Exception - msg::String - stats::S -end - -function Base.show(io::IO, sfe::SolverFailureException) - return println(io, "SolverFailureException: $(sfe.msg) \n Solver stats: $(sfe.stats)") -end - """ implicit(x[; kwargs...]) -Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`. -""" -(implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...) - +Make `ImplicitFunction` callable by applying `implicit.forward`. """ - frule(rc, (_, dx), implicit, x[; kwargs...]) - -Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref). - -We compute the Jacobian-vector product `Jv` by solving `Au = -Bv` and setting `Jv = u`. -Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. -""" -function ChainRulesCore.frule( - rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... -) where {R<:Real} - conditions = implicit.conditions - linear_solver = implicit.linear_solver - - y = implicit(x; kwargs...) - - conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...) - conditions_y(ỹ; kwargs...) = conditions(x, ỹ; kwargs...) - - pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y; kwargs...)[2] - pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x; kwargs...)[2] - - mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y)))) - mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x)))) - - n, m = length(x), length(y) - A = LinearOperator(R, m, m, false, false, mul_A!) - B = LinearOperator(R, m, n, false, false, mul_B!) - - dx_vec = convert(Vector{R}, vec(unthunk(dx))) - b = -B * dx_vec - dy_vec, stats = linear_solver(A, b) - if !stats.solved - throw(SolverFailureException("Linear solver failed to converge", stats)) - end - dy = reshape(dy_vec, size(y)) - - return y, dy -end - -""" - rrule(rc, implicit, x[; kwargs...]) - -Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). - -We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`. -Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. -""" -function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... -) where {R<:Real} - conditions = implicit.conditions - linear_solver = implicit.linear_solver - - y = implicit(x; kwargs...) - - pullback = rrule_via_ad(rc, conditions, x, y; kwargs...)[2] - - mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback(reshape(u, size(y)))[3]) - mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback(reshape(v, size(y)))[2]) - - n, m = length(x), length(y) - Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!) - Bᵀ = LinearOperator(R, n, m, false, false, mul_Bᵀ!) - - function implicit_pullback(dy) - dy_vec = convert(Vector{R}, vec(unthunk(dy))) - u, stats = linear_solver(Aᵀ, dy_vec) - if !stats.solved - throw(SolverFailureException("Linear solver failed to converge", stats)) - end - dx_vec = -Bᵀ * u - dx = reshape(dx_vec, size(x)) - return (NoTangent(), dx) - end - - return y, implicit_pullback +function (implicit::ImplicitFunction)(x; kwargs...) + y, z = implicit.forward(x; kwargs...) + return y, z end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..3a897ee --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,59 @@ +struct SolverFailureException{A,B} <: Exception + solver::A + stats::B +end + +function Base.show(io::IO, sfe::SolverFailureException) + return println( + io, + "SolverFailureException: \n Solver: $(sfe.solver) \n Solver stats: $(string(sfe.stats))", + ) +end + +function check_solution(solver, stats) + if stats.solved + return nothing + else + throw(SolverFailureException(solver, stats)) + end +end + +""" + PushforwardMul!{P,N} + +Callable structure wrapping a pushforward with `N`-dimensional inputs into an in-place multiplication for vectors. + +# Fields +- `pushforward::P`: the pushforward function +- `input_size::NTuple{N,Int}`: the array size of the function input +""" +struct PushforwardMul!{P,N} + pushforward::P + input_size::NTuple{N,Int} +end + +""" + PullbackMul!{P,N} + +Callable structure wrapping a pullback with `N`-dimensional outputs into an in-place multiplication for vectors. + +# Fields +- `pullback::P`: the pullback of the function +- `output_size::NTuple{N,Int}`: the array size of the function output +""" +struct PullbackMul!{P,N} + pullback::P + output_size::NTuple{N,Int} +end + +function (pfm::PushforwardMul!)(res::Vector, δinput_vec::Vector) + δinput = reshape(δinput_vec, pfm.input_size) + δoutput = only(pfm.pushforward(δinput)) + return res .= vec(δoutput) +end + +function (pbm::PullbackMul!)(res::Vector, δoutput_vec::Vector) + δoutput = reshape(δoutput_vec, pbm.output_size) + δinput = only(pbm.pullback(δoutput)) + return res .= vec(δinput) +end diff --git a/test/runtests.jl b/test/runtests.jl index ea776e4..c52e9c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,6 @@ using Aqua using Documenter -using ForwardDiffChainRules using ImplicitDifferentiation using JET using JuliaFormatter @@ -33,22 +32,25 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") ## Test sets @testset verbose = true "ImplicitDifferentiation.jl" begin - @testset verbose = true "Code quality (Aqua.jl)" begin - Aqua.test_all(ImplicitDifferentiation) + @testset verbose = false "Code quality (Aqua.jl)" begin + Aqua.test_ambiguities([ImplicitDifferentiation, Base, Core]) + Aqua.test_unbound_args(ImplicitDifferentiation) + Aqua.test_undefined_exports(ImplicitDifferentiation) + Aqua.test_piracy(ImplicitDifferentiation) + Aqua.test_project_extras(ImplicitDifferentiation) + Aqua.test_stale_deps(ImplicitDifferentiation; ignore=[:ChainRulesCore]) + Aqua.test_deps_compat(ImplicitDifferentiation) + Aqua.test_project_toml_formatting(ImplicitDifferentiation) end @testset verbose = true "Formatting (JuliaFormatter.jl)" begin @test format(ImplicitDifferentiation; verbose=true, overwrite=false) end @testset verbose = true "Static checking (JET.jl)" begin if VERSION >= v"1.8" - JET.test_package( - ImplicitDifferentiation; - toplevel_logger=nothing, - ignored_modules=(ForwardDiffChainRules,), - ) # TODO: remove once new version released + JET.test_package(ImplicitDifferentiation; toplevel_logger=nothing) end end - @testset verbose = true "Doctests (Documenter.jl)" begin + @testset verbose = false "Doctests (Documenter.jl)" begin doctest(ImplicitDifferentiation) end for file in readdir(EXAMPLES_DIR_JL)