Skip to content

Commit

Permalink
Merge pull request #1215 from Nadrieril/fix-call-generics
Browse files Browse the repository at this point in the history
Fix generics handling for function calls
  • Loading branch information
Nadrieril authored Jan 6, 2025
2 parents 52ad1f9 + 63422f3 commit 4d7cbff
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 114 deletions.
4 changes: 2 additions & 2 deletions engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,8 @@ end) : EXPR = struct
| Float k ->
TFloat
(match k with F16 -> F16 | F32 -> F32 | F64 -> F64 | F128 -> F128)
| Arrow value ->
let ({ inputs; output; _ } : Thir.ty_fn_sig) = value.value in
| Arrow signature | Closure (_, { untupled_sig = signature; _ }) ->
let ({ inputs; output; _ } : Thir.ty_fn_sig) = signature.value in
let inputs =
if List.is_empty inputs then [ U.unit_typ ]
else List.map ~f:(c_ty span) inputs
Expand Down
3 changes: 3 additions & 0 deletions frontend/exporter/src/traits/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ pub fn required_predicates<'tcx>(
.iter()
.map(|(clause, _span)| *clause),
),
// The tuple struct/variant constructor functions inherit the generics and predicates from
// their parents.
Variant | Ctor(..) => return required_predicates(tcx, tcx.parent(def_id)),
// We consider all predicates on traits to be outputs
Trait => None,
// `predicates_defined_on` ICEs on other def kinds.
Expand Down
175 changes: 94 additions & 81 deletions frontend/exporter/src/types/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
use crate::prelude::*;
use crate::sinto_as_usize;
#[cfg(feature = "rustc")]
use rustc_middle::{mir, ty};
#[cfg(feature = "rustc")]
use tracing::trace;

#[derive_group(Serializers)]
Expand Down Expand Up @@ -410,75 +412,91 @@ pub(crate) fn get_function_from_def_id_and_generics<'tcx, S: BaseState<'tcx> + H
(def_id.sinto(s), generics, trait_refs, source)
}

/// Get a `FunOperand` from an `Operand` used in a function call.
/// Return the [DefId] of the function referenced by an operand, with the
/// parameters substitution.
/// The [Operand] comes from a [TerminatorKind::Call].
#[cfg(feature = "rustc")]
fn get_function_from_operand<'tcx, S: UnderOwnerState<'tcx> + HasMir<'tcx>>(
s: &S,
op: &rustc_middle::mir::Operand<'tcx>,
) -> (FunOperand, Vec<GenericArg>, Vec<ImplExpr>, Option<ImplExpr>) {
// Match on the func operand: it should be a constant as we don't support
// closures for now.
use rustc_middle::mir::Operand;
use rustc_middle::ty::TyKind;
let ty = op.ty(&s.mir().local_decls, s.base().tcx);
trace!("type: {:?}", ty);
// If the type of the value is one of the singleton types that corresponds to each function,
// that's enough information.
if let TyKind::FnDef(def_id, generics) = ty.kind() {
let (fun_id, generics, trait_refs, trait_info) =
get_function_from_def_id_and_generics(s, *def_id, *generics);
return (FunOperand::Id(fun_id), generics, trait_refs, trait_info);
}
match op {
Operand::Constant(_) => {
unimplemented!("{:?}", op);
}
Operand::Move(place) => {
// Function pointer. A fn pointer cannot have bound variables or trait references, so
// we don't need to extract generics, trait refs, etc.
let place = place.sinto(s);
(FunOperand::Move(place), Vec::new(), Vec::new(), None)
}
Operand::Copy(_place) => {
unimplemented!("{:?}", op);
}
}
}

