diff --git a/src/opustrainer/__main__.py b/src/opustrainer/__main__.py index 4615aca..e5216c9 100644 --- a/src/opustrainer/__main__.py +++ b/src/opustrainer/__main__.py @@ -1,3 +1,4 @@ -from opustrainer.trainer import main +import sys +from opustrainer.trainer import parse_args, main -main() +main(parse_args(sys.argv[1:])) diff --git a/src/opustrainer/trainer.py b/src/opustrainer/trainer.py index 5a29e22..901112f 100755 --- a/src/opustrainer/trainer.py +++ b/src/opustrainer/trainer.py @@ -739,8 +739,6 @@ class StateTracker: """Wraps around the trainer.run() call to restore and dump state its.""" path: str loader: StateLoader - dump: bool - restore: bool def __init__(self, path:str, *, loader:StateLoader=StateLoader(), timeout=60): """ @@ -806,9 +804,7 @@ def parse_args(args:List[str]) -> argparse.Namespace: return parser.parse_args(args) -def main() -> None: - args = parse_args(sys.argv[1:]) - +def main(args:argparse.Namespace) -> None: logger.setup_logger(args.log_file, args.log_level) with open(args.config, 'r', encoding='utf-8') as fh: @@ -859,10 +855,14 @@ def main() -> None: # Tracks whether we interrupted training, or whether the stopping occurred naturally interrupted = False + # Tracks whether stdin was closed due to BrokenPipeError + closed = False + try: for batch in state_tracker.run(trainer, batch_size=args.batch_size, chunk_size=args.chunk_size, processes=args.workers): model_trainer.stdin.writelines(batch) except BrokenPipeError: + closed = True logger.log("trainer stopped reading input") except KeyboardInterrupt: interrupted = True @@ -872,10 +872,11 @@ def main() -> None: # or because ctrl-c was pressed. Pressing ctrl-c more advances to next level of aggressiveness. for urgency in ['exit', 'terminate', 'kill']: try: - logger.log(f"waiting for trainer to {urgency}. Press ctrl-c to be more aggressive") + logger.log(f"waiting for trainer to {urgency}") if urgency == 'exit': - model_trainer.stdin.close() + if not closed: + model_trainer.stdin.close() elif urgency == 'terminate': model_trainer.terminate() elif urgency == 'kill': @@ -887,6 +888,8 @@ def main() -> None: sys.exit(retval) elif interrupted: sys.exit(130) + else: + break # We're done trying to stop to various degrees except KeyboardInterrupt: interrupted = True continue # Skip to the next degree of forcefully stopping @@ -895,5 +898,6 @@ def main() -> None: # would already have called sys.exit() by now. trainer.next_stage() + if __name__ == '__main__': - main() + main(parse_args(sys.argv[1:])) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index ff80e1b..7c40ff7 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -452,35 +452,3 @@ def test_num_fields(self): ]) # Assert that we got an error message for one line self.assertRegex(logger_ctx.output[0], r'\[Trainer\] Expected 3 fields in clean line:') - - def test_next_stage(self): - """Test letting the trainer move to the next stage using early-stopping""" - config = { - 'datasets': { - 'clean': 'test/data/clean', - 'medium': 'test/data/medium', - }, - 'stages': [ - 'start', - 'mid', - ], - 'start': [ - 'clean 1.0', - 'until clean inf' - ], - 'mid': [ - 'medium 1.0', - 'until medium inf', - ], - 'seed': 1111 - } - - curriculum = CurriculumLoader().load(config) - - # Reference batches (trainer runs without resuming) - with closing(Trainer(curriculum)) as trainer: - self.assertEqual(trainer.stage.name, 'start') - trainer.next_stage() - self.assertEqual(trainer.stage.name, 'mid') - trainer.next_stage() - self.assertIsNone(trainer.stage) diff --git a/tests/test_trainer_cli.py b/tests/test_trainer_cli.py index ddd31a2..7534621 100644 --- a/tests/test_trainer_cli.py +++ b/tests/test_trainer_cli.py @@ -1,5 +1,10 @@ #!/usr/bin/env python3 import unittest +import yaml +import sys +from subprocess import Popen +from pathlib import Path +from tempfile import TemporaryDirectory, TemporaryFile from opustrainer.trainer import parse_args @@ -28,3 +33,51 @@ def test_marian_log_args(self): 'trainer': ['marian', '--log', 'marian.log'] } self.assertEqual({**vars(parsed), **expected}, vars(parsed)) + + def test_early_stopping(self): + """Test letting the trainer move to the next stage using early-stopping""" + head_lines = 10000 + + basepath = Path('contrib').absolute() + + config = { + 'datasets': { + 'clean': str(basepath / 'test-data/clean'), + 'medium': str(basepath / 'test-data/medium'), + }, + 'stages': [ + 'start', + 'mid', + ], + 'start': [ + 'clean 1.0', + 'until clean inf' + ], + 'mid': [ + 'medium 1.0', + 'until medium inf', + ], + 'seed': 1111 + } + + with TemporaryDirectory() as tmp, TemporaryFile() as fout: + with open(Path(tmp) / 'config.yml', 'w+t') as fcfg: + yaml.safe_dump(config, fcfg) + + child = Popen([ + sys.executable, + '-m', 'opustrainer', + '--do-not-resume', + '--no-shuffle', + '--config', str(Path(tmp) / 'config.yml'), + 'head', '-n', str(head_lines) + ], stdout=fout) + + # Assert we exited neatly + retval = child.wait(30) + self.assertEqual(retval, 0) + + # Assert we got the number of lines we'd expect + fout.seek(0) + line_count = sum(1 for _ in fout) + self.assertEqual(line_count, len(config['stages']) * head_lines)