Welcome! This guide will walk you through the Tune Factory application. The application uses llama-factory for tuning.
Before you start, make sure you have the following:
- An active Google Cloud account.
- A Google Cloud Storage (GCS) bucket where you'll store your datasets and training outputs.
- The necessary permissions to access and use the required Google Cloud Vertex AI services.
- An llama-factory docker image
All API requests will be sent to the following base URL:
http://localhost:8000
You can replace localhost
with your server's address if you're running this remotely.
Your datasets, training configurations, and model outputs will be stored in your GCS bucket. Let's call it shkhose-tune-factory
for now, but feel free to replace this with your actual bucket name.
To upload a dataset, you'll send a POST
request to /datasets/upload
. This request should include your dataset file in multipart/form-data
format. Here's how you do it:
POST http://localhost:8000/datasets/upload
Content-Type: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW
------WebKitFormBoundary7MA4YWxkTrZu0gW
Content-Disposition: form-data; name="file"; filename="alpaca_en_demo.json"
Content-Type: application/json
< ./alpaca_en_demo.json
------WebKitFormBoundary7MA4YWxkTrZu0gW--
After sending this request, you'll receive a response containing the GCS URL of your uploaded dataset. It'll look something like this:
{
"gcs_url": "gs://shkhose-tune-factory/datasets/alpaca_en_demo.json"
}
This GCS URL will be important for future steps, so keep it handy!
To see all the datasets you've uploaded, simply send a GET
request to /datasets
:
GET http://localhost:8000/datasets
This will return a list of all your datasets.
If you need to retrieve a specific dataset, you can use its GCS URL. Send a GET
request to /datasets/{gcs_url}
, replacing {gcs_url}
with the actual GCS URL of your dataset:
GET http://localhost:8000/datasets/gs://shkhose-tune-factory/datasets/alpaca_en_demo.json
Before you can start training, you need to generate a training configuration. This is done by sending a POST
request to /training/generate_config
with a JSON payload specifying your training parameters.
Here's an example:
POST http://localhost:8000/training/generate_config
Content-Type: application/json
{
"dataset_dir": "/gcs/shkhose-tune-factory/datasets",
"dataset": "alpaca_en_demo",
"model_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"output_dir": "/gcs/shkhose-tune-factory/meta-llama/saves/llama3-8b/lora/sft",
"training_config": {
"learning_rate": "1.0e-4",
"template": "llama3",
"stage": "sft",
"do_train": "true",
"finetuning_type": "lora",
"lora_target": "all",
"per_device_train_batch_size": "1",
"gradient_accumulation_steps": "8",
"num_train_epochs": "1.0",
"lr_scheduler_type": "cosine",
"warmup_ratio": "0.1",
"bf16": "true",
"ddp_timeout": "180000000",
"val_size": "0.1",
"per_device_eval_batch_size": "1",
"eval_strategy": "steps",
"eval_steps": "500"
}
}
This will generate a configuration file in your GCS bucket, and you'll receive a response with its GCS URL.
To view all your training configurations, send a GET
request to /training/configs
:
GET http://localhost:8000/training/configs
With your training configuration ready, you can start a training job. Send a POST
request to /training/start
with the GCS URL of your configuration file:
POST http://localhost:8000/training/start
Content-Type: application/json
{
"config_gcs_url": "gs://shkhose-tune-factory/training_configs/training_config_1178797c.yaml"
}
To check the status of your training job, you'll need the job ID. Send a GET
request to /training/status/{job_id}
, replacing {job_id}
with your job's ID:
GET http://localhost:8000/training/status/3604209251972546560
Once your model is trained, you can deploy it using vLLM. Send a POST
request to /deployment/deploy_vllm
with the necessary deployment parameters:
POST http://localhost:8000/deployment/deploy_vllm
Content-Type: application/json
{
"model_name": "llama-3-vllm-model-api",
"model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
"service_account": "<123>[email protected]",
"machine_type": "g2-standard-8",
"accelerator_type": "NVIDIA_L4",
"accelerator_count": "1",
"gpu_memory_utilization": "0.9",
"max_model_len": "4096",
"dtype": "auto",
"enable_trust_remote_code": "false",
"enforce_eager": "false",
"enable_lora": "true",
"max_loras": "1",
"max_cpu_loras": "8",
"use_dedicated_endpoint": "false",
"max_num_seqs": "256"
}
You'll receive a response containing the endpoint ID of your deployed model.
To check the deployment status, use the endpoint ID you received. Send a GET
request to /deployment/status/{endpoint_id}
:
GET http://localhost:8000/deployment/status/5712755654179946496
python predict-vllm.py
Enter your prompt: what is a car
Prompt:
what is a car?
Output:
a car is a vehicle that is powered by an internal combustion engine or electric motor and is designed for transporting people or goods. cars are typically designed for use on roads and highways, and are often equipped with features such as air conditioning, radio, and