Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
CAG2Mark committed Jan 24, 2025
1 parent cb8b917 commit 1f1ce3f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 48 deletions.
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MLsCompiler(preludeFile: os.Path):
val parsed = mainParse.resultBlk
val (blk, newCtx) = elab.importFrom(parsed)
val low = ltl.givenIn:
codegen.Lowering(lowerHandlers = false, None) // TODO: properly hook up stack limit
codegen.Lowering(lowerHandlers = false, stackLimit = None) // TODO: properly hook up stack limit
val jsb = codegen.js.JSBuilder()
val le = low.program(blk)
val baseScp: utils.Scope =
Expand Down
16 changes: 9 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import semantics.*
import semantics.Term.*
import sem.Elaborator.State

import scala.collection.mutable.ListBuffer

case class Program(
imports: Ls[Local -> Str],
main: Block,
Expand Down Expand Up @@ -101,20 +99,24 @@ sealed abstract class Block extends Product with AutoLocated:
case _: Return | _: Throw | _: Label | _: Break | _: Continue | _: End | _: HandleBlockReturn => Nil

// Moves definitions in a block to the top. Only scans the top-level definitions of the block;
// i.e, definitions inside other definitions are not moved out. Definitions inside `if` and
// `while` statements are moved out.
// i.e, definitions inside other definitions are not moved out. Definitions inside `match`/`if`
// and `while` statements are moved out.
//
// Note that this returns the definitions in reverse order, with the bottommost definiton appearing
// last. This is so that using defns.foldLeft later to add the definitions to the front of a block,
// we don't need to reverse the list again to preserve the order of the definitions.
def floatOutDefns =
val defns = ListBuffer[Defn]()
var defns: List[Defn] = Nil
val transformer = new BlockTransformerShallow(SymbolSubst()):
override def applyBlock(b: Block): Block = b match
case Define(defn, rest) => defn match
case v: ValDefn => super.applyBlock(b)
case _ =>
defns.addOne(defn)
defns ::= defn
applyBlock(rest)
case _ => super.applyBlock(b)

(transformer.applyBlock(this), defns.reverse.toList)
(transformer.applyBlock(this), defns)

end Block

Expand Down
15 changes: 9 additions & 6 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,25 +288,28 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
case b: HandleBlock =>
val rest = applyBlock(b.rest)
translateHandleBlock(b.copy(rest = rest))
// This block optimizes tail-calls in the handler transformation. We do not optimize tail-calls
// on the top-level since it does not make sense. This optimization also prevents the
// "throw 'Unhandled effects'" code from being added at the top level.
case Return(c @ Call(fun, args), implct) if handlerCtx.isHandleFree && !handlerCtx.isTopLevel =>
val fun2 = applyPath(fun)
val args2 = args.map(applyArg)
val c2 = if (fun2 is fun) && (args2 zip args).forall(_ is _) then c else Call(fun2, args2)(c.isMlsFun)
val args2 = args.mapConserve(applyArg)
val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun)
if c2 is c then b else Return(c2, implct)
case _ => super.applyBlock(b)
override def applyResult2(r: Result)(k: Result => Block): Block = r match
case r @ Call(Value.Ref(_: BuiltinSymbol), _) => super.applyResult2(r)(k)
case c @ Call(fun, args) =>
val res = freshTmp("res")
val fun2 = applyPath(fun)
val args2 = args.map(applyArg)
val c2 = if (fun2 is fun) && (args2 zip args).forall(_ is _) then c else Call(fun2, args2)(c.isMlsFun)
val args2 = args.mapConserve(applyArg)
val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun)
ResultPlaceholder(res, freshId(), false, c2, k(Value.Ref(res)))
case c @ Instantiate(cls, args) =>
val res = freshTmp("res")
val cls2 = applyPath(cls)
val args2 = args.map(applyPath)
val c2 = if (cls2 is cls) && (args2 zip args).forall(_ is _) then c else Instantiate(cls2, args2)
val args2 = args.mapConserve(applyPath)
val c2 = if (cls2 is cls) && (args2 is args) then c else Instantiate(cls2, args2)
ResultPlaceholder(res, freshId(), false, c2, k(Value.Ref(res)))
case r => super.applyResult2(r)(k)
override def applyLam(lam: Value.Lam): Value.Lam = Value.Lam(lam.params, translateBlock(lam.body, functionHandlerCtx))
Expand Down
32 changes: 16 additions & 16 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StackSafeTransform(depthLimit: Int)(using State):
private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n))

