Skip to content

Commit

Permalink
Add test for early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmervdl committed Dec 23, 2023
1 parent f9bef93 commit b1a7f04
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/opustrainer/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from opustrainer.trainer import main
import sys
from opustrainer.trainer import parse_args, main

main()
main(parse_args(sys.argv[1:]))
20 changes: 12 additions & 8 deletions src/opustrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,6 @@ class StateTracker:
"""Wraps around the trainer.run() call to restore and dump state its."""
path: str
loader: StateLoader
dump: bool
restore: bool

def __init__(self, path:str, *, loader:StateLoader=StateLoader(), timeout=60):
"""
Expand Down Expand Up @@ -806,9 +804,7 @@ def parse_args(args:List[str]) -> argparse.Namespace:
return parser.parse_args(args)


def main() -> None:
args = parse_args(sys.argv[1:])

def main(args:argparse.Namespace) -> None:
logger.setup_logger(args.log_file, args.log_level)

with open(args.config, 'r', encoding='utf-8') as fh:
Expand Down Expand Up @@ -859,10 +855,14 @@ def main() -> None:
# Tracks whether we interrupted training, or whether the stopping occurred naturally
interrupted = False

# Tracks whether stdin was closed due to BrokenPipeError
closed = False

try:
for batch in state_tracker.run(trainer, batch_size=args.batch_size, chunk_size=args.chunk_size, processes=args.workers):
model_trainer.stdin.writelines(batch)
except BrokenPipeError:
closed = True
logger.log("trainer stopped reading input")
except KeyboardInterrupt:
interrupted = True
Expand All @@ -872,10 +872,11 @@ def main() -> None:
# or because ctrl-c was pressed. Pressing ctrl-c more advances to next level of aggressiveness.
for urgency in ['exit', 'terminate', 'kill']:
try:
logger.log(f"waiting for trainer to {urgency}. Press ctrl-c to be more aggressive")
logger.log(f"waiting for trainer to {urgency}")

if urgency == 'exit':
model_trainer.stdin.close()
if not closed:
model_trainer.stdin.close()
elif urgency == 'terminate':
model_trainer.terminate()
elif urgency == 'kill':
Expand All @@ -887,6 +888,8 @@ def main() -> None:
sys.exit(retval)
elif interrupted:
sys.exit(130)
else:
break # We're done trying to stop to various degrees
except KeyboardInterrupt:
interrupted = True
continue # Skip to the next degree of forcefully stopping
Expand All @@ -895,5 +898,6 @@ def main() -> None:
# would already have called sys.exit() by now.
trainer.next_stage()


if __name__ == '__main__':
main()
main(parse_args(sys.argv[1:]))
32 changes: 0 additions & 32 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,35 +452,3 @@ def test_num_fields(self):
])
# Assert that we got an error message for one line
self.assertRegex(logger_ctx.output[0], r'\[Trainer\] Expected 3 fields in clean line:')

def test_next_stage(self):
"""Test letting the trainer move to the next stage using early-stopping"""
config = {
'datasets': {
'clean': 'test/data/clean',
'medium': 'test/data/medium',
},
'stages': [
'start',
'mid',
],
'start': [
'clean 1.0',
'until clean inf'
],
'mid': [
'medium 1.0',
'until medium inf',
],
'seed': 1111
}

curriculum = CurriculumLoader().load(config)

# Reference batches (trainer runs without resuming)
with closing(Trainer(curriculum)) as trainer:
self.assertEqual(trainer.stage.name, 'start')
trainer.next_stage()
self.assertEqual(trainer.stage.name, 'mid')
trainer.next_stage()
self.assertIsNone(trainer.stage)
53 changes: 53 additions & 0 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#!/usr/bin/env python3
import unittest
import yaml
import sys
from subprocess import Popen
from pathlib import Path
from tempfile import TemporaryDirectory, TemporaryFile
from opustrainer.trainer import parse_args


Expand Down Expand Up @@ -28,3 +33,51 @@ def test_marian_log_args(self):
'trainer': ['marian', '--log', 'marian.log']
}
self.assertEqual({**vars(parsed), **expected}, vars(parsed))

def test_early_stopping(self):
"""Test letting the trainer move to the next stage using early-stopping"""
head_lines = 10000

basepath = Path('contrib').absolute()

config = {
'datasets': {
'clean': str(basepath / 'test-data/clean'),
'medium': str(basepath / 'test-data/medium'),
},
'stages': [
'start',
'mid',
],
'start': [
'clean 1.0',
'until clean inf'
],
'mid': [
'medium 1.0',
'until medium inf',
],
'seed': 1111
}

with TemporaryDirectory() as tmp, TemporaryFile() as fout:
with open(Path(tmp) / 'config.yml', 'w+t') as fcfg:
yaml.safe_dump(config, fcfg)

child = Popen([
sys.executable,
'-m', 'opustrainer',
'--do-not-resume',
'--no-shuffle',
'--config', str(Path(tmp) / 'config.yml'),
'head', '-n', str(head_lines)
], stdout=fout)

# Assert we exited neatly
retval = child.wait(30)
self.assertEqual(retval, 0)

# Assert we got the number of lines we'd expect
fout.seek(0)
line_count = sum(1 for _ in fout)
self.assertEqual(line_count, len(config['stages']) * head_lines)

0 comments on commit b1a7f04

Please sign in to comment.