Skip to content

Commit

Permalink
Introduce implicit resolution pass
Browse files Browse the repository at this point in the history
  • Loading branch information
FlandiaYingman committed Jan 14, 2025
1 parent 4fb2937 commit ca68d06
Show file tree
Hide file tree
Showing 15 changed files with 647 additions and 161 deletions.
21 changes: 20 additions & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/MLsDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import mlscript.utils.*, shorthands.*
import utils.*

import hkmc2.semantics.Elaborator
import hkmc2.semantics.ImplicitResolver


abstract class MLsDiffMaker extends DiffMaker:
Expand Down Expand Up @@ -34,11 +35,14 @@ abstract class MLsDiffMaker extends DiffMaker:
val silent = NullaryCommand("silent")
val dbgElab = NullaryCommand("de")
val dbgParsing = NullaryCommand("dp")
val dbgResolving = NullaryCommand("dr")

val showParse = NullaryCommand("p")
val showParsedTree = DebugTreeCommand("pt")
val showElab = NullaryCommand("el")
val showElaboratedTree = DebugTreeCommand("elt")
val showResolve = NullaryCommand("r")
val showResolvedTree = DebugTreeCommand("rt")
val showLoweredTree = NullaryCommand("lot")
val ppLoweredTree = NullaryCommand("slot")
val showContext = NullaryCommand("ctx")
Expand All @@ -56,6 +60,7 @@ abstract class MLsDiffMaker extends DiffMaker:
override def dbg: Bool =
dbgParsing.isSet
|| dbgElab.isSet
|| dbgResolving.isSet
|| debug.isSet

val etl = new TraceLogger:
Expand All @@ -68,9 +73,13 @@ abstract class MLsDiffMaker extends DiffMaker:
// * Perhaps this should be the default behavior of TraceLogger.
if doTrace then super.trace(pre, post)(thunk)
else thunk

val rtl = new TraceLogger:
override def doTrace = dbgResolving.isSet
override def emitDbg(str: String): Unit = output(str)

var curCtx = Elaborator.State.init

var curICtx = ImplicitResolver.ICtx.empty

override def run(): Unit =
if file =/= preludeFile then importFile(preludeFile, verbose = false)
Expand Down Expand Up @@ -187,6 +196,16 @@ abstract class MLsDiffMaker extends DiffMaker:
showElaboratedTree.get.foreach: post =>
output(s"Elaborated tree:")
output(e.showAsTree(using post))

val resolver = ImplicitResolver(rtl)
curICtx = resolver.resolveBlk(e)(using curICtx)

if showResolve.isSet then
output(s"Resolved: ${e.showDbg}")
showResolvedTree.get.foreach: post =>
output(s"Resolved tree:")
output(e.showAsTree(using post))

processTerm(e, inImport = false)


