Skip to content

Commit

Permalink
Fix #2182: avoid double accumulation in *MixedDuplicated (#2262)
Browse files Browse the repository at this point in the history
* Add nonzero dval tests to expose jit bug

* Fix jit bug
  • Loading branch information
danielwe authored Jan 13, 2025
1 parent 8a0bff4 commit 70a2940
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 203 deletions.
6 changes: 4 additions & 2 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ end
end
elseif args[i] <: MixedDuplicated
:(args[$i].dval[])
else
else # args[i] <: BatchMixedDuplicated
:(args[$i].dval[$w][])
end

Expand All @@ -1500,9 +1500,11 @@ end
T = Core.Typeof(vecld)
@assert !(vecld isa Base.RefValue)
vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr)
else
elseif $(args[i] <: Active)
val = @inbounds vec[idx_in_vec]
add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec)
else # args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated
@inbounds vec[idx_in_vec] = $expr
end
end
else
Expand Down
295 changes: 157 additions & 138 deletions test/applyiter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,155 +105,174 @@ end

@testset "Reverse Apply iterate" begin
x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]
res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])

dx = [(0.0, 0.0), (0.0, 0.0)]
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
@test res[2] 200.84999999999997
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])

x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]

res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
@test dx [[4.0, 6.0], [15.8, 22.4]]

dx = [[0.0, 0.0], [0.0, 0.0]]

res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))

@test res[2] 200.84999999999997
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])


x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]

y = [(13, 17), (25, 31)]
res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])


x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]
y = [(13, 17), (25, 31)]
dy = [(0, 0), (0, 0)]
res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])



x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]
y = [[13, 17], [25, 31]]
res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])

dy_const = [(0, 0), (0, 0)]
primal = 200.84999999999997
@testset "tuple $label" for (label, dx_pre, dx_post) in [
("dx == 0", [(0.0, 0.0), (0.0, 0.0)], [(4.0, 6.0), (15.8, 22.4)]),
("dx != 0", [(1.0, -2.0), (-3.0, 4.0)], [(5.0, 4.0), (12.8, 26.4)]),
]
dx = deepcopy(dx_pre)
Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
@test tupapprox(dx, dx_post)

dx = deepcopy(dx_pre)
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
@test res[2] primal
@test tupapprox(dx, dx_post)

dx = deepcopy(dx_pre)
Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
@test tupapprox(dx, dx_post)

dx = deepcopy(dx_pre)
dy = deepcopy(dy_const)
Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
@test tupapprox(dx, dx_post)
@test tupapprox(dy, dy_const)
end

x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]
y = [[13, 17], [25, 31]]
dy = [[0, 0], [0, 0]]
res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
dy_const = [[0, 0], [0, 0]]
primal = 200.84999999999997
@testset "list $label" for (label, dx_pre, dx_post) in [
("dx == 0", [[0.0, 0.0], [0.0, 0.0]], [[4.0, 6.0], [15.8, 22.4]]),
("dx != 0", [[1.0, -2.0], [-3.0, 4.0]], [[5.0, 4.0], [12.8, 26.4]]),
]
dx = deepcopy(dx_pre)
Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
@test dx dx_post

dx = deepcopy(dx_pre)
res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx))
@test res[2] primal
@test dx dx_post

dx = deepcopy(dx_pre)
Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y))
@test dx dx_post

dx = deepcopy(dx_pre)
dy = deepcopy(dy_const)
Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy))
@test dx dx_post
@test dy dy_const
end
end

@testset "BatchReverse Apply iterate" begin
x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]
dx2 = [(0.0, 0.0), (0.0, 0.0)]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])

dx = [(0.0, 0.0), (0.0, 0.0)]
dx2 = [(0.0, 0.0), (0.0, 0.0)]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test out[] 200.84999999999997
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])

x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]
dx2 = [[0.0, 0.0], [0.0, 0.0]]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)

Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test dx [[4.0, 6.0], [15.8, 22.4]]
@test dx2 [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]

dx = [[0.0, 0.0], [0.0, 0.0]]
dx2 = [[0.0, 0.0], [0.0, 0.0]]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))

@test out[] 200.84999999999997
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])


x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]
dx2 = [(0.0, 0.0), (0.0, 0.0)]

