-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathparser.py
155 lines (115 loc) · 6.13 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import sys
import torch
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser
from transformers.trainer_utils import get_last_checkpoint
from utils.logging import get_logger
from args.data_args import DataArguments
from args.model_args import ModelArguments
from args.finetuning_args import FinetuningArguments
from args.training_args import TrainingArguments
from args.generation_args import GenerationArguments, InferArguments
logger = get_logger(__name__)
_TRAIN_ARGS = [
ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GenerationArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GenerationArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GenerationArguments, InferArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GenerationArguments, InferArguments
]
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
_verify_model_args(model_args, finetuning_args)
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = TrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = TrainingArguments(**training_args_dict)
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
training_args.resume_from_checkpoint
))
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, inference_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
return model_args, data_args, training_args, finetuning_args, generating_args, inference_args