diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 0304c7297b..8f2867c21c 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1488,7 +1488,7 @@ end end elseif args[i] <: MixedDuplicated :(args[$i].dval[]) - else + else # args[i] <: BatchMixedDuplicated :(args[$i].dval[$w][]) end @@ -1500,9 +1500,11 @@ end T = Core.Typeof(vecld) @assert !(vecld isa Base.RefValue) vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), $expr) - else + elseif $(args[i] <: Active) val = @inbounds vec[idx_in_vec] add_into_vec!(Base.inferencebarrier(val), $expr, vec, idx_in_vec) + else # args[i] <: MixedDuplicated || args[i] <: BatchMixedDuplicated + @inbounds vec[idx_in_vec] = $expr end end else diff --git a/test/applyiter.jl b/test/applyiter.jl index 642ad62035..699a3cd69e 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -105,155 +105,174 @@ end @testset "Reverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - - dx = [(0.0, 0.0), (0.0, 0.0)] - res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test res[2] ≈ 200.84999999999997 - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - - x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] - - res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) - @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] - - dx = [[0.0, 0.0], [0.0, 0.0]] - - res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) - - @test res[2] ≈ 200.84999999999997 - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - - - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - y = [(13, 17), (25, 31)] - res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - - - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - y = [(13, 17), (25, 31)] - dy = [(0, 0), (0, 0)] - res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - - - - x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] - y = [[13, 17], [25, 31]] - res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - + dy_const = [(0, 0), (0, 0)] + primal = 200.84999999999997 + @testset "tuple $label" for (label, dx_pre, dx_post) in [ + ("dx == 0", [(0.0, 0.0), (0.0, 0.0)], [(4.0, 6.0), (15.8, 22.4)]), + ("dx != 0", [(1.0, -2.0), (-3.0, 4.0)], [(5.0, 4.0), (12.8, 26.4)]), + ] + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, dx_post) + + dx = deepcopy(dx_pre) + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ primal + @test tupapprox(dx, dx_post) + + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, dx_post) + + dx = deepcopy(dx_pre) + dy = deepcopy(dy_const) + Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, dx_post) + @test tupapprox(dy, dy_const) + end x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] y = [[13, 17], [25, 31]] - dy = [[0, 0], [0, 0]] - res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + dy_const = [[0, 0], [0, 0]] + primal = 200.84999999999997 + @testset "list $label" for (label, dx_pre, dx_post) in [ + ("dx == 0", [[0.0, 0.0], [0.0, 0.0]], [[4.0, 6.0], [15.8, 22.4]]), + ("dx != 0", [[1.0, -2.0], [-3.0, 4.0]], [[5.0, 4.0], [12.8, 26.4]]), + ] + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + @test dx ≈ dx_post + + dx = deepcopy(dx_pre) + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ primal + @test dx ≈ dx_post + + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test dx ≈ dx_post + + dx = deepcopy(dx_pre) + dy = deepcopy(dy_const) + Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test dx ≈ dx_post + @test dy ≈ dy_const + end end @testset "BatchReverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - dx2 = [(0.0, 0.0), (0.0, 0.0)] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) - - dx = [(0.0, 0.0), (0.0, 0.0)] - dx2 = [(0.0, 0.0), (0.0, 0.0)] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test out[] ≈ 200.84999999999997 - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) - - x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] - dx2 = [[0.0, 0.0], [0.0, 0.0]] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] - @test dx2 ≈ [[3*4.0, 3*6.0], [3*15.8, 3*22.4]] - - dx = [[0.0, 0.0], [0.0, 0.0]] - dx2 = [[0.0, 0.0], [0.0, 0.0]] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - - @test out[] ≈ 200.84999999999997 - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) - - - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - dx2 = [(0.0, 0.0), (0.0, 0.0)] - y = [(13, 17), (25, 31)] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) - - - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(0.0, 0.0), (0.0, 0.0)] - dx2 = [(0.0, 0.0), (0.0, 0.0)] - y = [(13, 17), (25, 31)] - dy = [(0, 0), (0, 0)] - dy2 = [(0, 0), (0, 0)] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) - - - x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] - dx2 = [[0.0, 0.0], [0.0, 0.0]] - y = [[13, 17], [25, 31]] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + dy_const = [(0, 0), (0, 0)] + primal = 200.84999999999997 + out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0 + @testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [ + ( + "dx == 0", + [(0.0, 0.0), (0.0, 0.0)], + [(4.0, 6.0), (15.8, 22.4)], + [(3 * 4.0, 3 * 6.0), (3 * 15.8, 3 * 22.4)], + ), + ( + "dx != 0", + [(1.0, -2.0), (-3.0, 4.0)], + [(5.0, 4.0), (12.8, 26.4)], + [(1.0 + 3 * 4.0, -2.0 + 3 * 6.0), (-3.0 + 3 * 15.8, 4.0 + 3 * 22.4)], + ), + ] + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ primal + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + dy, dy2 = deepcopy.((dy_const, dy_const)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + @test tupapprox(dy, dy_const) + @test tupapprox(dy2, dy_const) + end x = [[2.0, 3.0], [7.9, 11.2]] - dx = [[0.0, 0.0], [0.0, 0.0]] - dx2 = [[0.0, 0.0], [0.0, 0.0]] y = [[13, 17], [25, 31]] - dy = [[0, 0], [0, 0]] - dy2 = [[0, 0], [0, 0]] - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + dy_const = [[0, 0], [0, 0]] + primal = 200.84999999999997 + out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0 + @testset "tuple $label" for (label, dx_pre, dx_post, dx2_post) in [ + ( + "dx == 0", + [[0.0, 0.0], [0.0, 0.0]], + [[4.0, 6.0], [15.8, 22.4]], + [[3 * 4.0, 3 * 6.0], [3 * 15.8, 3 * 22.4]], + ), + ( + "dx != 0", + [[1.0, -2.0], [-3.0, 4.0]], + [[5.0, 4.0], [12.8, 26.4]], + [[1.0 + 3 * 4.0, -2.0 + 3 * 6.0], [-3.0 + 3 * 15.8, 4.0 + 3 * 22.4]], + ), + ] + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test dx ≈ dx_post + @test dx2 ≈ dx2_post + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ primal + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test dx ≈ dx_post + @test dx2 ≈ dx2_post + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test dx ≈ dx_post + @test dx2 ≈ dx2_post + + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + dy, dy2 = deepcopy.((dy_const, dy_const)) + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test dx ≈ dx_post + @test dx2 ≈ dx2_post + @test dy ≈ dy_const + @test dy2 ≈ dy_const + end end @testset "Forward Apply iterate" begin @@ -502,4 +521,4 @@ end @test ddata[1][1] ≈ 6.0 end -include("mixedapplyiter.jl") \ No newline at end of file +include("mixedapplyiter.jl") diff --git a/test/mixedapplyiter.jl b/test/mixedapplyiter.jl index 0a4f06cbb9..73fe35b837 100644 --- a/test/mixedapplyiter.jl +++ b/test/mixedapplyiter.jl @@ -66,81 +66,128 @@ end @testset "Mixed Reverse Apply iterate (tuple)" begin x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] - dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) - - x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] - - dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test res[2] ≈ 5562.9996 - @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + primal = 5562.9996 + @testset "$label" for (label, dx_pre, dx_post) in [ + ( + "dx == 0", + [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))], + [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))], + ), + ( + "dx != 0", + [((1.0, [-2.0]), (-3.0, [4.0])), ((5.0, [-6.0]), (-7.0, [8.0]))], + [((5.0, [3.4]), (3.0, [10.28])), ((20.8, [88.0]), (15.4, [120.0]))], + ), + ] + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, dx_post) + + dx = deepcopy(dx_pre) + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ primal + @test tupapprox(dx, dx_post) + end end @testset "BatchMixed Reverse Apply iterate (tuple)" begin x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] - dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) - @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) - - x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] - dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] - - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test out[] ≈ 5562.9996 - @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) - @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) + primal = 5562.9996 + out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0 + @testset "$label" for (label, dx_pre, dx_post, dx2_post) in [ + ( + "dx == 0", + [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))], + [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))], + [((3 * 4.0, [3 * 5.4]), (3 * 6.0, [3 * 6.28])), ((3 * 15.8, [3 * 94.0]), (3 * 22.4, [3 * 112.0]))], + ), + ( + "dx != 0", + [((1.0, [-2.0]), (-3.0, [4.0])), ((5.0, [-6.0]), (-7.0, [8.0]))], + [((5.0, [3.4]), (3.0, [10.28])), ((20.8, [88.0]), (15.4, [120.0]))], + [((1.0 + 3 * 4.0, [-2.0 + 3 * 5.4]), (-3.0 + 3 * 6.0, [4.0 + 3 * 6.28])), ((5.0 + 3 * 15.8, [-6.0 + 3 * 94.0]), (-7.0 + 3 * 22.4, [8.0 + 3 * 112.0]))], + ), + ] + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ primal + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + end end - @testset "Mixed Reverse Apply iterate (list)" begin x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] - dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - - res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) - - dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - - res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @test res[2] ≈ 5562.9996 - @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + primal = 5562.9996 + @testset "$label" for (label, dx_pre, dx_post) in [ + ( + "dx == 0", + [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]], + [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]], + ), + ( + "dx != 0", + [[(1.0, [-2.0]), (-3.0, [4.0])], [(5.0, [-6.0]), (-7.0, [8.0])]], + [[(5.0, [3.4]), (3.0, [10.28])], [(20.8, [88.0]), (15.4, [120.0])]], + ), + ] + dx = deepcopy(dx_pre) + Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, dx_post) + + dx = deepcopy(dx_pre) + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ primal + @test tupapprox(dx, dx_post) + end end @testset "BatchMixed Reverse Apply iterate (list)" begin x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] - dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) - @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) - - x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] - dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] - - out = Ref(0.0) - dout = Ref(1.0) - dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test out[] ≈ 5562.9996 - @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) - @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) + primal = 5562.9996 + out_pre, dout_pre, dout2_pre = 0.0, 1.0, 3.0 + @testset "$label" for (label, dx_pre, dx_post, dx2_post) in [ + ( + "dx == 0", + [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]], + [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]], + [[(3 * 4.0, [3 * 5.4]), (3 * 6.0, [3 * 6.28])], [(3 * 15.8, [3 * 94.0]), (3 * 22.4, [3 * 112.0])]], + ), + ( + "dx != 0", + [[(1.0, [-2.0]), (-3.0, [4.0])], [(5.0, [-6.0]), (-7.0, [8.0])]], + [[(5.0, [3.4]), (3.0, [10.28])], [(20.8, [88.0]), (15.4, [120.0])]], + [[(1.0 + 3 * 4.0, [-2.0 + 3 * 5.4]), (-3.0 + 3 * 6.0, [4.0 + 3 * 6.28])], [(5.0 + 3 * 15.8, [-6.0 + 3 * 94.0]), (-7.0 + 3 * 22.4, [8.0 + 3 * 112.0])]], + ), + ] + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + + out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre)) + dx, dx2 = deepcopy.((dx_pre, dx_pre)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ primal + @test dout[] ≈ 0 + @test dout2[] ≈ 0 + @test tupapprox(dx, dx_post) + @test tupapprox(dx2, dx2_post) + end end struct MyRectilinearGrid5{FT,FZ}