Skip to content

Commit 7350228

Browse files
committed
fix AbstractArrayStyle dimension promotion rule
1 parent 58d6684 commit 7350228

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

base/broadcast.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,12 @@ BroadcastStyle(a::AbstractArrayStyle, ::Style{Tuple}) = a
137137
BroadcastStyle(::A, ::A) where A<:ArrayStyle = A()
138138
BroadcastStyle(::ArrayStyle, ::ArrayStyle) = Unknown()
139139
BroadcastStyle(::A, ::A) where A<:AbstractArrayStyle = A()
140-
Base.@pure function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N}
141-
if Base.typename(A) === Base.typename(B)
142-
return A(Val(max(M, N)))
140+
function BroadcastStyle(::A, ::B) where {M,N,A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}}
141+
if M!=N && M!=Any && N!=Any && A(Val(max(M,N)))==B(Val(max(M,N)))
142+
A(Val(max(M,N)))
143+
else
144+
Unknown()
143145
end
144-
return Unknown()
145146
end
146147
# Any specific array type beats DefaultArrayStyle
147148
BroadcastStyle(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a

test/broadcast.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,3 +990,19 @@ p0 = copy(p)
990990
@test isequal(identity.(Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
991991
[[1, 2],[missing, 1]])
992992
end
993+
994+
# test the dimension promotion provided by the generic
995+
# BroadcastStyle(::AbstractArrayStyle, ::AbstractArrayStyle)
996+
struct CustomStyle{T,N} <: Broadcast.AbstractArrayStyle{N} end
997+
Broadcast.BroadcastStyle(::CustomStyle{:A,N}, ::CustomStyle{:B,N}) where {N} = CustomStyle{:A,N}()
998+
CustomStyle{T,N}(::Val{M}) where {T,N,M} = CustomStyle{T,M}()
999+
# test this isn't spoiled:
1000+
@test @inferred(Broadcast.result_style(CustomStyle{:A,1}(), CustomStyle{:B,1}())) ==
1001+
@inferred(Broadcast.result_style(CustomStyle{:B,1}(), CustomStyle{:A,1}())) == CustomStyle{:A,1}()
1002+
# test dimension promotion works here:
1003+
@test @inferred(Broadcast.result_style(CustomStyle{:A,1}(), CustomStyle{:A,2}())) == CustomStyle{:A,2}()
1004+
# here the user would need to specify a custom rule:
1005+
@test @inferred(Broadcast.result_style(CustomStyle{:A,1}(), CustomStyle{:B,2}())) ==
1006+
@inferred(Broadcast.result_style(CustomStyle{:B,2}(), CustomStyle{:A,1}())) ==
1007+
@inferred(Broadcast.result_style(CustomStyle{:A,2}(), CustomStyle{:B,1}())) ==
1008+
@inferred(Broadcast.result_style(CustomStyle{:B,1}(), CustomStyle{:A,2}())) == Broadcast.ArrayConflict()

0 commit comments

Comments
 (0)