Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gpu ci #164

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/gpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Tests on GPU

on:
push:
branches:
- main
pull_request:
# This will trigger the workflow for pull requests to any branch
types: [opened, synchronize, reopened]

jobs:
gpu-tests:
name: python
runs-on: self-hosted

steps:
- uses: actions/checkout@v4
with:
submodules: true

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"



- name: Set up Python
run: uv python install 3.10.13

- name: Install the project
run: uv sync --all-extras --dev

- name: Install flash attention
run: uv pip install flash-attn --no-build-isolation

- name: Run tests
run: uv run pytest tests
5 changes: 4 additions & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.utils import (
FakeTokenizer,
GPUMemoryMonitor,
PerfCounter,
get_module_signature,
Expand Down Expand Up @@ -137,7 +138,9 @@ def train(config: Config):
config.ckpt.interval % config.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"

if config.type_model == "llama2":
if config.data.fake and config.name_model == "debugmodel":
tokenizer = FakeTokenizer()
elif config.type_model == "llama2":
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
elif config.type_model == "llama3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
Expand Down
8 changes: 8 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,11 @@ def get_random_available_port_list(num_port):

def get_random_available_port():
return get_random_available_port_list(1)[0]


class FakeTokenizer(object):
def __init__(self):
self.vocab_size = 1000
self.bos_token_id = 0
self.eos_token_id = 1
self.pad_token_id = 2
6 changes: 5 additions & 1 deletion tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from zeroband.diloco import Compression

import torch

num_gpu = torch.cuda.device_count()


def get_random_available_port_list(num_port):
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
Expand Down Expand Up @@ -77,7 +81,7 @@ def test_multi_gpu(num_gpus):
_test_multi_gpu(num_gpus, "debug/normal.toml")


@pytest.mark.parametrize("num_gpus", [[2, 1], [2, 2]])
@pytest.mark.parametrize("num_gpus", [[2, 1], [2, 2]] if num_gpu >= 4 else [[2, 1]])
def test_multi_gpu_diloco(num_gpus):
_test_multi_gpu(num_gpus, "debug/diloco.toml", diloco=True)

Expand Down