Expand Down
6 changes: 3 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,13 @@ class BBTyper(using elState: Elaborator.State, tl: TL, scope: Scope):
effBuff += eff
ctx += sym -> rhsTy
goStats(stats)
case TermDefinition(_, Fun, sym, ps :: Nil, sig, Some(body), _, _, _) :: stats =>
case TermDefinition(_, Fun, sym, ps :: Nil, _, sig, Some(body), _, _, _) :: stats =>
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
goStats(stats)
case TermDefinition(_, Fun, sym, Nil, sig, Some(body), _, _, _) :: stats =>
case TermDefinition(_, Fun, sym, Nil, _, sig, Some(body), _, _, _) :: stats =>
typeFunDef(sym, body, sig, ctx) // * may be a case expressions
goStats(stats)
case TermDefinition(_, Fun, sym, _, S(sig), None, _, _, _) :: stats =>
case TermDefinition(_, Fun, sym, _, _, S(sig), None, _, _, _) :: stats =>
ctx += sym -> typeType(sig)
goStats(stats)
case (clsDef: ClassDef) :: stats =>
Expand Down
15 changes: 15 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ class Lowering(using TL, Raise, Elaborator.State):
case sem.Fld(flags, value, asc) =>
TODO("Other argument forms")
case spd: Spd => true -> spd.term
case ca: sem.CtxArg => ca.term match
case S(t) =>
false -> t
case N =>
// All contextual arguments should have been
// populated by implicit resolution before lowering.
// Fail silently.
false -> Term.Error
val l = new TempSymbol(S(t))
def rec(as: Ls[Bool -> st], asr: Ls[Arg]): Block = as match
case Nil => k(Call(fr, asr.reverse)(isMlsFun))
Expand All @@ -103,6 +111,7 @@ class Lowering(using TL, Raise, Elaborator.State):
subTerm(prefix): p =>
conclude(Select(p, nme)(sel.sym))
case _ => subTerm(f)(conclude)
case st.TyApp(lhs, _) => term(lhs)(k)
case st.Blk(Nil, res) => term(res)(k)
case st.Blk(Lit(Tree.UnitLit(true)) :: stats, res) =>
subTerm(st.Blk(stats, res))(k)
Expand Down Expand Up @@ -131,6 +140,12 @@ class Lowering(using TL, Raise, Elaborator.State):
val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme))
Define(FunDefn(td.sym, paramLists, bodyBlock),
term(st.Blk(stats, res))(k))
case syntax.Ins =>
// Implciit instances are not parameterized for now.
assert(td.params.isEmpty)
subTerm(bod)(r =>
Define(ValDefn(td.owner, syntax.ImmutVal, td.sym, r),
term(st.Blk(stats, res))(k)))
// case cls: ClassDef =>
case cls: ClassLikeDef =>
reportAnnotations(cls, cls.annotations)
Expand Down
147 changes: 67 additions & 80 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ extends Importer:
else Term.SynthSel(trm, Ident("class"))(mem.clsTree.orElse(mem.modTree).map(_.symbol))
case _ => trm

