forked from ZonePG/cs-notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01-HfArgumentParser.py
59 lines (50 loc) · 1.87 KB
/
01-HfArgumentParser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import HfArgumentParser
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
@dataclass
class ModelArguments:
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
def __post_init__(self):
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
dataset: Optional[str] = field(
default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
)
split: Optional[str] = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
def parse_args() -> Tuple[ModelArguments, DataArguments]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
))
return parser.parse_args_into_dataclasses()
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parse_args()
print(model_args, data_args)