generated from JuliaReach/JuliaReachTemplatePkg.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from JuliaReach/schillic/ai2
Add partial AI2 algorithm
- Loading branch information
Showing
17 changed files
with
257 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,5 +19,8 @@ LazyForward | |
BoxForward | ||
SplitForward | ||
DeepZ | ||
AI2Box | ||
AI2Zonotope | ||
AI2Polytope | ||
Verisig | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Util | ||
|
||
This section of the manual describes the module for utilities. | ||
|
||
```@contents | ||
Pages = ["Util.md"] | ||
Depth = 3 | ||
``` | ||
|
||
```@meta | ||
CurrentModule = NeuralNetworkReachability.Util | ||
``` | ||
|
||
```@docs | ||
ConvSet | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
AI2Box <: AI2 | ||
AI2 forward algorithm for ReLU activation functions based on abstract | ||
interpretation with the interval domain from [1]. | ||
### Notes | ||
This algorithm is less precise than [`BoxForward`](@ref) because it abstracts | ||
after every step, including the affine map. | ||
[1]: Gehr et al.: *AI²: Safety and robustness certification of neural networks | ||
with abstract interpretation*, SP 2018. | ||
""" | ||
struct AI2Box <: ForwardAlgorithm end | ||
|
||
""" | ||
AI2Zonotope <: AI2 | ||
AI2 forward algorithm for ReLU activation functions based on abstract | ||
interpretation with the zonotope domain from [1]. | ||
### Fields | ||
- `join_algorithm` -- (optional; default: `"join"`) algorithm to compute the | ||
join of two zonotopes | ||
[1]: Gehr et al.: *AI²: Safety and robustness certification of neural networks | ||
with abstract interpretation*, SP 2018. | ||
""" | ||
struct AI2Zonotope{S} <: ForwardAlgorithm | ||
join_algorithm::S | ||
end | ||
|
||
# the default join algorithm is "join" | ||
AI2Zonotope() = AI2Zonotope("join") | ||
|
||
""" | ||
AI2Polytope <: AI2 | ||
AI2 forward algorithm for ReLU activation functions based on abstract | ||
interpretation with the polytope domain from [1]. | ||
[1]: Gehr et al.: *AI²: Safety and robustness certification of neural networks | ||
with abstract interpretation*, SP 2018. | ||
""" | ||
struct AI2Polytope <: ForwardAlgorithm end | ||
|
||
# meet and join algorithms for different abstract domains | ||
const _meet_zonotope = (X, Y) -> overapproximate(X ∩ Y, Zonotope) | ||
const _join_zonotope(algo) = (X, Y) -> overapproximate(X ∪ Y, Zonotope; algorithm=algo) | ||
const _meet_polytope = intersection | ||
const _join_polytope = convex_hull | ||
|
||
# apply affine map | ||
|
||
# box: box approximation of the affine map | ||
function forward(H, W::AbstractMatrix, b::AbstractVector, ::AI2Box) | ||
return box_approximation(W * H + b) | ||
end | ||
|
||
# zonotope and polytope: closed under affine map | ||
function forward(X, W::AbstractMatrix, b::AbstractVector, ::Union{AI2Zonotope,AI2Polytope}) | ||
return affine_map(W, X, b) | ||
end | ||
|
||
# apply ReLU activation function | ||
# for each dimension 1:n | ||
# 1(a) if nonnegative: nothing | ||
# 1(b) if negative: project | ||
# 1(c) if both nonnegative and negative: intersect with half-spaces and project negative | ||
# 2: take the domain element(s) corresponding to the previous set(s) | ||
# 3(c): union of the two sets, then take the corresponding domain element | ||
|
||
# box: exploits that Box(ReLU(H)) = ReLU(H) | ||
function forward(H::AbstractHyperrectangle, ::ReLU, ::AI2Box) | ||
return rectify(H) | ||
end | ||
|
||
# zonotope: intersection the zonotope overapproximation of all pairwise projected intersections | ||
function forward(Z::AbstractZonotope, ::ReLU, algo::AI2Zonotope) | ||
require(@__MODULE__, :IntervalConstraintProgramming; fun_name="forward", | ||
explanation="with AI2Zonotope") | ||
|
||
return _forward_AI2_ReLU(Z; meet=_meet_zonotope, join=_join_zonotope(algo.join_algorithm)) | ||
end | ||
|
||
# polytope: the convex hull of all pairwise polytopes | ||
function forward(P::AbstractPolytope, ::ReLU, ::AI2Polytope) | ||
return _forward_AI2_ReLU(P; meet=_meet_polytope, join=_join_polytope) | ||
end | ||
|
||
function _forward_AI2_ReLU(X::LazySet{N}; meet, join) where {N} | ||
n = dim(X) | ||
d = ones(N, n) # reused vector for "almost" identity matrices | ||
for i in 1:n | ||
if low(X, i) >= 0 # nonnegative case | ||
continue | ||
elseif high(X, i) <= 0 # negative case | ||
d[i] = zero(N) | ||
D = Diagonal(d) | ||
X = linear_map(D, X) | ||
d[i] = one(N) | ||
else # mixed case | ||
# nonnegative part | ||
H1 = HalfSpace(SingleEntryVector(i, n, -one(N)), zero(N)) | ||
X1 = meet(X, H1) | ||
|
||
# negative part | ||
H2 = HalfSpace(SingleEntryVector(i, n, one(N)), zero(N)) | ||
X2 = meet(X, H2) | ||
d[i] = zero(N) | ||
D = Diagonal(d) | ||
X2′ = linear_map(D, X2) | ||
d[i] = one(N) | ||
|
||
# join | ||
X = join(X1, X2′) | ||
end | ||
end | ||
return X | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
ConvSet{T<:LazySet{N}} | ||
Wrapper of a set to represent a three-dimensional structure. | ||
### Fields | ||
- `set` -- set of dimension `dims[1] * dims[2] * dims[3]` | ||
- `dims` -- 3-tuple with the dimensions | ||
""" | ||
struct ConvSet{T<:LazySet} | ||
set::T | ||
dims::NTuple{3,Int} | ||
|
||
function ConvSet(set::T, dims::NTuple{3,Int}; validate=Val(true)) where {T} | ||
if validate isa Val{true} && (dim(set) != dims[1] * dims[2] * dims[3] || | ||
dims[1] <= 0 || dims[2] <= 0 || dims[3] <= 0) | ||
throw(ArgumentError("invalid dimensions $(dim(set)) and $dims")) | ||
end | ||
return new{T}(set, dims) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module Util | ||
|
||
using LazySets: LazySet, dim | ||
|
||
export ConvSet | ||
|
||
include("ConvSet.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using NeuralNetworkReachability.Util: ConvSet | ||
|
||
@testset "ConvSet" begin | ||
X = BallInf(zeros(12), 1.0) | ||
ConvSet(X, (1, 2, 6)) | ||
ConvSet(X, (2, 2, 3)) | ||
@test_throws ArgumentError ConvSet(X, (0, 1, 12)) | ||
@test_throws ArgumentError ConvSet(X, (2, 2, 2)) | ||
@test_throws ArgumentError ConvSet(X, (0, 0, 0)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters