Skip to content

Commit

Permalink
added evaluation script
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonxyliu committed Dec 27, 2022
1 parent 2acaff3 commit 6186e64
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions s2s_sup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
Infer trained model.
"""
import argparse
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration
import spot

from s2s_hf_transformers import T5_PREFIX

from s2s_pt_transformer import Seq2SeqTransformer, \
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMBED_SIZE, NHEAD, DIM_FFN_HID
from s2s_pt_transformer import translate as transformer_translate
from s2s_pt_transformer import construct_dataset as transformer_construct_dataset
from utils import save_to_file


class Seq2Seq:
Expand Down Expand Up @@ -42,12 +44,32 @@ def translate(self, queries):
)
ltls = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
elif self.model_type == "pt_transformer":
ltls = self.model_translate(self.model, self.vocab_transform, self.text_transform, queries[0])
ltls = [self.model_translate(self.model, self.vocab_transform, self.text_transform, queries[0])]
else:
raise ValueError(f'ERROR: unrecognized model, {self.model_type}')
return ltls


def evaluate(s2s, results_fpath):
_, valid_iter, _, _, _, _ = transformer_construct_dataset(args.data)
results = [["utterances", "true_ltl", "output_ltl", "is_correct"]]
accs = []

for utt, true_ltl in valid_iter:
out_ltl = s2s.translate([utt])[0]
try: # output LTL formula may have syntax error
is_correct = spot.are_equivalent(spot.formula(out_ltl), spot.formula(true_ltl))
is_correct = 'True' if is_correct else 'False'
except SyntaxError:
print(f'Syntax error in output LTL: {out_ltl}')
is_correct = 'Syntax Error'
accs.append(is_correct)
results.append([utt, true_ltl, out_ltl, is_correct])

save_to_file(results, f"{results_fpath}.csv")
print(np.mean([True if acc == 'True' else False for acc in accs]))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data/symbolic_pairs_no_perm.csv', help='file path to train and test data for supervised seq2seq')
Expand All @@ -66,6 +88,8 @@ def translate(self, queries):
else:
raise TypeError(f"ERROR: unrecognized model, {args.model}")

ltls = s2s.translate([args.utt])
print(args.utt)
print(ltls)
evaluate(s2s, f"results/s2s_{args.model}_batch_1_results")

# ltls = s2s.translate([args.utt])
# print(args.utt)
# print(ltls)

0 comments on commit 6186e64

Please sign in to comment.