-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.sh
executable file
·52 lines (43 loc) · 1.32 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/bin/bash
set -xe
echo "Starting training..."
# Use PEFT or not? That is the question.
PEFT_PARAMS=""
if [ "$USE_PEFT" = 1 ] ; then
PEFT_PARAMS="--use_peft"
fi
# Use 4-bit or 8-bit LORA? That is the other question.
LORA_PARAMS=""
if [ "$USE_4BIT" = 1 ] ; then
LORA_PARAMS="--load_in_4bit"
elif [ "$USE_8BIT" = 1 ] ; then
LORA_PARAMS="--load_in_8bit"
fi
HUB_PARAMS=""
if [ "$PUSH_TO_HUB" = 1 ] ; then
HUB_PARAMS="--push_to_hub --hub_model_id ${OUTPUT_MODEL_NAME}"
fi
WANDB_LOG_PARAMS=""
if [ "$LOG_TO_WANDB" = 1 ] ; then
WANDB_LOG_PARAMS="--log_with wandb"
fi
BATCH_SIZE="${BATCH_SIZE:-4}"
GAS="${GAS:-2}"
EPOCHS="${EPOCHS:-1}"
LEARNING_RATE="${LEARNING_RATE:-1.41e-5}"
SEQENCE_LENGTH="${SEQENCE_LENGTH:-512}"
python ${HOMEDIR}/sft_train.py \
--model_name "${BASE_MODEL_NAME}" \
--dataset_name "${DATASET_NAME}" \
${LORA_PARAMS} \
${PEFT_PARAMS} \
--batch_size ${BATCH_SIZE} \
--num_train_epochs ${EPOCHS} \
--learning_rate ${LEARNING_RATE} \
--gradient_accumulation_steps ${GAS} \
--seq_length ${SEQENCE_LENGTH} \
--output_dir "${OUTPUT_MODEL_PATH}" \
${WANDB_LOG_PARAMS} \
${HUB_PARAMS} && \
${HOMEDIR}/export.sh && \
${HOMEDIR}/quantize.sh || echo "Training failed." && exit 1