Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restart trainer between stages #45

Closed
wants to merge 10 commits into from
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ trainer: /path/to/trainer/run.py
```

### Number of fields
If `num_fields` is provided, at read time, the trainer will strip any extra TSV fields that the dataset contains (such as optinal alignment field that you are not going to use). Furthermore, any line that doesn't have enough fields gets filtered (eg lines missing alignment info when you do actually care about alignment).
If `num_fields` is provided, at read time, the trainer will strip any extra TSV fields that the dataset contains (such as optional alignment field that you are not going to use). Furthermore, any line that doesn't have enough fields gets filtered (eg lines missing alignment info when you do actually care about alignment).

### Extended stage configuration
If you want to change which modifiers are used for a specific stage, you can the extended stage configuration format. If a `modifiers` is mentioned here, it will override the curriculum-wide defined `modifiers` for just this stage.
If you want to change which modifiers are used for a specific stage, you can the extended stage configuration format.

In the extended format, the list of datasets is defined in the `mix` key. You can optionally add a `modifiers` key. For example:
In the extended format, the list of datasets is defined in the `mix` key. You can optionally add a `modifiers` and `arguments` key. For example:

```yaml
start:
Expand All @@ -110,10 +110,15 @@ start:
- dirty 0
- until clean 2 # Until two epochs of clean
modifiers:
- UpperCase: 0.05
- TitleCase: 0.05
- UpperCase: 0.05
- TitleCase: 0.05
arguments: --stop-early
```

If a `modifiers` is mentioned here, it will override the curriculum-wide defined `modifiers` for just this stage.

If the optional `arguments` key is added, it will be appended to the end of the trainer command.

Note that you can use YAML references if you wish to extensively combine global and local modifiers.

### Modifiers
Expand Down
2 changes: 1 addition & 1 deletion contrib/test-data/test_enzh_config_plain_expected.log
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 7
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 8
[2023-07-04 02:48:12] [Trainer] [INFO] Reading clean for epoch 9
[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit. Press ctrl-c to be more aggressive
[2023-07-04 02:48:12] [Trainer] [INFO] waiting for trainer to exit
5 changes: 3 additions & 2 deletions src/opustrainer/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from opustrainer.trainer import main
import sys
from opustrainer.trainer import parse_args, main

main()
main(parse_args(sys.argv[1:]))
129 changes: 75 additions & 54 deletions src/opustrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class Stage:
until_dataset: str
until_epoch: Optional[int]
modifiers: Optional[List[Modifier]]
arguments: List[str]


@dataclass(frozen=True)
Expand Down Expand Up @@ -448,6 +449,9 @@ def _load_stages(self, ymldata:dict, basepath:str, stages_order:List[str], datas
- until dataset3 epochs
modifiers:
- Modifier: freq
arguments:
- arg1
- arg2
```
"""
return {
Expand Down Expand Up @@ -491,7 +495,8 @@ def _load_stage(self, ymldata:dict, basepath:str, stage_name:str, available_data
datasets=datasets,
until_dataset=until_dataset_name,
until_epoch=until_epoch,
modifiers=self._load_modifiers(ymldata[stage_name], basepath) if isinstance(ymldata[stage_name], dict) and 'modifiers' in ymldata[stage_name] else None
modifiers=self._load_modifiers(ymldata[stage_name], basepath) if isinstance(ymldata[stage_name], dict) and 'modifiers' in ymldata[stage_name] else None,
arguments=shlex.split(ymldata[stage_name]['arguments']) if isinstance(ymldata[stage_name], dict) and 'arguments' in ymldata[stage_name] else [],
)
except Exception as exc:
raise CurriculumLoaderError(f"could not complete the parse of stage '{stage_name}': {exc!s}") from exc
Expand Down Expand Up @@ -739,32 +744,24 @@ class StateTracker:
"""Wraps around the trainer.run() call to restore and dump state its."""
path: str
loader: StateLoader
dump: bool
restore: bool

def __init__(self, path:str, *, loader:StateLoader=StateLoader(), restore:bool=True, dump:bool=True, timeout=60):
def __init__(self, path:str, *, loader:StateLoader=StateLoader(), timeout=60):
"""
Parameters
--–-------
path : str
Path to state file
loader : type, optional
Loader class for encoding/decoding the state file
restore : bool, optional
Whether to restore the state if the state file currently exists (default is True)
dump : bool, optional
Whether to dump the state to the file after training (default is True)
timeout : int, optional
Minimum number of seconds between state dumps
"""
self.path = path
self.loader = loader
self.dump = dump
self.restore = restore
self.timeout = timeout
self._last_dump = 0

def _restore(self, trainer:Trainer):
def restore(self, trainer:Trainer):
with open(self.path, 'r', encoding='utf-8') as fh:
return trainer.restore(self.loader.load(fh))

Expand All @@ -776,20 +773,17 @@ def _dump(self, trainer:Trainer):
self._last_dump = time.monotonic()

def run(self, trainer:Trainer, *args, **kwargs):
if self.restore and os.path.exists(self.path):
self._restore(trainer)

"""Wraps trainer.run() and dumps its state to disk from time to time."""
try:
for batch in trainer.run(*args, **kwargs):
# TODO: Replace this with something that listens to Marian, and
# writes the state to disk after marian performed validation.
if self.dump and time.monotonic() - self._last_dump > self.timeout:
if time.monotonic() - self._last_dump > self.timeout:
self._dump(trainer)
yield batch
finally:
# Dump on clean exit as well as on exception.
if self.dump:
self._dump(trainer)
self._dump(trainer)


def print_state(state:TrainerState) -> None:
Expand All @@ -815,9 +809,7 @@ def parse_args(args:List[str]) -> argparse.Namespace:
return parser.parse_args(args)


def main() -> None:
args = parse_args(sys.argv[1:])

def main(args:argparse.Namespace) -> None:
logger.setup_logger(args.log_file, args.log_level)

with open(args.config, 'r', encoding='utf-8') as fh:
Expand All @@ -836,56 +828,85 @@ def main() -> None:
tmpdir=args.temporary_directory,
shuffle=args.shuffle)

state_tracker = StateTracker(args.state or f'{args.config}.state', restore=not args.do_not_resume)
state_tracker = StateTracker(args.state or f'{args.config}.state')

if not args.do_not_resume:
try:
state_tracker.restore(trainer)
except FileNotFoundError:
pass

# Make trainer listen to `kill -SIGUSR1 $PID` to print dataset progress
signal.signal(signal.SIGUSR1, lambda signum, handler: print_state(trainer.state()))

model_trainer = subprocess.Popen(
args.trainer or shlex.split(config['trainer']),
stdin=subprocess.PIPE,
encoding="utf-8",
preexec_fn=ignore_sigint) # ignore_sigint makes marian ignore Ctrl-C. We'll stop it from here.

assert model_trainer.stdin is not None

# TODO: This logic looks complicated, should be able to do this simpler. Three scenarios:
# 1. ctrl-c is pressed and trainer is told this is the end of the training data
# 2. ctrl-c is pressed and trainer has much training data in its buffers, ctrl-c needs to be
# pressed again to tell trainer to really terminate. Just closing its stdin and waiting for
# it to notice takes too long
# 3. trainer decides it has read enough and will train no longer. This is the BrokenPipeError
# scenario. We don't need to deal with multiple levels of terminating the trainer because
# the trainer is already dead at this point.
try:
# Trainer is whatever is in config['trainer'], unless you specified it on
# the cli. If neither is specified, this will raise an KeyError.
model_trainer_cmd = args.trainer or shlex.split(config['trainer'])

while trainer.stage is not None:
model_trainer = subprocess.Popen(
model_trainer_cmd + trainer.stage.arguments,
stdin=subprocess.PIPE,
encoding="utf-8",
preexec_fn=ignore_sigint) # ignore_sigint makes marian ignore Ctrl-C. We'll stop it from here.

assert model_trainer.stdin is not None

# TODO: This logic looks complicated, should be able to do this simpler. Four scenarios:
# 1. ctrl-c is pressed and trainer is told this is the end of the training data
# 2. ctrl-c is pressed and trainer has much training data in its buffers, ctrl-c needs to be
# pressed again to tell trainer to really terminate. Just closing its stdin and waiting for
# it to notice takes too long
# 3. trainer decides it has read enough and will train no longer. This is the BrokenPipeError
# scenario.
# 4. end of the epochs of this stage was reached, and the for-loop ran out of steam.

# Tracks whether we interrupted training, or whether the stopping occurred naturally
interrupted = False

# Tracks whether stdin was closed due to BrokenPipeError
closed = False

try:
for batch in state_tracker.run(trainer, batch_size=args.batch_size, chunk_size=args.chunk_size, processes=args.workers):
model_trainer.stdin.writelines(batch)
except BrokenPipeError:
closed = True
logger.log("trainer stopped reading input")
except KeyboardInterrupt:
interrupted = True
logger.log("Ctrl-c pressed, stopping training")

# Levels of waiting for the trainer. This is reached either because we ran out of batches
# or because ctrl-c was pressed. Pressing ctrl-c more advances to next level of aggressiveness.
for stage in ['exit', 'terminate', 'kill']:
for urgency in ['exit', 'terminate', 'kill']:
try:
if stage == 'exit':
model_trainer.stdin.close()
elif stage == 'terminate':
logger.log(f"waiting for trainer to {urgency}")

if urgency == 'exit':
if not closed:
model_trainer.stdin.close()
elif urgency == 'terminate':
model_trainer.terminate()
else:
elif urgency == 'kill':
model_trainer.kill()

logger.log(f"waiting for trainer to {stage}. Press ctrl-c to be more aggressive")
sys.exit(model_trainer.wait()) # blocking

retval = model_trainer.wait() # blocking
if retval != 0:
logger.log(f"trainer stopped with non-zero exit code: {retval}")
sys.exit(retval)
elif interrupted:
sys.exit(130)
else:
break # We're done trying to stop to various degrees
except KeyboardInterrupt:
continue
except BrokenPipeError:
# BrokenPipeError is thrown by writelines() or close() and indicates that the child trainer
# process is no more. We can safely retrieve its return code and exit with that, it should
# not block at this point.
logger.log("trainer stopped reading input")
sys.exit(model_trainer.wait())
interrupted = True
continue # Skip to the next degree of forcefully stopping

# Move to next stage. If it was interrupted or exited uncleanly, we
# would already have called sys.exit() by now.
trainer.next_stage()


if __name__ == '__main__':
main()
main(parse_args(sys.argv[1:]))
8 changes: 5 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import closing
from textwrap import dedent
from io import StringIO
from itertools import chain
from itertools import chain, islice

import yaml

Expand Down Expand Up @@ -173,10 +173,11 @@ def test_resume(self):

# Train on trainer1
with closing(Trainer(curriculum)) as trainer1:
batches = [batch for _, batch in zip(range(10), state_tracker.run(trainer1))]
batches = list(islice(state_tracker.run(trainer1), 10))

# Resume on trainer2
with closing(Trainer(curriculum)) as trainer2:
state_tracker.restore(trainer2)
batches.extend(state_tracker.run(trainer2))

self.assertEqual(batches, batches_ref)
Expand Down Expand Up @@ -258,7 +259,8 @@ def test_simple(self):
],
until_dataset='clean',
until_epoch=5,
modifiers=None)
modifiers=None,
arguments=[])
})
self.assertEqual(curriculum.seed, 1111)
self.assertEqual(len(curriculum.modifiers), 1)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#!/usr/bin/env python3
import unittest
import sys
from subprocess import Popen
from pathlib import Path
from tempfile import TemporaryDirectory, TemporaryFile

import yaml

from opustrainer.trainer import parse_args


Expand Down Expand Up @@ -28,3 +35,61 @@ def test_marian_log_args(self):
'trainer': ['marian', '--log', 'marian.log']
}
self.assertEqual({**vars(parsed), **expected}, vars(parsed))

def test_early_stopping(self):
"""Test letting the trainer move to the next stage using early-stopping"""
basepath = Path('contrib').absolute()

config = {
'datasets': {
'clean': str(basepath / 'test-data/clean'),
'medium': str(basepath / 'test-data/medium'),
},
'stages': [
'start',
'mid',
],
'start': {
'mix': [
'clean 1.0',
'until clean inf'
],
'arguments': '-n 5000'
},
'mid': {
'mix': [
'medium 1.0',
'until medium inf',
],
'arguments': '-n 10000'
},
'seed': 1111
}

with TemporaryDirectory() as tmp, TemporaryFile() as fout, TemporaryFile() as ferr:
with open(Path(tmp) / 'config.yml', 'w+t') as fcfg:
yaml.safe_dump(config, fcfg)

child = Popen([
sys.executable,
'-m', 'opustrainer',
'--do-not-resume',
'--no-shuffle',
'--config', str(Path(tmp) / 'config.yml'),
'head', # plus value for -n, per stage
], stdout=fout, stderr=ferr)

retval = child.wait(30)
fout.seek(0)
ferr.seek(0)

# Assert we exited neatly
self.assertEqual(retval, 0, msg=ferr.read().decode())

# Assert we got the number of lines we'd expect
line_count = sum(1 for _ in fout)
expected_line_count = sum(
int(config[stage]['arguments'][3:]) # interpret the `-n XXX`
for stage in config['stages']
)
self.assertEqual(line_count, expected_line_count)