Skip to content

Commit

Permalink
Refactor core for optimizer (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-studios authored Nov 14, 2024
1 parent d2c9547 commit 7650914
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 404 deletions.
2 changes: 1 addition & 1 deletion effekt/shared/src/main/scala/effekt/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package effekt

import effekt.PhaseResult.{ AllTransformed, CoreTransformed }
import effekt.context.Context
import effekt.core.{ DirectStyleMutableState, Transformer }
import effekt.core.Transformer
import effekt.namer.Namer
import effekt.source.{ AnnotateCaptures, ExplicitCapabilities, ResolveExternDefs, ModuleDecl }
import effekt.symbols.Module
Expand Down
141 changes: 0 additions & 141 deletions effekt/shared/src/main/scala/effekt/core/Deadcode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,144 +36,3 @@ object Deadcode {
remove(Set(entrypoint), m)
}

/**
* A simple reachability analysis for toplevel definitions
*
* TODO this could also be extended to cover record and interface declarations.
*/
class Reachable(
var reachable: Map[Id, Usage],
var stack: List[Id],
var seen: Set[Id]
) {

def within(id: Id)(f: => Unit): Unit = {
stack = id :: stack
f
stack = stack.tail
}

def process(d: Definition)(using defs: Map[Id, Definition]): Unit =
if stack.contains(d.id) then
reachable = reachable.updated(d.id, Usage.Recursive)
else d match {
case Definition.Def(id, block) =>
seen = seen + id
within(id) { process(block) }

case Definition.Let(id, _, binding) =>
seen = seen + id
process(binding)
}

def process(id: Id)(using defs: Map[Id, Definition]): Unit =
if (stack.contains(id)) {
reachable = reachable.updated(id, Usage.Recursive)
return;
}

val count = reachable.get(id) match {
case Some(Usage.Once) => Usage.Many
case Some(Usage.Many) => Usage.Many
case Some(Usage.Recursive) => Usage.Recursive
case None => Usage.Once
}
reachable = reachable.updated(id, count)
if (!seen.contains(id)) {
defs.get(id).foreach(process)
}

def process(b: Block)(using defs: Map[Id, Definition]): Unit =
b match {
case Block.BlockVar(id, annotatedTpe, annotatedCapt) => process(id)
case Block.BlockLit(tparams, cparams, vparams, bparams, body) => process(body)
case Block.Member(block, field, annotatedTpe) => process(block)
case Block.Unbox(pure) => process(pure)
case Block.New(impl) => process(impl)
}

def process(s: Stmt)(using defs: Map[Id, Definition]): Unit = s match {
case Stmt.Scope(definitions, body) =>
var currentDefs = defs
definitions.foreach {
case d: Definition.Def =>
currentDefs += d.id -> d // recursive
process(d)(using currentDefs)
case d: Definition.Let =>
process(d)(using currentDefs)
currentDefs += d.id -> d // non-recursive
}
process(body)(using currentDefs)
case Stmt.Return(expr) => process(expr)
case Stmt.Val(id, tpe, binding, body) => process(binding); process(body)
case Stmt.App(callee, targs, vargs, bargs) =>
process(callee)
vargs.foreach(process)
bargs.foreach(process)
case Stmt.If(cond, thn, els) => process(cond); process(thn); process(els)
case Stmt.Match(scrutinee, clauses, default) =>
process(scrutinee)
clauses.foreach { case (id, value) => process(value) }
default.foreach(process)
case Stmt.Alloc(id, init, region, body) =>
process(init)
process(region)
process(body)
case Stmt.Var(id, init, capture, body) =>
process(init)
process(body)
case Stmt.Get(id, capt, tpe) => process(id)
case Stmt.Put(id, tpe, value) => process(id); process(value)
case Stmt.Reset(body) => process(body)
case Stmt.Shift(prompt, body) => process(prompt); process(body)
case Stmt.Resume(k, body) => process(k); process(body)
case Stmt.Region(body) => process(body)
case Stmt.Hole() => ()
}

def process(e: Expr)(using defs: Map[Id, Definition]): Unit = e match {
case DirectApp(b, targs, vargs, bargs) =>
process(b);
vargs.foreach(process)
bargs.foreach(process)
case Run(s) => process(s)
case Pure.ValueVar(id, annotatedType) => process(id)
case Pure.Literal(value, annotatedType) => ()
case Pure.PureApp(b, targs, vargs) => process(b); vargs.foreach(process)
case Pure.Make(data, tag, vargs) => process(tag); vargs.foreach(process)
case Pure.Select(target, field, annotatedType) => process(target)
case Pure.Box(b, annotatedCapture) => process(b)
}

def process(i: Implementation)(using defs: Map[Id, Definition]): Unit =
i.operations.foreach { op => process(op.body) }

}

object Reachable {
def apply(entrypoints: Set[Id], definitions: Map[Id, Definition]): Map[Id, Usage] = {
val analysis = new Reachable(Map.empty, Nil, Set.empty)
entrypoints.foreach(d => analysis.process(d)(using definitions))
analysis.reachable
}

def apply(m: ModuleDecl): Map[Id, Usage] = {
val analysis = new Reachable(Map.empty, Nil, Set.empty)
val defs = m.definitions.map(d => d.id -> d).toMap
m.definitions.foreach(d => analysis.process(d)(using defs))
analysis.reachable
}

def apply(s: Stmt.Scope): Map[Id, Usage] = {
val analysis = new Reachable(Map.empty, Nil, Set.empty)
analysis.process(s)(using Map.empty)
analysis.reachable
}
}

enum Usage {
case Once
case Many
case Recursive
}

This file was deleted.

6 changes: 5 additions & 1 deletion effekt/shared/src/main/scala/effekt/core/Inline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ object Inline {
case Stmt.App(b, targs, vargs, bargs) =>
app(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite))

case Stmt.Invoke(b, method, methodTpe, targs, vargs, bargs) =>
invoke(rewrite(b), method, methodTpe, targs, vargs.map(rewrite), bargs.map(rewrite))

case Stmt.Reset(body) =>
rewrite(body) match {
case BlockLit(tparams, cparams, vparams, List(prompt), body) if !used(prompt.id) => body
Expand Down Expand Up @@ -173,7 +176,6 @@ object Inline {

// congruences
case b @ Block.BlockLit(tparams, cparams, vparams, bparams, body) => rewrite(b)
case Block.Member(block, field, annotatedTpe) => member(rewrite(block), field, annotatedTpe)
case Block.Unbox(pure) => unbox(rewrite(pure))
case Block.New(impl) => New(rewrite(impl))
}
Expand Down Expand Up @@ -224,6 +226,7 @@ object Inline {
case Stmt.Return(expr) => false
case Stmt.Val(id, annotatedTpe, binding, body) => tailResumptive(k, body) && !freeInStmt(binding)
case Stmt.App(callee, targs, vargs, bargs) => false
case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => false
case Stmt.If(cond, thn, els) => !freeInExpr(cond) && tailResumptive(k, thn) && tailResumptive(k, els)
// Interestingly, we introduce a join point making this more difficult to implement properly
case Stmt.Match(scrutinee, clauses, default) => !freeInExpr(scrutinee) && clauses.forall {
Expand Down Expand Up @@ -261,6 +264,7 @@ object Inline {
case Stmt.Hole() => stmt
case Stmt.Return(expr) => stmt
case Stmt.App(callee, targs, vargs, bargs) => stmt
case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => stmt
case Stmt.Get(id, annotatedCapt, annotatedTpe) => stmt
case Stmt.Put(id, annotatedCapt, value) => stmt
}
Expand Down
14 changes: 4 additions & 10 deletions effekt/shared/src/main/scala/effekt/core/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ class CoreParsers(positions: Positions, names: Names) extends EffektLexers(posit
| `val` ~> id ~ maybeTypeAnnotation ~ (`=` ~> stmt) ~ (`;` ~> stmt) ^^ {
case id ~ tpe ~ binding ~ body => Stmt.Val(id, tpe.getOrElse(binding.tpe), binding, body)
}
| block ~ maybeTypeArgs ~ valueArgs ~ blockArgs ^^ Stmt.App.apply
| block ~ (`.` ~> id ~ (`:` ~> blockType)).? ~ maybeTypeArgs ~ valueArgs ~ blockArgs ^^ {
case (recv ~ Some(method ~ tpe) ~ targs ~ vargs ~ bargs) => Invoke(recv, method, tpe, targs, vargs, bargs)
case (recv ~ None ~ targs ~ vargs ~ bargs) => App(recv, targs, vargs, bargs)
}
| (`if` ~> `(` ~/> pure <~ `)`) ~ stmt ~ (`else` ~> stmt) ^^ Stmt.If.apply
| `region` ~> blockLit ^^ Stmt.Region.apply
| `var` ~> id ~ (`in` ~> id) ~ (`=` ~> pure) ~ (`;` ~> stmt) ^^ { case id ~ region ~ init ~ body => Alloc(id, init, region, body) }
Expand Down Expand Up @@ -195,15 +198,6 @@ class CoreParsers(positions: Positions, names: Names) extends EffektLexers(posit
// Blocks
// ------
lazy val block: P[Block] =
( blockNonMember ~ many((`.` ~> id) ~ (`:` ~> blockType)) ^^ {
case firstReceiver ~ fields => fields.foldLeft(firstReceiver) {
case (receiver, field ~ tpe) => Block.Member(receiver, field, tpe)
}
}
| blockNonMember
)

lazy val blockNonMember: P[Block] =
( id ~ (`:` ~> blockType) ~ (`@` ~> captures) ^^ Block.BlockVar.apply
| `unbox` ~> pure ^^ Block.Unbox.apply
| `new` ~> implementation ^^ Block.New.apply
Expand Down
63 changes: 52 additions & 11 deletions effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,6 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] {
case Block.BlockVar(id, annotatedTpe, annotatedCapt) =>
Block.BlockVar(id, transform(annotatedTpe), annotatedCapt)
case b: Block.BlockLit => transform(b)
case Block.Member(block, field, annotatedTpe) => // TODO properly box/unbox arguments and result
val tpe = block.tpe.asInstanceOf[core.BlockType.Interface]
PContext.findInterface(tpe.name) match {
case None => // assume it's an external interface and will be handled elsewhere
Block.Member(transform(block), field, transform(annotatedTpe))
case Some(ifce) =>
val prop = ifce.properties.find { p => p.id == field }.getOrElse { Context.abort(s"Cannot find field ${field} in declaration ${ifce}") }
val propTpe = Type.substitute(prop.tpe.asInstanceOf[BlockType.Function], (ifce.tparams zip (tpe.targs.map(transformArg))).toMap, Map.empty)
val coerce = coercer[Block](transform(propTpe), transform(annotatedTpe))
coerce(Block.Member(transform(block), field, transform(propTpe)))
}
case Block.Unbox(pure) =>
Block.Unbox(transform(pure))
case Block.New(impl) =>
Expand Down Expand Up @@ -277,6 +266,56 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] {
val bcoercers = (tBargs zip itpe.bparams).map { (a, p) => coercer[Block](a.tpe, p) }
val fcoercer = coercer[Block](tpe, itpe, targs)
fcoercer.call(calleeT, (vcoercers zip tVargs).map(_(_)), (bcoercers zip tBargs).map(_(_)))

// [S](S) => (Int, S)
case Stmt.Invoke(callee, method, methodTpe: BlockType.Function, targs, vargs, bargs) =>
// Double

val calleeT = transform(callee)

// [S](S) => (T, S)
val (tpe: BlockType.Function, interfaceParams, interfaceArgs) = calleeT.tpe match {
// [Int]
case BlockType.Interface(name, targs) =>
PContext.findInterface(name) match {
// [T]
case Some(Interface(id, tparams, properties)) =>
// op: [S](S) => (T, S)
val prop = properties.find { p => p.id == method }.getOrElse { Context.panic(pp"Cannot find field ${method} in declaration of ${id}") }

(prop.tpe.asInstanceOf[BlockType.Function], tparams, targs)
case _ =>
Context.panic(pp"Should not happen. Method call on extern interface: ${stmt}")
}
case _ =>
Context.panic("Should not happen. Method call on non-interface")
}

// [S](S) => (BoxedInt, S)
val boxedTpe = Type.substitute(tpe, (interfaceParams zip interfaceArgs.map(transformArg)).toMap, Map.empty).asInstanceOf[BlockType.Function]

// duplicated from App
val itpe = Type.instantiate(methodTpe, targs, methodTpe.cparams.map(Set(_)))
val tVargs = vargs map transform
val tBargs = bargs map transform
val vcoercers = (tVargs zip boxedTpe.vparams).map { (a, p) => coercer(a.tpe, p) }
val bcoercers = (tBargs zip boxedTpe.bparams).map { (a, p) => coercer[Block](a.tpe, p) }
// (T, S) (Int, Double)
val rcoercer = coercer(tpe.result, itpe.result)

val result = Invoke(calleeT, method, boxedTpe, targs.map(transformArg), (vcoercers zip tVargs).map(_(_)), (bcoercers zip tBargs).map(_(_)))

// (BoxedInt, BoxedDouble)
val out = result.tpe
if (rcoercer.isIdentity) {
result
} else {
val orig = TmpValue("result")
Stmt.Val(orig, out, result,
Stmt.Return(rcoercer(Pure.ValueVar(orig, out))))
}
case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => ???

case Stmt.Get(id, capt, tpe) => Stmt.Get(id, capt, transform(tpe))
case Stmt.Put(id, capt, value) => Stmt.Put(id, capt, transform(value))
case Stmt.If(cond, thn, els) =>
Expand Down Expand Up @@ -344,6 +383,8 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] {
case Stmt.Hole() => Stmt.Hole()
}



def transform(expr: Expr)(using PContext): Expr = expr match {
case DirectApp(b, targs, vargs, bargs) =>
val callee = transform(b)
Expand Down
5 changes: 3 additions & 2 deletions effekt/shared/src/main/scala/effekt/core/PrettyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ object PrettyPrinter extends ParenPrettyPrinter {
case BlockVar(id, _, _) => toDoc(id)
case BlockLit(tps, cps, vps, bps, body) =>
braces { space <> paramsToDoc(tps, vps, bps) <+> "=>" <+> nest(line <> toDoc(body)) <> line }
case Member(b, id, _) =>
toDoc(b) <> "." <> id.name.toString
case Unbox(e) => parens("unbox" <+> toDoc(e))
case New(handler) => "new" <+> toDoc(handler)
}
Expand Down Expand Up @@ -181,6 +179,9 @@ object PrettyPrinter extends ParenPrettyPrinter {
case App(b, targs, vargs, bargs) =>
toDoc(b) <> argsToDoc(targs, vargs, bargs)

case Invoke(b, method, methodTpe, targs, vargs, bargs) =>
toDoc(b) <> "." <> method.name.toString <> argsToDoc(targs, vargs, bargs)

case If(cond, thn, els) =>
"if" <+> parens(toDoc(cond)) <+> block(toDoc(thn)) <+> "else" <+> block(toDoc(els))

Expand Down
Loading

0 comments on commit 7650914

Please sign in to comment.