#[cfg(feature = "rustc")]
fn translate_terminator_kind_call<'tcx, S: BaseState<'tcx> + HasMir<'tcx> + HasOwnerId>(
s: &S,
terminator: &rustc_middle::mir::TerminatorKind<'tcx>,
) -> TerminatorKind {
if let rustc_middle::mir::TerminatorKind::Call {
let tcx = s.base().tcx;
let mir::TerminatorKind::Call {
func,
args,
destination,
target,
unwind,
call_source,
fn_span,
..
} = terminator
{
let (fun, generics, trait_refs, trait_info) = get_function_from_operand(s, func);
else {
unreachable!()
};

TerminatorKind::Call {
fun,
let ty = func.ty(&s.mir().local_decls, tcx);
let hax_ty: crate::Ty = ty.sinto(s);
let sig = match hax_ty.kind() {
TyKind::Arrow(sig) => sig,
TyKind::Closure(_, args) => &args.untupled_sig,
_ => supposely_unreachable_fatal!(
s,
"TerminatorKind_Call_expected_fn_type";
{ ty }
),
};
let fun_op = if let ty::TyKind::FnDef(def_id, generics) = ty.kind() {
// The type of the value is one of the singleton types that corresponds to each function,
// which is enough information.
let (def_id, generics, trait_refs, trait_info) =
get_function_from_def_id_and_generics(s, *def_id, *generics);
FunOperand::Static {
def_id,
generics,
args: args.sinto(s),
destination: destination.sinto(s),
target: target.sinto(s),
unwind: unwind.sinto(s),
call_source: call_source.sinto(s),
fn_span: fn_span.sinto(s),
trait_refs,
trait_info,
}
} else {
unreachable!()
use mir::Operand;
match func {
Operand::Constant(_) => {
unimplemented!("{:?}", func);
}
Operand::Move(place) => {
// Function pointer or closure.
let place = place.sinto(s);
FunOperand::DynamicMove(place)
}
Operand::Copy(_place) => {
unimplemented!("{:?}", func);
}
}
};

let late_bound_generics = sig
.bound_vars
.iter()
.map(|var| match var {
BoundVariableKind::Region(r) => r,
BoundVariableKind::Ty(..) | BoundVariableKind::Const => {
supposely_unreachable_fatal!(
s,
"non_lifetime_late_bound";
{ var }
)
}
})
.map(|_| {
GenericArg::Lifetime(Region {
kind: RegionKind::ReErased,
})
})
.collect();
TerminatorKind::Call {
fun: fun_op,
late_bound_generics,
args: args.sinto(s),
destination: destination.sinto(s),
target: target.sinto(s),
unwind: unwind.sinto(s),
fn_span: fn_span.sinto(s),
}
}

Expand Down Expand Up @@ -562,13 +580,25 @@ pub enum SwitchTargets {
SwitchInt(IntUintTy, Vec<(ScalarInt, BasicBlock)>, BasicBlock),
}

/// A value of type `fn<...> A -> B` that can be called.
#[derive_group(Serializers)]
#[derive(Clone, Debug, JsonSchema)]
pub enum FunOperand {
/// Call to a top-level function designated by its id
Id(DefId),
/// Use of a closure
Move(Place),
/// Call to a statically-known function.
Static {
def_id: DefId,
/// If `Some`, this is a method call on the given trait reference. Otherwise this is a call
/// to a known function.
trait_info: Option<ImplExpr>,
/// If this is a trait method call, this only includes the method generics; the trait
/// generics are included in the `ImplExpr` in `trait_info`.
generics: Vec<GenericArg>,
/// Trait predicates required by the function generics. Like for `generics`, this only
/// includes the predicates required by the method, if applicable.
trait_refs: Vec<ImplExpr>,
},
/// Use of a closure or a function pointer value. Counts as a move from the given place.
DynamicMove(Place),
}

#[derive_group(Serializers)]
Expand Down Expand Up @@ -607,18 +637,16 @@ pub enum TerminatorKind {
)]
Call {
fun: FunOperand,
/// We truncate the substitution so as to only include the arguments
/// relevant to the method (and not the trait) if it is a trait method
/// call. See [ParamsInfo] for the full details.
generics: Vec<GenericArg>,
/// A `FunOperand` is a value of type `fn<...> A -> B`. The generics in `<...>` are called
/// "late-bound" and are instantiated anew at each call site. This list provides the
/// generics used at this call-site. They are all lifetimes and at the time of writing are
/// all erased lifetimes.
late_bound_generics: Vec<GenericArg>,
args: Vec<Spanned<Operand>>,
destination: Place,
target: Option<BasicBlock>,
unwind: UnwindAction,
call_source: CallSource,
fn_span: Span,
trait_refs: Vec<ImplExpr>,
trait_info: Option<ImplExpr>,
},
TailCall {
func: Operand,
Expand Down Expand Up @@ -934,26 +962,12 @@ pub enum AggregateKind {
Option<UserTypeAnnotationIndex>,
Option<FieldIdx>,
),
#[custom_arm(rustc_middle::mir::AggregateKind::Closure(rust_id, generics) => {
let def_id : DefId = rust_id.sinto(s);
// The generics is meant to be converted to a function signature. Note
// that Rustc does its job: the PolyFnSig binds the captured local
// type, regions, etc. variables, which means we can treat the local
// closure like any top-level function.
#[custom_arm(rustc_middle::mir::AggregateKind::Closure(def_id, generics) => {
let closure = generics.as_closure();
let sig = closure.sig().sinto(s);
// Solve the predicates from the parent (i.e., the item which defines the closure).
let tcx = s.base().tcx;
let parent_generics = closure.parent_args();
let parent_generics_ref = tcx.mk_args(parent_generics);
// TODO: does this handle nested closures?
let parent = tcx.generics_of(rust_id).parent.unwrap();
let trait_refs = solve_item_required_traits(s, parent, parent_generics_ref);
AggregateKind::Closure(def_id, parent_generics.sinto(s), trait_refs, sig)
let args = ClosureArgs::sfrom(s, *def_id, closure);
AggregateKind::Closure(def_id.sinto(s), args)
})]
Closure(DefId, Vec<GenericArg>, Vec<ImplExpr>, PolyFnSig),
Closure(DefId, ClosureArgs),
Coroutine(DefId, Vec<GenericArg>),
CoroutineClosure(DefId, Vec<GenericArg>),
RawPtr(Ty, Mutability),
Expand Down Expand Up @@ -1208,7 +1222,6 @@ sinto_todo!(rustc_middle::mir, UserTypeProjection);
sinto_todo!(rustc_middle::mir, MirSource<'tcx>);
sinto_todo!(rustc_middle::mir, CoroutineInfo<'tcx>);
sinto_todo!(rustc_middle::mir, VarDebugInfo<'tcx>);
sinto_todo!(rustc_middle::mir, CallSource);
sinto_todo!(rustc_middle::mir, UnwindTerminateReason);
sinto_todo!(rustc_middle::mir::coverage, CoverageKind);
sinto_todo!(rustc_middle::mir::interpret, ConstAllocation<'a>);
72 changes: 67 additions & 5 deletions frontend/exporter/src/types/new/full_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ pub enum FullDefKind<Body> {
inline: InlineAttr,
#[value(s.base().tcx.constness(s.owner_id()) == rustc_hir::Constness::Const)]
is_const: bool,
#[value(s.base().tcx.fn_sig(s.owner_id()).instantiate_identity().sinto(s))]
#[value(get_method_sig(s).sinto(s))]
sig: PolyFnSig,
#[value(s.owner_id().as_local().map(|ldid| Body::body(ldid, s)))]
body: Option<Body>,
Expand All @@ -271,10 +271,8 @@ pub enum FullDefKind<Body> {
is_const: bool,
#[value({
let fun_type = s.base().tcx.type_of(s.owner_id()).instantiate_identity();
match fun_type.kind() {
ty::TyKind::Closure(_, args) => args.as_closure().sinto(s),
_ => unreachable!(),
}
let ty::TyKind::Closure(_, args) = fun_type.kind() else { unreachable!() };
ClosureArgs::sfrom(s, s.owner_id(), args.as_closure())
})]
args: ClosureArgs,
},
Expand Down Expand Up @@ -769,6 +767,70 @@ where
}
}