private def op(op: String, a: Path, b: Path) =
Call(State.builtinOpsMap(op).asPath, List(a.asArg, b.asArg))(true)
Call(State.builtinOpsMap(op).asPath, a.asArg :: b.asArg :: Nil)(true)

// Increases the stack depth, assigns the call to a value, then decreases the stack depth
// then binds that value to a desired block
Expand Down Expand Up @@ -59,8 +59,8 @@ class StackSafeTransform(depthLimit: Int)(using State):
HandleBlock(
handlerSym, resSym,
stackDelayClsPath, clsSym,
List(Handler(
BlockMemberSymbol("perform", Nil), resumeSym, List(ParamList(ParamListFlags.empty, Nil, N)),
Handler(
BlockMemberSymbol("perform", Nil), resumeSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil,
/*
fun perform() =
let curOffset = stackOffset
Expand All @@ -72,10 +72,10 @@ class StackSafeTransform(depthLimit: Int)(using State):
blockBuilder
.assign(curOffsetSym, stackOffsetPath)
.assignFieldN(predefPath, STACK_OFFSET_IDENT, stackDepthPath)
.assign(handlerRes, Call(Value.Ref(resumeSym), List())(true))
.assign(handlerRes, Call(Value.Ref(resumeSym), Nil)(true))
.assignFieldN(predefPath, STACK_OFFSET_IDENT, curOffsetSym.asPath)
.ret(handlerRes.asPath)
)),
) :: Nil,
blockBuilder
.assignFieldN(predefPath, STACK_LIMIT_IDENT, intLit(depthLimit)) // set stackLimit before call
.assignFieldN(predefPath, STACK_DEPTH_IDENT, intLit(1)) // set stackDepth = 1 before call
Expand Down Expand Up @@ -145,22 +145,22 @@ class StackSafeTransform(depthLimit: Int)(using State):
newBody
else
val diffSym = TempSymbol(None, "diff")
val scrut1Sym = TempSymbol(None, "scrut1")
val scrut2Sym = TempSymbol(None, "scrut2")
val diffGeqLimitSym = TempSymbol(None, "diffGeqLimit")
val handlerExistsSym = TempSymbol(None, "handlerExists")
val scrutSym = TempSymbol(None, "scrut")
val diff = op("-", stackDepthPath, stackOffsetPath)
val scrut1 = op(">=", diffSym.asPath, stackLimitPath)
val scrut2 = op("!==", stackHandlerPath, Value.Lit(Tree.UnitLit(false)))
val scrutVal = op("&&", scrut1Sym.asPath, scrut2Sym.asPath)
val diffGeqLimit = op(">=", diffSym.asPath, stackLimitPath)
val handlerExists = op("!==", stackHandlerPath, Value.Lit(Tree.UnitLit(false)))
val scrutVal = op("&&", diffGeqLimitSym.asPath, handlerExistsSym.asPath)
blockBuilder
.assign(diffSym, diff) // diff = stackDepth - stackOffset
.assign(scrut1Sym, scrut1) // diff >= depthLimit
.assign(scrut2Sym, scrut2) // stackHandler !== null
.assign(scrutSym, scrutVal) // diff >= depthLimit && stackHandler !== null
.assign(diffSym, diff) // diff = stackDepth - stackOffset
.assign(diffGeqLimitSym, diffGeqLimit) // diff >= depthLimit
.assign(handlerExistsSym, handlerExists) // stackHandler !== null
.assign(scrutSym, scrutVal) // diff >= depthLimit && stackHandler !== null
.ifthen(
scrutSym.asPath, Case.Lit(Tree.BoolLit(true)),
blockBuilder.assign( // tmp = perform(undefined)
TempSymbol(None, "tmp"),
blockBuilder.assign( // dummy = perform(undefined) (is called `dummy` as the value is not used)
TempSymbol(None, "dummy"),
Call(Select(stackHandlerPath, Tree.Ident("perform"))(N), Nil)(true)).end)
.rest(newBody)

Expand Down
36 changes: 18 additions & 18 deletions hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ hi(0)
//│ JS (unsanitized):
//│ let hi1, res, handleBlock$1;
//│ hi1 = function hi(n) {
//│ let scrut, tmp, diff, scrut1, scrut2, scrut3, tmp1, res1, Cont$;
//│ let scrut, tmp, diff, diffGeqLimit, handlerExists, scrut1, dummy, res1, Cont$;
//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); };
//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class {
//│ constructor(pc) {
//│ let tmp2;
//│ tmp2 = super(null, null);
//│ let tmp1;
//│ tmp1 = super(null, null);
//│ this.pc = pc;
//│ }
//│ resume(value$) {
Expand All @@ -49,7 +49,7 @@ hi(0)
//│ } else if (this.pc === 1) {
//│ break contLoop;
//│ } else if (this.pc === 0) {
//│ tmp1 = res1;
//│ dummy = res1;
//│ this.pc = 2;
//│ continue contLoop;
//│ }
Expand All @@ -59,17 +59,17 @@ hi(0)
//│ toString() { return "Cont$(" + this.pc + ")"; }
//│ };
//│ diff = globalThis.Predef.__stackDepth - globalThis.Predef.__stackOffset;
//│ scrut1 = diff >= globalThis.Predef.__stackLimit;
//│ scrut2 = globalThis.Predef.__stackHandler !== undefined;
//│ scrut3 = scrut1 && scrut2;
//│ if (scrut3 === true) {
//│ diffGeqLimit = diff >= globalThis.Predef.__stackLimit;
//│ handlerExists = globalThis.Predef.__stackHandler !== undefined;
//│ scrut1 = diffGeqLimit && handlerExists;
//│ if (scrut1 === true) {
//│ res1 = globalThis.Predef.__stackHandler.perform();
//│ if (res1 instanceof globalThis.Predef.__EffectSig.class) {
//│ res1.tail.next = new Cont$.class(0);
//│ res1.tail = res1.tail.next;
//│ return res1;
//│ }
//│ tmp1 = res1;
//│ dummy = res1;
//│ }
//│ scrut = n == 0;
//│ if (scrut === true) {
Expand Down Expand Up @@ -179,12 +179,12 @@ sum(10000)
//│ JS (unsanitized):
//│ let sum3, res1, handleBlock$3;
//│ sum3 = function sum(n) {
//│ let scrut, tmp, tmp1, tmp2, prevDepth, diff, scrut1, scrut2, scrut3, tmp3, res2, res3, Cont$;
//│ let scrut, tmp, tmp1, tmp2, prevDepth, diff, diffGeqLimit, handlerExists, scrut1, dummy, res2, res3, Cont$;
//│ Cont$ = function Cont$(pc1) { return new Cont$.class(pc1); };
//│ Cont$.class = class Cont$ extends globalThis.Predef.__Cont.class {
//│ constructor(pc) {
//│ let tmp4;
//│ tmp4 = super(null, null);
//│ let tmp3;
//│ tmp3 = super(null, null);
//│ this.pc = pc;
//│ }
//│ resume(value$) {
Expand Down Expand Up @@ -221,7 +221,7 @@ sum(10000)
//│ tmp1 = tmp2;
//│ return n + tmp1;
//│ } else if (this.pc === 0) {
//│ tmp3 = res2;
//│ dummy = res2;
//│ this.pc = 3;
//│ continue contLoop;
//│ }
Expand All @@ -231,17 +231,17 @@ sum(10000)
//│ toString() { return "Cont$(" + this.pc + ")"; }
//│ };
//│ diff = globalThis.Predef.__stackDepth - globalThis.Predef.__stackOffset;
//│ scrut1 = diff >= globalThis.Predef.__stackLimit;
//│ scrut2 = globalThis.Predef.__stackHandler !== undefined;
//│ scrut3 = scrut1 && scrut2;
//│ if (scrut3 === true) {
//│ diffGeqLimit = diff >= globalThis.Predef.__stackLimit;
//│ handlerExists = globalThis.Predef.__stackHandler !== undefined;
//│ scrut1 = diffGeqLimit && handlerExists;
//│ if (scrut1 === true) {
//│ res2 = globalThis.Predef.__stackHandler.perform();
//│ if (res2 instanceof globalThis.Predef.__EffectSig.class) {
//│ res2.tail.next = new Cont$.class(0);
//│ res2.tail = res2.tail.next;
//│ return res2;
//│ }
//│ tmp3 = res2;
//│ dummy = res2;
//│ }
//│ scrut = n == 0;
//│ if (scrut === true) {
Expand Down

0 comments on commit 1f1ce3f

Please sign in to comment.