Skip to content

Commit

Permalink
add fake tokenizer for running test without hf token
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Nov 22, 2024
1 parent 07e9253 commit 282eef2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.utils import (
FakeTokenizer,
GPUMemoryMonitor,
PerfCounter,
get_module_signature,
Expand Down Expand Up @@ -137,7 +138,9 @@ def train(config: Config):
config.ckpt.interval % config.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"

if config.type_model == "llama2":
if config.data.fake and config.name_model == "debugmodel":
tokenizer = FakeTokenizer()
elif config.type_model == "llama2":
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
elif config.type_model == "llama3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
Expand Down
8 changes: 8 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,11 @@ def get_random_available_port_list(num_port):

def get_random_available_port():
return get_random_available_port_list(1)[0]


class FakeTokenizer(object):
def __init__(self):
self.vocab_size = 1000
self.bos_token_id = 0
self.eos_token_id = 1
self.pad_token_id = 2

0 comments on commit 282eef2

Please sign in to comment.