def term(tree: Tree, inAppPrefix: Bool = false): Ctxl[Term] =
def term(tree: Tree, inAppPrefix: Bool = false, inTyAppPrefix: Bool = false): Ctxl[Term] =
trace[Term](s"Elab term ${tree.showDbg}", r => s"~> $r"):
def maybeModuleMethodApp(t: Term): Ctxl[Term] =
if !inAppPrefix then moduleMethodApp(t)
def maybeApp(t: Term): Ctxl[Term] =
// !inAppPrefix && !inTyAppPrefix is to ensure that nested App/TyApp are only wrapped once.
if !inAppPrefix && !inTyAppPrefix
then maybeModuleMethodApp(t)
else t
tree.desugared match
case Bra(k, e) =>
Expand Down Expand Up @@ -250,8 +252,8 @@ extends Importer:
case N =>
raise(ErrorReport(msg"Name not found: $name" -> tree.toLoc :: Nil))
Term.Error
case TyApp(lhs, targs) =>
Term.TyApp(term(lhs), targs.map {
case TyApp(lhs, targs) => maybeApp:
Term.TyApp(term(lhs, inTyAppPrefix = true), targs.map {
case Modified(Keyword.`in`, inLoc, arg) => Term.WildcardTy(S(term(arg)), N)
case Modified(Keyword.`out`, outLoc, arg) => Term.WildcardTy(N, S(term(arg)))
case Tup(Modified(Keyword.`in`, inLoc, arg1) :: Modified(Keyword.`out`, outLoc, arg2) :: Nil) =>
Expand Down Expand Up @@ -379,15 +381,18 @@ extends Importer:
msg"Only module parameters may receive module arguments (values)." ->
arg.toLoc :: Nil

maybeModuleMethodApp(Term.App(lt, rt)(tree, sym))
maybeApp:
Term.App(lt, rt)(tree, sym)
case SynthSel(pre, nme) =>
val preTrm = term(pre)
val sym = resolveField(nme, preTrm.symbol, nme)
maybeModuleMethodApp(Term.SynthSel(preTrm, nme)(sym))
maybeApp:
Term.SynthSel(preTrm, nme)(sym)
case Sel(pre, nme) =>
val preTrm = term(pre)
val sym = resolveField(nme, preTrm.symbol, nme)
maybeModuleMethodApp(Term.Sel(preTrm, nme)(sym))
maybeApp:
Term.Sel(preTrm, nme)(sym)
case MemberProj(ct, nme) =>
val c = cls(ct, inAppPrefix = false)
val f = c.symbol.flatMap(_.asCls) match
Expand Down Expand Up @@ -524,48 +529,53 @@ extends Importer:
// ???

/** Module method applications that require further elaboration with type information. */
def moduleMethodApp(t: Term): Ctxl[Term] =
trace[Term](s"Elab module method ${t.showDbg}", r => s"~> $r"):
/** Returns the module method definition of the innermost Sel wrapped by some App and TyApp. */
def defn(t: Term, argLists: Ls[Term]): (Opt[Definition], Term, Ls[Term]) = t match
case Term.App(f, r) => defn(f, r :: argLists)
case Term.TyApp(f, _) => defn(f, argLists)
case Term.SynthSel(pre, nme) if ModuleChecker.evalsToModule(pre) =>
(t.symbol.flatMap(_.asBlkMember).flatMap(_.defn), t, argLists)
case Term.Sel(pre, nme) if ModuleChecker.evalsToModule(pre) =>
(t.symbol.flatMap(_.asBlkMember).flatMap(_.defn), t, argLists)
case _ => (N, t, argLists)
defn(t, Nil) match
case (S(defn: TermDefinition), inner, argLists) =>
log(s"Elab module method definition w/ type information ${defn}.")
val emptyTreeApp = new Tree.App(Tree.Empty(), Tree.Empty())
/**
* Zips a module method application term along with its parameter lists,
* inserting any missing contextual argument lists.
*
* M.foo -> M.foo(<using> ...)
* M.foo(a, b) -> M.foo(<using> ...)(a, b)(<using> ...)
*
* Note: This *doesn't* handle explicit contextual arguments.
*/
def zip(t: Term, paramLists: Ls[ParamList]): Term = (t, paramLists) match
case (_, params :: pRest) if params.flags.ctx =>
log(s"Insert a missing contextual argument list for ${params}")
val args = Term.Tup(params.params.map(CtxArgImpl(_)))(Tree.Tup(Nil))
Term.App(zip(t, pRest), args)(emptyTreeApp, FlowSymbol("‹app-res›"))
case (t @ Term.App(lhs, rhs), params :: pRest) =>
// Match the outermost App with the next non-contextual parameter list.
Term.App(zip(lhs, pRest), rhs)(t.tree, t.resSym)
case (t, params :: pRest) =>
// LHS is not a app but it still expects more param lists - a partial application.
// Just suppose it is legal and don't fail here.
// TODO: Check in the implicit resolver.
t
case (_, Nil) => t
val newTerm = zip(t, defn.params.reverse)
log(s"Zip module method application: ${newTerm}")
newTerm
case _ =>
def maybeModuleMethodApp(t: Term): Ctxl[Term] =
// * Some function definitions might not be fully elaborated yet.
// * We need to do some very lightweight elaboration here.
case class Param(ctx: Bool)(tree: Tree)
case class ParamList(ps: Ls[Param], ctx: Bool)(tree: Tree)
def param(tree: Tree): Param = tree match
case Tree.Modified(Keyword.`using`, _, tree) => Param(true)(tree)
case _ => Param(false)(tree)
def paramList(tree: Tree.Tup): ParamList =
val ps = tree.fields.map(param)
ParamList(ps, ps.exists(_.ctx))(tree)

/**
* Zips a module method application term along with its parameter lists,
* inserting any missing contextual argument lists.
*
* M.foo -> M.foo(<using> ...)
* M.foo(a, b) -> M.foo(<using> ...)(a, b)(<using> ...)
*
* Note: This *doesn't* handle explicit contextual arguments.
*/
def zip(t: Term, paramLists: Ls[ParamList]): Term = (t, paramLists) match
case (t, ps :: pss) if ps.ctx =>
val appTree = new Tree.App(Tree.Empty(), Tree.Empty())
val tupTree = new Tree.Tup(Nil)
val args = Term.Tup(ps.ps.map(_ => CtxArgImpl()))(tupTree)
Term.App(zip(t, pss), args)(appTree, FlowSymbol("‹app-res›"))
case (t @ Term.App(lhs, rhs), ps :: pss) =>
Term.App(zip(lhs, pss), rhs)(t.tree, t.resSym)
case (t, params :: pRest) =>
// LHS is not a app but it still expects more param lists - a partial application.
// Just suppose it is legal and don't fail here.
// TODO: Check in the implicit resolver.
t
case (_, Nil) => t

t match
// M.f[T](foo)(bar)
case semantics.Apps(Term.TyApp(ModuleChecker.MethodTreeDef(tree), _), argss) =>
trace[Term](s"Elab module method application ${t.showDbg}", r => s"~> $r"):
zip(t, tree.paramLists.map(paramList).reverse)
// M.f(foo)(bar)
case semantics.Apps(ModuleChecker.MethodTreeDef(tree), argss) =>
trace[Term](s"Elab module method application ${t.showDbg}", r => s"~> $r"):
zip(t, tree.paramLists.map(paramList).reverse)
// Not a module method application.
case _ =>
t

def fld(tree: Tree): Ctxl[Elem] = tree match
Expand Down Expand Up @@ -725,10 +735,10 @@ extends Importer:
case trm => raise(WarningReport(msg"Terms in handler block do nothing" -> trm.toLoc :: Nil))

val tds = elabed.stats.map {
case td @ TermDefinition(owner, Fun, sym, params, sign, body, resSym, flags, annotations) =>
case td @ TermDefinition(owner, Fun, sym, params, tparams, sign, body, resSym, flags, annotations) =>
params.reverse match
case ParamList(_, value :: Nil, _) :: newParams =>
val newTd = TermDefinition(owner, Fun, sym, newParams.reverse, sign, body, resSym, flags, annotations)
val newTd = TermDefinition(owner, Fun, sym, newParams.reverse, tparams, sign, body, resSym, flags, annotations)
S(HandlerTermDefinition(value.sym, newTd))
case _ =>
raise(ErrorReport(msg"Handler function is missing resumption parameter" -> td.toLoc :: Nil))
Expand Down Expand Up @@ -778,7 +788,9 @@ extends Importer:
val tdf = ctx.nest(N).givenIn:
// * Add type parameters to context
val (tps, newCtx1) = td.typeParams match
case S(t) => typeParams(t)
case S(t) =>
val (tps, ctx) = typeParams(t)
(S(tps), ctx)
case N => (N, ctx)
// * Add parameters to context
val (pss, newCtx) =
Expand All @@ -801,7 +813,7 @@ extends Importer:
case Nil if k is Fun =>
ParamList(ParamListFlags.empty, Nil, N) :: Nil
case _ => pss
val tdf = TermDefinition(owner, k, sym, real_pss, s, b, r,
val tdf = TermDefinition(owner, k, sym, real_pss, tps, s, b, r,
TermDefFlags.empty.copy(isModMember = isModMember), annotations)
sym.defn = S(tdf)

Expand Down Expand Up @@ -1043,7 +1055,7 @@ extends Importer:
def computeVariances(s: Statement): Unit =
val trav = VarianceTraverser()
def go(s: Statement): Unit = s match
case TermDefinition(_, k, sym, pss, sign, body, r, _, _) =>
case TermDefinition(_, k, sym, pss, _, sign, body, r, _, _) =>
pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false))))
sign.foreach(trav.traverseType(S(true)))
body match
Expand All @@ -1062,31 +1074,6 @@ extends Importer:
trav.changed = false
go(s)

object ModuleChecker:

/** Checks if a term is a reference to a type parameter. */
def isTypeParam(t: Term): Bool = t.symbol
.filter(_.isInstanceOf[VarSymbol])
.flatMap(_.asInstanceOf[VarSymbol].decl)
.exists(_.isInstanceOf[TyParam])

/** Checks if a term evaluates to a module value. */
def evalsToModule(t: Term): Bool =
def isModule(t: Tree): Bool = t match
case TypeDef(Mod, _, _, _) => true
case _ => false
def returnsModule(t: TermDef): Bool = t.annotatedResultType match
case S(TypeDef(Mod, _, N, N)) => true
case _ => false
t match
case Term.Blk(_, res) => evalsToModule(res)
case Term.App(lhs, rhs) => lhs.symbol match
case S(sym: BlockMemberSymbol) => sym.trmTree.exists(returnsModule)
case _ => false
case t => t.symbol match
case S(sym: BlockMemberSymbol) => sym.modTree.exists(isModule)
case _ => false

class VarianceTraverser(var changed: Bool = true) extends Traverser:
override def traverseType(pol: Pol)(trm: Term): Unit = trm match
case Term.TyApp(lhs, targs) =>
Expand Down
Loading

0 comments on commit ca68d06

Please sign in to comment.