Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fewer anonymous funcs and despecialize inference #2131

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/absint.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Abstractly interpret julia from LLVM

# Return (bool if could interpret, julia object interpreted to)
function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any}
Base.@nospecializeinfer function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any}
if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst)
return absint(operands(arg)[1], partial)
end
Expand Down Expand Up @@ -165,7 +165,7 @@ function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bo
return (false, nothing)
end

function actual_size(@nospecialize(typ2))::Int
Base.@nospecializeinfer function actual_size(@nospecialize(typ2))::Int
@static if VERSION < v"1.11-"
if typ2 <: Array
return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t)
Expand All @@ -184,7 +184,7 @@ function actual_size(@nospecialize(typ2))::Int
end
end

@inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int}
Base.@nospecializeinfer @inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int}
@static if VERSION < v"1.11-"
if typ2 <: Array
return (1, 0)
Expand All @@ -204,7 +204,7 @@ end
return (-1, 0)
end

function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool
Base.@nospecializeinfer function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool
sz = if arg_t == LLVM.IntType(1)
1
else
Expand Down Expand Up @@ -232,7 +232,7 @@ function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType)
end
end

function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int}
Base.@nospecializeinfer function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int}
offset = 0
while true
if isa(larg, LLVM.ConstantExpr)
Expand Down Expand Up @@ -280,7 +280,7 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
return larg, offset
end

function abs_typeof(
Base.@nospecializeinfer function abs_typeof(
@nospecialize(arg::LLVM.Value),
partial::Bool = false, seenphis=Set{LLVM.PHIInst}()
)::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}}
Expand Down Expand Up @@ -758,7 +758,7 @@ end
return false
end

