diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 89d280134..a6205a8b3 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -40,7 +40,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) end, # out-of-place versions @thunk(if isempty(x) || p == 0 - zero.(x) .* (zero(y) * zero(real(Δy))) + # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0) + # only infers down to Union(Diagonal{Float64}, Matrix{Float64}) + # rather than Diagonal{Float64}, so we cast it back. + maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy)))) elseif p == 2 _norm2_back(x, y, Δy) elseif p == 1 @@ -72,7 +75,10 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}) end , @thunk(if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) + # Note: post-julia-1.11 the zero.(Diagonal(Float64[;])) .* 0.0) + # only infers down to Union(Diagonal{Float64}, Matrix{Float64}) + # rather than Diagonal{Float64}, so we cast it back. + maybe_withsomezeros_rewrap(x, zero.(x) .* (zero(y) * zero(real(Δy)))) else _norm2_back(x, y, Δy) end) @@ -99,7 +105,7 @@ function rrule(::typeof(norm), x::Number, p::Real) function norm_pullback(ȳ) Δy = unthunk(ȳ) ∂x = if iszero(Δy) || iszero(p) - zero(x) * zero(real(Δy)) + zero(x) * zero(real(Δy)) else signx = x isa Real ? sign(x) : x * pinv(y) signx * real(Δy) diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 3d8ad923f..8d6ac72fa 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -58,7 +58,9 @@ for S in [ :UnitLowerTriangular, ] @eval withsomezeros_rewrap(::$S, x) = $S(x) + @eval maybe_withsomezeros_rewrap(::$S, x) = $S(x) end +maybe_withsomezeros_rewrap(::AbstractArray, x) = x # Bidiagonal, Tridiagonal have more complicated storage. # AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))