Skip to content

Commit

Permalink
Add case
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Dec 21, 2024
1 parent 4e09583 commit a6057bf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
13 changes: 11 additions & 2 deletions uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Context:

@dataclass
class FrameApplyFun(Context):
val: Any
fun: Any
ctx: Context


Expand All @@ -58,6 +58,10 @@ class FrameApplyArg(Context):
term: "AST"
ctx: Context

@dataclass
class FrameApplyFunArg(Context):
arg: Any
ctx: Context

@dataclass
class FrameForce(Context):
Expand All @@ -76,6 +80,11 @@ class FrameConstr(Context):
resolved_fields: List["AST"]
ctx: Context

@dataclass
class FrameCases(Context):
env: frozendict.frozendict
branches: List["AST"]
ctx: Context

class Step:
pass
Expand Down Expand Up @@ -1109,7 +1118,7 @@ def _replicate_bytes(length: BuiltinInteger, val: BuiltinInteger):
BuiltInFun.BData: single_bytestring(lambda x: PlutusByteString(x.value)),
BuiltInFun.UnConstrData: single_data_constr(
lambda x: BuiltinPair(
BuiltinInteger(x.constructor), BuiltinList(x.fields, PlutusData())
BuiltinInteger(x.constructor), BuiltinList(x.branches, PlutusData())
)
),
BuiltInFun.UnMapData: single_data_map(
Expand Down
31 changes: 30 additions & 1 deletion uplc/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def budget_cost_of_op_on_model(
}


def transfer_arg_stack(args: List[AST], context: Context) -> Context:
if not args:
return context
return transfer_arg_stack(args[:-1], FrameApplyFunArg(args[-1], context))

class Machine:
def __init__(
self,
Expand Down Expand Up @@ -185,11 +190,23 @@ def compute(self, term: AST, context: Context, state: frozendict.frozendict):
context,
term,
)
elif isinstance(term, Case):
return Compute(
FrameCases(
state,
term.branches,
context,
),
state,
term.scrutinee,
)
raise NotImplementedError(f"Invalid term to compute: {term}")

def return_compute(self, context, value):
if isinstance(context, FrameApplyFun):
return self.apply_evaluate(context.ctx, context.val, value)
return self.apply_evaluate(context.ctx, context.fun, value)
elif isinstance(context, FrameApplyFunArg):
return self.apply_evaluate(context.ctx, value, context.arg)
elif isinstance(context, FrameApplyArg):
return Compute(
FrameApplyFun(
Expand Down Expand Up @@ -221,6 +238,18 @@ def return_compute(self, context, value):
context.ctx,
Constr(context.tag, resolved_fields),
)
elif isinstance(context, FrameCases):
if not isinstance(value, Constr):
raise RuntimeError("Scrutinized non-constr in case")
try:
branch = context.branches[value.tag]
except IndexError as e:
raise RuntimeError("No branch provided for constr tag") from None
return Compute(
transfer_arg_stack(value.fields, context.ctx),
context.env,
branch,
)
raise NotImplementedError(f"Invalid context to return compute: {context}")

def apply_evaluate(self, context, function, argument):
Expand Down

0 comments on commit a6057bf

Please sign in to comment.