Skip to content

Commit

Permalink
Refactor gpu tests into gpu/ folder
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel authored and dpo committed Oct 16, 2024
1 parent 6dc3b23 commit b8ea9bc
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 11 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ steps:
using Pkg
Pkg.add("CUDA")
Pkg.instantiate()
include("test/gpu/test_S_kwarg.jl")
include("test/gpu/nvidia.jl")'
timeout_in_minutes: 30
15 changes: 15 additions & 0 deletions test/gpu/amdgpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Test, LinearAlgebra, SparseArrays
using LinearOperators, AMDGPU

@testset "AMDGPU -- AMDGPU.jl" begin
A = ROCArray(rand(Float32, 5, 5))
B = ROCArray(rand(Float32, 10, 10))
C = ROCArray(rand(Float32, 20, 20))
M = BlockDiagonalOperator(A, B, C)

v = ROCArray(rand(Float32, 35))
y = M * v
@test y isa ROCArray{Float32}

@testset "AMDGPU S kwarg" test_S_kwarg(arrayType = ROCArray)
end
1 change: 1 addition & 0 deletions test/gpu/jlarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test_S_kwarg(arrayType = JLArray)
3 changes: 3 additions & 0 deletions test/gpu/metal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if Sys.isapple() && occursin("arm64", Sys.MACHINE)
test_S_kwarg(arrayType = MtlArray, notMetal = false)
end
1 change: 1 addition & 0 deletions test/gpu/nvidia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ using LinearOperators, CUDA, CUDA.CUSPARSE, CUDA.CUSOLVER
v = CUDA.rand(35)
y = M * v
@test y isa CuVector{Float32}
@testset "Nvidia S kwarg" test_S_kwarg(arrayType = CuArray)
end
11 changes: 4 additions & 7 deletions test/test_S_kwarg.jl → test/gpu/test_S_kwarg.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function test_S_kwarg(; arrayType = JLArray, notMetal = true)
using Test, LinearOperators, LinearAlgebra

function test_S_kwarg(; arrayType, notMetal = true)
mat = arrayType(rand(Float32, 32, 32))
vec = arrayType(rand(Float32, 32))
vecT = typeof(vec)
Expand All @@ -8,7 +10,7 @@ function test_S_kwarg(; arrayType = JLArray, notMetal = true)
vecTother = typeof(arrayType(rand(Float32, 32)))
end

@testset ExtendedTestSet "S Kwarg with arrayType $(arrayType)" begin
@testset "S Kwarg with arrayType $(arrayType)" begin
@test vecT == LinearOperators.storage_type(mat)

# constructors.jl
Expand Down Expand Up @@ -37,9 +39,4 @@ function test_S_kwarg(; arrayType = JLArray, notMetal = true)
notMetal && @test LinearOperators.storage_type(BlockDiagonalOperator(mat, mat; S = vecTother)) == vecTother
end

end

test_S_kwarg()
if Sys.isapple() && occursin("arm64", Sys.MACHINE)
test_S_kwarg(arrayType = MtlArray, notMetal = false)
end
10 changes: 6 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using Arpack, Test, TestSetExtensions, LinearOperators
using LinearAlgebra, LDLFactorizations, SparseArrays, JLArrays
using Zygote
if Sys.isapple() && occursin("arm64", Sys.MACHINE)
using Metal
end
include("test_aux.jl")

include("test_linop.jl")
Expand All @@ -19,4 +16,9 @@ include("test_normest.jl")
include("test_diag.jl")
include("test_chainrules.jl")
include("test_solve_shifted_system.jl")
include("test_S_kwarg.jl")
include("gpu/test_S_kwarg.jl")
include("gpu/jlarrays.jl")
if Sys.isapple() && occursin("arm64", Sys.MACHINE)
using Metal
include("gpu/metal.jl")
end

0 comments on commit b8ea9bc

Please sign in to comment.