Skip to content

Commit

Permalink
Handle type unstable getglobal (#1910)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 28, 2024
1 parent a5c6fee commit 327558b
Showing 1 changed file with 57 additions and 48 deletions.
105 changes: 57 additions & 48 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,26 +456,32 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
else
annotation0
end
world = codegen_world_age(FT, tt)

opt_mi = Val(world)
forward, adjoint = thunk(
opt_mi,
dupClosure0 ? $dupty : Const{FT},
annotationA,
Tuple{$(Types...)},
Val(API.DEM_ReverseModePrimal),
width,
ModifiedBetween,
Val(true),
Val(false),
FFIABI,
Val(false),
runtimeActivity,
) #=erriffuncwritten=#
internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal)
gv = Core.getglobal(args[1].val, args[2].val)
@assert sizeof(gv) == 0
(nothing, f, nothing, Const)
else
world = codegen_world_age(FT, tt)

internal_tape, origRet, initShadow = forward(dupClosure0 ? $dup : Const(f), args...)
annotation = annotationA
opt_mi = Val(world)
forward, adjoint = thunk(
opt_mi,
dupClosure0 ? $dupty : Const{FT},
annotationA,
Tuple{$(Types...)},
Val(API.DEM_ReverseModePrimal),
width,
ModifiedBetween,
Val(true),
Val(false),
FFIABI,
Val(false),
runtimeActivity,
) #=erriffuncwritten=#

(forward(dupClosure0 ? $dup : Const(f), args...)..., annotationA)
end

resT = typeof(origRet)
if annotation <: Const
Expand Down Expand Up @@ -649,39 +655,42 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
annotation0
end

world = codegen_world_age(FT, tt)
if f isa typeof(Core.getglobal)
else
world = codegen_world_age(FT, tt)

opt_mi = Val(world)
_, adjoint = thunk(
opt_mi,
dupClosure0 ? $dupty : Const{FT},
annotation,
Tuple{$(Types...)},
Val(API.DEM_ReverseModePrimal),
width,
ModifiedBetween,
Val(true),
Val(false),
FFIABI,
Val(false),
runtimeActivity,
) #=erriffuncwritten=#
opt_mi = Val(world)
_, adjoint = thunk(
opt_mi,
dupClosure0 ? $dupty : Const{FT},
annotation,
Tuple{$(Types...)},
Val(API.DEM_ReverseModePrimal),
width,
ModifiedBetween,
Val(true),
Val(false),
FFIABI,
Val(false),
runtimeActivity,
) #=erriffuncwritten=#

tup =
if annotation0 <: Active ||
annotation0 <: MixedDuplicated ||
annotation0 <: BatchMixedDuplicated
adjoint(
dupClosure0 ? $dup : Const(f),
args...,
$shadowret,
tape.internal_tape,
)[1]
else
adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1]
end
tup =
if annotation0 <: Active ||
annotation0 <: MixedDuplicated ||
annotation0 <: BatchMixedDuplicated
adjoint(
dupClosure0 ? $dup : Const(f),
args...,
$shadowret,
tape.internal_tape,
)[1]
else
adjoint(dupClosure0 ? $dup : Const(f), args..., tape.internal_tape)[1]
end

$(outs...)
$(outs...)
end

return nothing
end
Expand Down

0 comments on commit 327558b

Please sign in to comment.