function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String}
Base.@nospecializeinfer function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String}
if isa(arg, ConstantExpr)
ce = arg
while isa(ce, ConstantExpr)
Expand Down
19 changes: 10 additions & 9 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data
typeof(Base.FastMath.tanh_fast) => (:tanh, 1, nothing),
typeof(Base.fma_emulated) => (:fma, 3, nothing),
)
@inline function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector)
@inline Base.@nospecializeinfer function find_math_method(@nospecialize(func::Type), sparam_vals::Core.SimpleVector)
if func ∈ keys(known_ops)
name, arity, toinject = known_ops[func]
Tys = (Float32, Float64)
Expand Down Expand Up @@ -317,7 +317,8 @@ include("llvm/transforms.jl")
include("llvm/passes.jl")
include("typeutils/make_zero.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)

Base.@nospecializeinfer function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(typeof(f), tt, world)
nested_codegen!(mode, mod, funcspec, world)
end
Expand Down Expand Up @@ -1345,7 +1346,7 @@ include("rules/activityrules.jl")
const DumpPreEnzyme = Ref(false)
const DumpPostWrap = Ref(false)

function enzyme!(
Base.@nospecializeinfer function enzyme!(
job::CompilerJob,
mod::LLVM.Module,
primalf::LLVM.Function,
Expand Down Expand Up @@ -1685,7 +1686,7 @@ function set_subprogram!(f::LLVM.Function, sp)
end
end

function create_abi_wrapper(
Base.@nospecializeinfer function create_abi_wrapper(
enzymefn::LLVM.Function,
@nospecialize(TT::Type),
@nospecialize(rettype::Type),
Expand Down Expand Up @@ -2167,7 +2168,7 @@ function create_abi_wrapper(
metadata(val)[LLVM.MD_dbg] = DILocation(0, 0, get_subprogram(llvm_f))
end

@inline function fixup_abi(index::Int, @nospecialize(value::LLVM.Value))
@inline Base.@nospecializeinfer function fixup_abi(index::Int, @nospecialize(value::LLVM.Value))
valty = sret_types[index]
# Union becoming part of a tuple needs to be adjusted
# See https://github.com/JuliaLang/julia/blob/81afdbc36b365fcbf3ae25b7451c6cb5798c0c3d/src/cgutils.cpp#L3795C1-L3801C121
Expand Down Expand Up @@ -2505,7 +2506,7 @@ function fixup_metadata!(f::LLVM.Function)
end

# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
function lower_convention(
Base.@nospecializeinfer function lower_convention(
@nospecialize(functy::Type),
mod::LLVM.Module,
entry_f::LLVM.Function,
Expand Down Expand Up @@ -3206,7 +3207,7 @@ end

using Random
# returns arg, return
function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing)
Base.@nospecializeinfer function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing)
# Even though the julia type here is ptr{int8}, the actual data can be something else
if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd)
return (true, false)
Expand Down Expand Up @@ -5226,7 +5227,7 @@ end
# JIT
##

function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
Base.@nospecializeinfer function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
if job.config.params.ABI <: InlineABI
return CompileResult(
Val((Symbol(mod), Symbol(adjoint_name))),
Expand Down Expand Up @@ -5337,7 +5338,7 @@ const cache_lock = ReentrantLock()
end
end

@inline function thunkbase(
Base.@nospecializeinfer @inline function thunkbase(
mi::Core.MethodInstance,
World::Union{UInt, Nothing},
@nospecialize(FA::Type{<:Annotation}),
Expand Down
41 changes: 21 additions & 20 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,35 +228,36 @@ EnzymeInterpreter(
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler)

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world
Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache

Base.@nospecializeinfer Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Base.@nospecializeinfer Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Base.@nospecializeinfer get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world
Base.@nospecializeinfer Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache

@static if HAS_INTEGRATED_CACHE
Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token
Base.@nospecializeinfer Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token
else
Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) =
Base.@nospecializeinfer Core.Compiler.code_cache(@nospecialize(interp::EnzymeInterpreter)) =
WorldView(interp.code_cache, interp.world)
end

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing
Base.@nospecializeinfer Core.Compiler.lock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing
Base.@nospecializeinfer Core.Compiler.unlock_mi_inference(@nospecialize(::EnzymeInterpreter), ::MethodInstance) = nothing

Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true
Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true
Base.@nospecializeinfer Core.Compiler.may_optimize(@nospecialize(::EnzymeInterpreter)) = true
Base.@nospecializeinfer Core.Compiler.may_compress(@nospecialize(::EnzymeInterpreter)) = true
# From @aviatesk:
# `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded,
# but as far as I understand Enzyme wants "always inlining, except special cased functions",
# so I guess we really don't want to discard sources?
Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false
Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false
Base.@nospecializeinfer Core.Compiler.may_discard_trees(@nospecialize(::EnzymeInterpreter)) = false
Base.@nospecializeinfer Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false

Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) =
Base.@nospecializeinfer Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) =
Core.Compiler.OverlayMethodTable(interp.world, interp.method_table)

function is_alwaysinline_func(@nospecialize(TT))::Bool
Base.@nospecializeinfer function is_alwaysinline_func(@nospecialize(TT))::Bool
isa(TT, DataType) || return false
@static if VERSION ≥ v"1.11-"
if TT.parameters[1] == typeof(Core.memoryref)
Expand All @@ -266,7 +267,7 @@ function is_alwaysinline_func(@nospecialize(TT))::Bool
return false
end

function is_primitive_func(@nospecialize(TT))::Bool
Base.@nospecializeinfer function is_primitive_func(@nospecialize(TT))::Bool
isa(TT, DataType) || return false
ft = TT.parameters[1]
if ft == typeof(Enzyme.pmap)
Expand All @@ -289,7 +290,7 @@ function is_primitive_func(@nospecialize(TT))::Bool
return false
end

function isKWCallSignature(@nospecialize(TT))::Bool
Base.@nospecializeinfer function isKWCallSignature(@nospecialize(TT))::Bool
return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg}
end

Expand Down Expand Up @@ -329,7 +330,7 @@ Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) =

import .EnzymeRules: FwdConfig, RevConfig, Annotation
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function Core.Compiler.abstract_call_gf_by_type(
Base.@nospecializeinfer function Core.Compiler.abstract_call_gf_by_type(
@nospecialize(interp::EnzymeInterpreter),
@nospecialize(f),
arginfo::ArgInfo,
Expand Down Expand Up @@ -424,7 +425,7 @@ let # overload `inlining_policy`
)
end
@static if isdefined(Core.Compiler, :inlining_policy)
@eval function Core.Compiler.inlining_policy($(sigs_ex.args...))
@eval Base.@nospecializeinfer function Core.Compiler.inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
if info.kind === :primitive
@safe_debug "Blocking inlining for primitive func" info.tt
Expand All @@ -444,7 +445,7 @@ let # overload `inlining_policy`
return @invoke Core.Compiler.inlining_policy($(args_ex.args...))
end
else
@eval function Core.Compiler.src_inlining_policy($(sigs_ex.args...))
@eval Base.@nospecializeinfer function Core.Compiler.src_inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
if info.kind === :primitive
@safe_debug "Blocking inlining for primitive func" info.tt
Expand Down Expand Up @@ -903,7 +904,7 @@ end
end
end

function abstract_call_known(
Base.@nospecializeinfer function abstract_call_known(
interp::EnzymeInterpreter{Handler},
@nospecialize(f),
arginfo::ArgInfo,
Expand Down
6 changes: 3 additions & 3 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ Base.@assume_effects :removable :foldable :nothrow function has_fn_attr(fn::LLVM
return false
end

function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction))
Base.@nospecializeinfer function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction))
@static if isdefined(LLVM, Symbol("erase!"))
LLVM.erase!(inst)
else
Expand Down Expand Up @@ -446,7 +446,7 @@ end
NamedTuple{ntuple(Symbol, Val(length(U.parameters))),U}

# recursively compute the eltype type indexed by idx[0], idx[1], ...
Base.@assume_effects :removable :foldable :nothrow function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint})::LLVM.LLVMType
Base.@nospecializeinfer Base.@assume_effects :removable :foldable :nothrow function recursive_eltype(@nospecialize(val::LLVM.Value), idxs::Vector{Cuint})::LLVM.LLVMType
ty = LLVM.value_type(val)::LLVM.LLVMType
for i in idxs
if isa(ty, LLVM.ArrayType)
Expand All @@ -461,7 +461,7 @@ end

# Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float}
# and that Bool -> i8, not i1
function calling_conv_fixup(
Base.@nospecializeinfer function calling_conv_fixup(
builder::LLVM.IRBuilder,
@nospecialize(val::LLVM.Value),
@nospecialize(tape::LLVM.LLVMType),
Expand Down
Loading
Loading