diff --git a/magnet/utils/globals.py b/magnet/utils/globals.py index 09cb6cf..cd4b825 100644 --- a/magnet/utils/globals.py +++ b/magnet/utils/globals.py @@ -5,7 +5,38 @@ import boto3 from spacy.lang.en import English import inspect +from transformers import AutoTokenizer +def break_into_chunks(text, model_name, max_tokens): + """ + Break a text into chunks of a specified max number of tokens using the tokenizer of a given model. + + Parameters: + - text (str): The text to break into chunks. + - model_name (str): The model or tokenizer name to use for tokenizing the text. + - max_tokens (int): The maximum number of tokens per chunk. + + Returns: + - List[str]: A list of text chunks, each with up to max_tokens tokens. + """ + # Load the tokenizer for the specified model + tokenizer = AutoTokenizer.from_pretrained(model_name) + + tokens = tokenizer.tokenize(text) + + current_chunk = [] + chunks = [] + + for token in tokens: + current_chunk.append(token) + if len(current_chunk) == max_tokens: + chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) + current_chunk = [] + + if current_chunk: + chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) + + return chunks def reversal(): return inspect.getsource(inspect.currentframe().f_back)