Skip to content

Commit

Permalink
Changed stage level arguments to a string
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmervdl committed Dec 24, 2023
1 parent 7fde43e commit 0fd4742
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 6 additions & 2 deletions src/opustrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _load_stage(self, ymldata:dict, basepath:str, stage_name:str, available_data
until_dataset=until_dataset_name,
until_epoch=until_epoch,
modifiers=self._load_modifiers(ymldata[stage_name], basepath) if isinstance(ymldata[stage_name], dict) and 'modifiers' in ymldata[stage_name] else None,
arguments=[str(arg) for arg in ymldata[stage_name]['arguments']] if isinstance(ymldata[stage_name], dict) and 'arguments' in ymldata[stage_name] else [],
arguments=shlex.split(ymldata[stage_name]['arguments']) if isinstance(ymldata[stage_name], dict) and 'arguments' in ymldata[stage_name] else [],
)
except Exception as exc:
raise CurriculumLoaderError(f"could not complete the parse of stage '{stage_name}': {exc!s}") from exc
Expand Down Expand Up @@ -839,9 +839,13 @@ def main(args:argparse.Namespace) -> None:
# Make trainer listen to `kill -SIGUSR1 $PID` to print dataset progress
signal.signal(signal.SIGUSR1, lambda signum, handler: print_state(trainer.state()))

# Trainer is whatever is in config['trainer'], unless you specified it on
# the cli. If neither is specified, this will raise an KeyError.
model_trainer_cmd = args.trainer or shlex.split(config['trainer'])

while trainer.stage is not None:
model_trainer = subprocess.Popen(
(args.trainer or shlex.split(config['trainer'])) + trainer.stage.arguments,
model_trainer_cmd + trainer.stage.arguments,
stdin=subprocess.PIPE,
encoding="utf-8",
preexec_fn=ignore_sigint) # ignore_sigint makes marian ignore Ctrl-C. We'll stop it from here.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def test_early_stopping(self):
'clean 1.0',
'until clean inf'
],
'arguments': ['5000']
'arguments': '-n 5000'
},
'mid': {
'mix': [
'medium 1.0',
'until medium inf',
],
'arguments': ['10000']
'arguments': '-n 10000'
},
'seed': 1111
}
Expand All @@ -76,7 +76,7 @@ def test_early_stopping(self):
'--do-not-resume',
'--no-shuffle',
'--config', str(Path(tmp) / 'config.yml'),
'head', '-n', # plus value for n, per stage
'head', # plus value for -n, per stage
], stdout=fout, stderr=ferr)

retval = child.wait(30)
Expand All @@ -89,7 +89,7 @@ def test_early_stopping(self):
# Assert we got the number of lines we'd expect
line_count = sum(1 for _ in fout)
expected_line_count = sum(
int(config[stage]['arguments'][0])
int(config[stage]['arguments'][3:]) # interpret the `-n XXX`
for stage in config['stages']
)
self.assertEqual(line_count, expected_line_count)

0 comments on commit 0fd4742

Please sign in to comment.