/// The signature of a method impl may be a subtype of the one expected from the trait decl, as in
/// the example below. For correctness, we must be able to map from the method generics declared in
/// the trait to the actual method generics. Because this would require type inference, we instead
/// simply return the declared signature. This will cause issues if it is possible to use such a
/// more-specific implementation with its more-specific type, but we have a few other issues with
/// lifetime-generic function pointers anyway so this is unlikely to cause problems.
///
/// ```ignore
/// trait MyCompare<Other>: Sized {
/// fn compare(self, other: Other) -> bool;
/// }
/// impl<'a> MyCompare<&'a ()> for &'a () {
/// // This implementation is more general because it works for non-`'a` refs. Note that only
/// // late-bound vars may differ in this way.
/// // `<&'a () as MyCompare<&'a ()>>::compare` has type `fn<'b>(&'a (), &'b ()) -> bool`,
/// // but type `fn(&'a (), &'a ()) -> bool` was expected from the trait declaration.
/// fn compare<'b>(self, _other: &'b ()) -> bool {
/// true
/// }
/// }
/// ```
#[cfg(feature = "rustc")]
fn get_method_sig<'tcx, S>(s: &S) -> ty::PolyFnSig<'tcx>
where
S: UnderOwnerState<'tcx>,
{
let tcx = s.base().tcx;
let def_id = s.owner_id();
let real_sig = tcx.fn_sig(def_id).instantiate_identity();
let item = tcx.associated_item(def_id);
if !matches!(item.container, ty::AssocItemContainer::ImplContainer) {
return real_sig;
}
let Some(decl_method_id) = item.trait_item_def_id else {
return real_sig;
};
let declared_sig = tcx.fn_sig(decl_method_id);

// TODO(Nadrieril): Temporary hack: if the signatures have the same number of bound vars, we
// keep the real signature. While the declared signature is more correct, it is also less
// normalized and we can't normalize without erasing regions but regions are crucial in
// function signatures. Hence we cheat here, until charon gains proper normalization
// capabilities.
if declared_sig.skip_binder().bound_vars().len() == real_sig.bound_vars().len() {
return real_sig;
}

let impl_def_id = item.container_id(tcx);
// The trait predicate that is implemented by the surrounding impl block.
let implemented_trait_ref = tcx
.impl_trait_ref(impl_def_id)
.unwrap()
.instantiate_identity();
// Construct arguments for the declared method generics in the context of the implemented
// method generics.
let impl_args = ty::GenericArgs::identity_for_item(tcx, def_id);
let decl_args = impl_args.rebase_onto(tcx, impl_def_id, implemented_trait_ref.args);
let sig = declared_sig.instantiate(tcx, decl_args);
// Avoids accidentally using the same lifetime name twice in the same scope
// (once in impl parameters, second in the method declaration late-bound vars).
let sig = tcx.anonymize_bound_vars(sig);
sig
}

#[cfg(feature = "rustc")]
fn get_ctor_contents<'tcx, S, Body>(s: &S, ctor_of: CtorOf) -> FullDefKind<Body>
where
Expand Down
Loading

0 comments on commit 4d7cbff

Please sign in to comment.