y = [(13, 17), (25, 31)]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])


x = [(2.0, 3.0), (7.9, 11.2)]
dx = [(0.0, 0.0), (0.0, 0.0)]
dx2 = [(0.0, 0.0), (0.0, 0.0)]
y = [(13, 17), (25, 31)]
dy = [(0, 0), (0, 0)]
dy2 = [(0, 0), (0, 0)]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
@test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)])
@test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)])


x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]
dx2 = [[0.0, 0.0], [0.0, 0.0]]
y = [[13, 17], [25, 31]]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])
dy_const = [(0, 0), (0, 0)]
primal = 200.84999999999997
out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0
@testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [
(
"dx == 0",
[(0.0, 0.0), (0.0, 0.0)],
[(4.0, 6.0), (15.8, 22.4)],
[(3 * 4.0, 3 * 6.0), (3 * 15.8, 3 * 22.4)],
),
(
"dx != 0",
[(1.0, -2.0), (-3.0, 4.0)],
[(5.0, 4.0), (12.8, 26.4)],
[(1.0 + 3 * 4.0, -2.0 + 3 * 6.0), (-3.0 + 3 * 15.8, 4.0 + 3 * 22.4)],
),
]
dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test dout[] 0
@test dout2[] 0
@test tupapprox(dx, dx_post)
@test tupapprox(dx2, dx2_post)

dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test out[] primal
@test dout[] 0
@test dout2[] 0
@test tupapprox(dx, dx_post)
@test tupapprox(dx2, dx2_post)

dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
@test dout[] 0
@test dout2[] 0
@test tupapprox(dx, dx_post)
@test tupapprox(dx2, dx2_post)

dx, dx2 = deepcopy.((dx_pre, dx_pre))
dy, dy2 = deepcopy.((dy_const, dy_const))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
@test dout[] 0
@test dout2[] 0
@test tupapprox(dx, dx_post)
@test tupapprox(dx2, dx2_post)
@test tupapprox(dy, dy_const)
@test tupapprox(dy2, dy_const)
end

x = [[2.0, 3.0], [7.9, 11.2]]
dx = [[0.0, 0.0], [0.0, 0.0]]
dx2 = [[0.0, 0.0], [0.0, 0.0]]
y = [[13, 17], [25, 31]]
dy = [[0, 0], [0, 0]]
dy2 = [[0, 0], [0, 0]]
out = Ref(0.0)
dout = Ref(1.0)
dout2 = Ref(3.0)
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
@test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]])
@test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]])
dy_const = [[0, 0], [0, 0]]
primal = 200.84999999999997
out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0
@testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [
(
"dx == 0",
[[0.0, 0.0], [0.0, 0.0]],
[[4.0, 6.0], [15.8, 22.4]],
[[3 * 4.0, 3 * 6.0], [3 * 15.8, 3 * 22.4]],
),
(
"dx != 0",
[[1.0, -2.0], [-3.0, 4.0]],
[[5.0, 4.0], [12.8, 26.4]],
[[1.0 + 3 * 4.0, -2.0 + 3 * 6.0], [-3.0 + 3 * 15.8, 4.0 + 3 * 22.4]],
),
]
dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test dout[] 0
@test dout2[] 0
@test dx dx_post
@test dx2 dx2_post

dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
@test out[] primal
@test dout[] 0
@test dout2[] 0
@test dx dx_post
@test dx2 dx2_post

dx, dx2 = deepcopy.((dx_pre, dx_pre))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y))
@test dout[] 0
@test dout2[] 0
@test dx dx_post
@test dx2 dx2_post

dx, dx2 = deepcopy.((dx_pre, dx_pre))
dy, dy2 = deepcopy.((dy_const, dy_const))
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2)))
@test dout[] 0
@test dout2[] 0
@test dx dx_post
@test dx2 dx2_post
@test dy dy_const
@test dy2 dy_const
end
end

@testset "Forward Apply iterate" begin
Expand Down Expand Up @@ -502,4 +521,4 @@ end
@test ddata[1][1] 6.0
end

include("mixedapplyiter.jl")
include("mixedapplyiter.jl")
Loading

0 comments on commit 70a2940

Please sign in to comment.