Skip to content

Commit

Permalink
Implement explicit forall parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
osa1 committed Jan 1, 2024
1 parent 9d20407 commit b5cb058
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 48 deletions.
15 changes: 11 additions & 4 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl<Id: fmt::Debug> Decl<Id> {
}
}

pub fn type_synonym(&self) -> &TypeDecl<Id> {
pub fn type_syn(&self) -> &TypeDecl<Id> {
match &self.node {
Decl_::Type(type_syn) => type_syn,
other => panic!("Not type synonym: {:?}", other),
Expand All @@ -183,9 +183,14 @@ pub type ValueDecl<Id> = AstNode<ValueDecl_<Id>>;
#[cfg(test)]
impl<Id: fmt::Debug> ValueDecl<Id> {
// TODO: Add a type for type sig
pub fn type_sig(&self) -> (&[Id], &[Type<Id>], &Type<Id>) {
pub fn type_sig(&self) -> (&[Id], &[Id], &[Type<Id>], &Type<Id>) {
match &self.node {
ValueDecl_::TypeSig { vars, context, ty } => (&*vars, &*context, ty),
ValueDecl_::TypeSig {
vars,
foralls,
context,
ty,
} => (&*vars, &*foralls, &*context, ty),
other => panic!("Not type sig: {:?}", other),
}
}
Expand All @@ -201,13 +206,15 @@ impl<Id: fmt::Debug> ValueDecl<Id> {

#[derive(Debug, Clone)]
pub enum ValueDecl_<Id> {
/// In `x, y, z :: Show a => a -> String`:
/// In `x, y, z :: forall a . Show a => a -> String`:
///
/// - vars = `[x, y, z]`
/// - foralls = `[a]`
/// - context = `[Show a]`
/// - ty = `a -> String`
TypeSig {
vars: Vec<Id>,
foralls: Vec<Id>,
context: Vec<Type<Id>>,
ty: Type<Id>,
},
Expand Down
8 changes: 7 additions & 1 deletion src/class_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ pub(crate) fn module_class_env(module: &[ast::RenamedDecl], kinds: &Map<Id, Kind

let mut methods: Map<Id, Scheme> = Default::default();
for decl in decls {
if let ast::ValueDecl_::TypeSig { vars, context, ty } = &decl.node {
if let ast::ValueDecl_::TypeSig {
vars,
foralls: _,
context,
ty,
} = &decl.node
{
let mut method_context: Vec<ast::RenamedType> =
Vec::with_capacity(context.len() + 1);

Expand Down
1 change: 1 addition & 0 deletions src/kind_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ fn collect_types(decls: &[ast::RenamedDecl]) -> Vec<TypeDecl> {
for decl in &decl.node.decls {
if let ast::ValueDecl_::TypeSig {
vars: _,
foralls: _,
context,
ty,
} = &decl.node
Expand Down
4 changes: 3 additions & 1 deletion src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ lexgen::lexer! {
"deriving" = Token::ReservedId(ReservedId::Deriving),
"do" = Token::ReservedId(ReservedId::Do),
"else" = Token::ReservedId(ReservedId::Else),
"forall" = Token::ReservedId(ReservedId::Forall),
"foreign" = Token::ReservedId(ReservedId::Foreign),
"if" = Token::ReservedId(ReservedId::If),
"import" = Token::ReservedId(ReservedId::Import),
Expand Down Expand Up @@ -150,7 +151,7 @@ mod test {
#[test]
fn lex_id_sym() {
assert_eq!(
lex("a A ++ :+: A.a A.A A.++ A.:+: *"),
lex("a A ++ :+: A.a A.A A.++ A.:+: * ."),
vec![
Token::VarId,
Token::ConId,
Expand All @@ -161,6 +162,7 @@ mod test {
Token::QVarSym,
Token::QConSym,
Token::VarSym,
Token::VarSym,
]
);
}
Expand Down
21 changes: 20 additions & 1 deletion src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn parse_module(module_str: &str) -> ParserResult<Vec<ParsedDecl>> {

/// Parse a type with predicates. Does not handle layout.
#[cfg(test)]
pub fn parse_type(type_str: &str) -> ParserResult<(Vec<ParsedType>, ParsedType)> {
pub fn parse_type(type_str: &str) -> ParserResult<(Vec<String>, Vec<ParsedType>, ParsedType)> {
Parser::new(
type_str,
"<input>".into(),
Expand Down Expand Up @@ -342,6 +342,25 @@ impl<'input, L: LayoutLexer_> Parser<'input, L> {
self.context_parser(Self::class)
}

fn try_foralls(&mut self) -> ParserResult<Vec<String>> {
let mut ids = vec![];
if self.skip_token(Token::ReservedId(ReservedId::Forall)) {
loop {
let (l, t, r) = self.next()?;
match t {
Token::VarId => {
ids.push(self.string(l, r));
}
Token::VarSym if self.str(l, r) == "." => {
break;
}
_ => return self.fail(l, ErrorKind::UnexpectedToken),
}
}
}
Ok(ids)
}

fn try_context_arrow(&mut self) -> Option<Vec<ParsedType>> {
self.try_(|self_| {
let context = self_.context()?;
Expand Down
13 changes: 11 additions & 2 deletions src/parser/decl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,18 @@ impl<'input, L: LayoutLexer_> Parser<'input, L> {
let (l, _, _) = self_.peek()?;
let vars = self_.vars()?;
self_.expect_token(Token::ReservedOp(ReservedOp::ColonColon))?;
let (context, ty) = self_.type_with_context()?;
let (foralls, context, ty) = self_.type_with_context()?;
let r = ty.span.end;
Ok(self_.spanned(l, r, ValueDecl_::TypeSig { vars, context, ty }))
Ok(self_.spanned(
l,
r,
ValueDecl_::TypeSig {
vars,
foralls,
context,
ty,
},
))
});

if let Ok(ty_sig) = ty_sig {
Expand Down
13 changes: 12 additions & 1 deletion src/parser/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,5 +353,16 @@ type T f = f Int
let ast = parse_module(pgm).unwrap();
assert_eq!(ast.len(), 2);
ast[0].kind_sig();
ast[1].type_synonym();
ast[1].type_syn();
}

#[test]
fn explicit_forall() {
let pgm = "f :: forall f a b . Functor f => (a -> b) -> f a -> f b";
let ast = parse_module(pgm).unwrap();
assert_eq!(ast.len(), 1);
let (vars, foralls, context, _ty) = ast[0].value().type_sig();
assert_eq!(vars, ["f"]);
assert_eq!(foralls, ["f", "a", "b"]);
assert_eq!(context.len(), 1);
}
10 changes: 7 additions & 3 deletions src/parser/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ impl<'input, L: LayoutLexer_> Parser<'input, L> {
/// Parses the part after `::` in a type signature. E.g. `x :: Show a => a -> String` the part
/// `Show a => a -> String`.
///
/// This is used in tests. Haskell 2010 doesn't have a non-terminal for this.
pub(crate) fn type_with_context(&mut self) -> ParserResult<(Vec<ParsedType>, ParsedType)> {
/// This is used type signatures and tests. Haskell 2010 spec doesn't have a non-terminal for
/// this.
pub(crate) fn type_with_context(
&mut self,
) -> ParserResult<(Vec<String>, Vec<ParsedType>, ParsedType)> {
let foralls = self.try_foralls()?;
let context = self.try_context_arrow().unwrap_or_default();
let ty = self.type_()?;
Ok((context, ty))
Ok((foralls, context, ty))
}

/*
Expand Down
111 changes: 82 additions & 29 deletions src/renaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,37 @@ impl Renamer {

fn rename_value_decl(&mut self, decl: &ast::ParsedValueDecl) -> ast::RenamedValueDecl {
match &decl.node {
ast::ValueDecl_::TypeSig { vars, context, ty } => {
ast::ValueDecl_::TypeSig {
vars,
foralls,
context,
ty,
} => {
let vars: Vec<Id> = vars.iter().map(|var| self.rename_var(var)).collect();

// `context` and `ty` may use variables before declaring
self.tys.enter();
let context: Vec<ast::RenamedType> =
context.iter().map(|ty| self.rename_type(ty)).collect();
let ty = self.rename_type(ty);

let foralls: Vec<Id> = foralls.iter().map(|var| self.bind_type_var(var)).collect();

// When the type has a scoped type variables all type variables should be listed.
let allow_binding = foralls.is_empty();

let context: Vec<ast::RenamedType> = context
.iter()
.map(|ty| self.rename_type(ty, allow_binding))
.collect();

let ty = self.rename_type(ty, allow_binding);

self.tys.exit();

decl.with_node(ast::ValueDecl_::TypeSig { vars, context, ty })
decl.with_node(ast::ValueDecl_::TypeSig {
vars,
foralls,
context,
ty,
})
}

ast::ValueDecl_::Fixity { fixity, prec, ops } => {
Expand Down Expand Up @@ -287,7 +307,7 @@ impl Renamer {
.iter()
.map(|var| self.bind_fresh_type_var(var))
.collect();
let rhs = self.rename_type(rhs);
let rhs = self.rename_type(rhs, true);
self.tys.exit();
decl.with_node(ast::TypeDecl_ { ty, vars, rhs })
}
Expand All @@ -302,7 +322,10 @@ impl Renamer {
} = &decl.node;
let ty_con = self.rename_type_var(ty_con);
self.tys.enter();
let context = context.iter().map(|ctx| self.rename_type(ctx)).collect();
let context = context
.iter()
.map(|ctx| self.rename_type(ctx, true))
.collect();
let ty_args = ty_args
.iter()
.map(|var| self.bind_fresh_type_var(var))
Expand Down Expand Up @@ -332,7 +355,7 @@ impl Renamer {
fn rename_field(&mut self, field: &ast::ParsedFieldDecl) -> ast::RenamedFieldDecl {
let ast::FieldDecl_ { vars, ty } = &field.node;
let vars = vars.iter().map(|var| self.rename_var(var)).collect();
let ty = self.rename_type(ty);
let ty = self.rename_type(ty, true);
field.with_node(ast::FieldDecl_ { vars, ty })
}

Expand All @@ -345,7 +368,10 @@ impl Renamer {
} = &decl.node;
let ty_con = self.rename_type_var(ty_con);
self.tys.enter();
let context = context.iter().map(|ctx| self.rename_type(ctx)).collect();
let context = context
.iter()
.map(|ctx| self.rename_type(ctx, true))
.collect();
for ty_arg in ty_args {
self.bind_fresh_type_var(ty_arg);
}
Expand Down Expand Up @@ -374,7 +400,10 @@ impl Renamer {
self.tys.enter();
// Context will refer to the type variable, so bind the type variable first.
self.bind_fresh_type_var(ty_arg);
let context = context.iter().map(|ty| self.rename_type(ty)).collect();
let context = context
.iter()
.map(|ty| self.rename_type(ty, true))
.collect();
let ty_arg = self.rename_type_var(ty_arg);
let decls = decls
.iter()
Expand All @@ -398,8 +427,11 @@ impl Renamer {
} = &decl.node;
let ty_con = self.rename_type_var(ty_con);
self.tys.enter();
let context = context.iter().map(|ty| self.rename_type(ty)).collect();
let ty = self.rename_type(ty);
let context = context
.iter()
.map(|ty| self.rename_type(ty, true))
.collect();
let ty = self.rename_type(ty, true);

// Instance methods are top-level, but we create a new scope to be able to able to override
// class method ids while renaming instance declarations.
Expand Down Expand Up @@ -427,7 +459,7 @@ impl Renamer {
fn rename_default_decl(&mut self, decl: &ast::ParsedDefaultDecl) -> ast::RenamedDefaultDecl {
let ast::DefaultDecl_ { tys } = &decl.node;
decl.with_node(ast::DefaultDecl_ {
tys: tys.iter().map(|ty| self.rename_type(ty)).collect(),
tys: tys.iter().map(|ty| self.rename_type(ty, true)).collect(),
})
}

Expand Down Expand Up @@ -591,8 +623,11 @@ impl Renamer {
type_,
} => ast::Exp_::TypeAnnotation {
exp: Box::new(self.rename_exp(exp)),
context: context.iter().map(|ty| self.rename_type(ty)).collect(),
type_: self.rename_type(type_),
context: context
.iter()
.map(|ty| self.rename_type(ty, true))
.collect(),
type_: self.rename_type(type_, true),
},

ast::Exp_::ArithmeticSeq { exp1, exp2, exp3 } => ast::Exp_::ArithmeticSeq {
Expand Down Expand Up @@ -725,27 +760,37 @@ impl Renamer {
}

// NB. Types can use varibles without declaring them first
pub(crate) fn rename_type(&mut self, ty: &ast::ParsedType) -> ast::RenamedType {
pub(crate) fn rename_type(
&mut self,
ty: &ast::ParsedType,
allow_binding: bool,
) -> ast::RenamedType {
ty.map(|ty| match ty {
ast::Type_::Tuple(tys) => {
ast::Type_::Tuple(tys.iter().map(|ty| self.rename_type(ty)).collect())
}
ast::Type_::Tuple(tys) => ast::Type_::Tuple(
tys.iter()
.map(|ty| self.rename_type(ty, allow_binding))
.collect(),
),

ast::Type_::List(ty) => ast::Type_::List(Box::new(self.rename_type(ty))),
ast::Type_::List(ty) => ast::Type_::List(Box::new(self.rename_type(ty, allow_binding))),

ast::Type_::Arrow(ty1, ty2) => ast::Type_::Arrow(
Box::new(self.rename_type(ty1)),
Box::new(self.rename_type(ty2)),
Box::new(self.rename_type(ty1, allow_binding)),
Box::new(self.rename_type(ty2, allow_binding)),
),

ast::Type_::App(ty, tys) => ast::Type_::App(
Box::new(self.rename_type(ty)),
tys.iter().map(|ty| self.rename_type(ty)).collect(),
Box::new(self.rename_type(ty, allow_binding)),
tys.iter()
.map(|ty| self.rename_type(ty, allow_binding))
.collect(),
),

ast::Type_::Con(con) => ast::Type_::Con(self.rename_tycon(con)),

ast::Type_::Var(var) => ast::Type_::Var(self.rename_or_bind_type_var(var)),
ast::Type_::Var(var) => {
ast::Type_::Var(self.rename_or_bind_type_var(var, allow_binding))
}
})
}

Expand Down Expand Up @@ -775,17 +820,25 @@ impl Renamer {
.clone()
}

fn rename_or_bind_type_var(&mut self, var: &str) -> Id {
fn rename_or_bind_type_var(&mut self, var: &str, allow_binding: bool) -> Id {
match self.tys.get(var) {
Some(id) => id.clone(),
None => {
let id = self.fresh_id(var, IdKind::TyVar);
self.tys.bind(var.to_owned(), id.clone());
id
if allow_binding {
self.bind_type_var(var)
} else {
panic!("Unbound type varible: {}", var)
}
}
}
}

fn bind_type_var(&mut self, var: &str) -> Id {
let id = self.fresh_id(var, IdKind::TyVar);
self.tys.bind(var.to_owned(), id.clone());
id
}

fn rename_type_var(&mut self, var: &str) -> Id {
self.tys
.get(var)
Expand Down
Loading

0 comments on commit b5cb058

Please sign in to comment.