diff --git a/eng/Versions.props b/eng/Versions.props index aeb07aebbe..b494d4d1b3 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -41,7 +41,7 @@ 1.0.0-beta.23509.3 1.16.3 0.0.0.12 - 2023.0.0.23189 @@ -87,6 +87,7 @@ 0.0.13-test 0.0.6-test 0.0.7-test + 2.0.0-beta.24218.2 4.8.6 1.0.118 1.2.7 diff --git a/src/Microsoft.ML.Tokenizers/EncodingResult.cs b/src/Microsoft.ML.Tokenizers/EncodingResult.cs deleted file mode 100644 index 9401cfa490..0000000000 --- a/src/Microsoft.ML.Tokenizers/EncodingResult.cs +++ /dev/null @@ -1,152 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// The Encoding represents the output of a Tokenizer. - /// - public sealed class EncodingResult - { - /// - /// Create a new object of the EncodingResult object. - /// - /// The list of tokens to merge. - /// The list of tokens to merge. - /// The list of tokens to merge. - /// Indicate whether the offsets is mapped to the original string or the normalized string. - public EncodingResult(string originalString, string normalizedString, IEnumerable splits, bool offsetsMappedToOriginalString) - { - OriginalString = originalString; - NormalizedString = normalizedString; - Splits = splits; - OffsetsMappedToOriginalString = offsetsMappedToOriginalString; - } - - /// - /// Gets the original tokenized string. - /// - public string? OriginalString { get; } - - /// - /// Gets the normalized form of the original string. - /// - public string? NormalizedString { get; } - - /// - /// Gets the normalized form of the original string. - /// - public bool OffsetsMappedToOriginalString { get; } - - internal IEnumerable Splits { get; } - private List? _tokens; - private List? _tokensWords; - private List? _ids; - private List<(int Index, int Length)>? _offsets; - - internal void AddTokens(IReadOnlyList addedTokens) - { - if (_tokens is null) - { - _tokens = new(addedTokens); - return; - } - - foreach (var token in addedTokens) - { - _tokens.Add(token); - } - } - - /// - /// Gets list of the tokens Ids. - /// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands. - /// - public IReadOnlyList Ids - { - get - { - if (_ids is not null) - { - return _ids; - } - - if (_tokens is null) - { - return Array.Empty(); - } - - _ids = new List(_tokens.Count); - - foreach (var token in _tokens) - { - _ids.Add(token.Id); - } - - return _ids; - } - } - - /// - /// Gets the generated tokens. They are the string representation of the Ids. - /// - public IReadOnlyList Tokens - { - get - { - if (_tokensWords is not null) - { - return _tokensWords; - } - - if (_tokens is null) - { - return Array.Empty(); - } - - _tokensWords = new List(_tokens.Count); - - foreach (var token in _tokens) - { - _tokensWords.Add(token.Value); - } - - return _tokensWords; - } - } - - /// - /// Gets The list of offsets. These offsets let's you slice the input string, and thus retrieve - /// the original part that led to producing the corresponding token. - /// - public IReadOnlyList<(int Index, int Length)> Offsets - { - get - { - if (_offsets is not null) - { - return _offsets; - } - - if (_tokens is null) - { - return Array.Empty<(int, int)>(); - } - - _offsets = new List<(int Index, int Length)>(_tokens.Count); - - foreach (var token in _tokens) - { - _offsets.Add(token.Offset); - } - - return _offsets; - } - } - } -} diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 5a8fb4d330..ad6627da8c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -17,13 +17,15 @@ namespace Microsoft.ML.Tokenizers /// /// Represent the Byte Pair Encoding model. /// - public sealed class Bpe : Model + public sealed class Bpe : Tokenizer { /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. private const int MaxWordLengthToCache = 15; private string? _unknownToken; private int? _unknownTokenId; + private readonly PreTokenizer? _preTokenizer; + private readonly Normalizer? _normalizer; /// /// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char @@ -74,13 +76,15 @@ private set /// /// The JSON file path containing the dictionary of string keys and their ids. /// The file path containing the tokens's pairs list. + /// The pre-tokenizer to use. + /// The normalizer to use. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. /// Indicate whether allowing multiple unknown tokens get fused. - public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : + public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read), - mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true) + mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true) { } @@ -89,16 +93,18 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st /// /// The JSON stream containing the dictionary of string keys and their ids. /// The stream containing the tokens's pairs list. + /// The pre-tokenizer to use. + /// The normalizer to use. /// The unknown token to be used by the model. /// The prefix to attach to sub-word units that don’t represent a beginning of word. /// The suffix to attach to sub-word units that represent an end of word. /// Indicate whether allowing multiple unknown tokens get fused. - public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : - this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false) + public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) : + this(vocabStream, mergesStream, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false) { } - private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams) + private Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams) { try { @@ -110,6 +116,8 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri FuseUnknownTokens = fuseUnknownTokens; ContinuingSubwordPrefix = continuingSubwordPrefix; EndOfWordSuffix = endOfWordSuffix; + _preTokenizer = preTokenizer ?? WhiteSpace.Instance; // Default to WhiteSpace pre-tokenizer + _normalizer = normalizer; (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); _vocab = vocab1 ?? new Dictionary(); @@ -166,47 +174,320 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri } /// - /// Encode a text string to a list of tokens. + /// Gets the PreTokenizer used by the Tokenizer. + /// + public override PreTokenizer? PreTokenizer => _preTokenizer; + + /// + /// Gets the Normalizer in use by the Tokenizer. + /// + public override Normalizer? Normalizer => _normalizer; + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. /// /// The text to encode. - /// The list of tokens generated from the text tokenization. - public override IReadOnlyList Encode(ReadOnlySpan text) + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span.Empty, out normalizedString, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization); + + private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization) { - if (text.Length == 0) + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return []; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + List tokens = new(); + PriorityQueue? priorityQueue = null; + + if (splits is not null) + { + foreach ((int Offset, int Length) split in splits) + { + EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue); + } + } + else { - return EmptyTokensList; + EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue); } - return EncodeWithCache(text); + return tokens; } /// - /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. + /// Encodes input text to token Ids. /// /// The text to encode. - /// The list of accumulated encoded Ids. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); /// - /// Get the number of tokens that the input text will be encoded to. + /// Encodes input text to token Ids up to maximum number of tokens. /// /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + textLength = 0; + normalizedString = null; + return []; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + List ids = new(); + PriorityQueue? priorityQueue = null; + + if (splits is not null) + { + textLength = 0; + foreach ((int Offset, int Length) split in splits) + { + EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue); + textLength = split.Offset + length; + + if (length < split.Length || ids.Count >= maxTokenCount) + { + break; + } + } + } + else + { + EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out textLength, ref priorityQueue); + } + + return ids; + } /// /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + private int CountTokens(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + textLength = 0; + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + PriorityQueue? priorityQueue = null; + int count = 0; + if (splits is not null) + { + foreach ((int Offset, int Length) split in splits) + { + count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue); + textLength = split.Offset + length; + + if (length < split.Length || count >= maxTokenCount) + { + break; + } + } + } + else + { + count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue); + } + + return count; + } + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled; + /// conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(text, Span.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + tokenCount = 0; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + PriorityQueue? priorityQueue = null; + + if (splits is not null) + { + tokenCount = 0; + foreach ((int Offset, int Length) split in splits.Reverse()) + { + tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue); + if (textIndex > 0 || tokenCount >= maxTokenCount) + { + return split.Offset + textIndex; + } + } + } + else + { + tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int textLength, ref priorityQueue); + return textLength; + } + + return 0; + } /// /// Map the token to encoded Id. @@ -383,7 +664,7 @@ internal string CharToString(char c) return s; } - internal Word MergeWord(ReadOnlySpan w) + internal Word MergeWord(ReadOnlySpan w, ref PriorityQueue? priorityQueue) { Word word = Word.WithCapacity(w.Length); (int Id, int Len)? unk = null; @@ -494,23 +775,24 @@ internal Word MergeWord(ReadOnlySpan w) word.Add(unk.Value.Id, unk.Value.Len); } - word.MergeAll(Merges, Dropout); + word.MergeAll(Merges, Dropout, ref priorityQueue); return word; } - internal List WordToTokens(ref Word word) => word.ToTokens(VocabReverse); + internal void WordToTokens(ref Word word, List tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset); - internal List EncodeWithCache(ReadOnlySpan text) + internal void EncodeWithCache(ReadOnlySpan text, List tokens, int offset, ref PriorityQueue? priorityQueue) { Word word; if (Cache is not null) { if (Cache.TryGetValue(text, out word)) { - return WordToTokens(ref word); + WordToTokens(ref word, tokens, offset); + return; } - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); if (text.Length <= MaxWordLengthToCache) { @@ -519,15 +801,15 @@ internal List EncodeWithCache(ReadOnlySpan text) } else { - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); } - return WordToTokens(ref word); + WordToTokens(ref word, tokens, offset); } internal int WordToIds(ref Word word, IList? accumulatedIds, out int textLength, int fullTextLength, int maxTokens) { - if (word.SymbolsCount < maxTokens) + if (word.SymbolsCount <= maxTokens) { textLength = fullTextLength; if (accumulatedIds is not null) @@ -548,7 +830,7 @@ internal int WordToIds(ref Word word, IList? accumulatedIds, out int textLe internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens) { - if (word.SymbolsCount < maxTokens) + if (word.SymbolsCount <= maxTokens) { textIndex = 0; if (accumulatedIds is not null) @@ -567,7 +849,7 @@ internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex); } - internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textLength) + private int EncodeToIdsWithCache(ReadOnlySpan text, List? accumulatedIds, int maxTokens, out int textLength, ref PriorityQueue? priorityQueue) { Word word; @@ -578,7 +860,7 @@ internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulat return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens); } - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); if (text.Length <= MaxWordLengthToCache) { @@ -587,13 +869,13 @@ internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulat } else { - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); } return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens); } - internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textIndex) + internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue? priorityQueue) { Word word; @@ -604,7 +886,7 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens); } - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); if (text.Length <= MaxWordLengthToCache) { @@ -613,12 +895,10 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac } else { - word = MergeWord(text); + word = MergeWord(text, ref priorityQueue); } return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens); } - - internal static readonly List EmptyTokensList = new(); } } diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index 013af1917f..ac4da084c2 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -8,6 +8,7 @@ using System.Diagnostics; using System.IO; using System.Linq; +using System.Text; using System.Text.Json; namespace Microsoft.ML.Tokenizers @@ -15,7 +16,7 @@ namespace Microsoft.ML.Tokenizers /// /// Represent the Byte Pair Encoding model. /// - public sealed class EnglishRoberta : Model + public sealed class EnglishRoberta : Tokenizer { private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; private readonly Dictionary _vocab; @@ -26,6 +27,8 @@ public sealed class EnglishRoberta : Model private readonly IReadOnlyDictionary _unicodeToByte; private readonly string[] _charToString; private readonly StringSpanOrdinalKeyCache> _cache; + private readonly PreTokenizer? _preTokenizer; + private readonly Normalizer? _normalizer; /// /// Indicate if want to filter the unsupported characters during the decoding. @@ -38,47 +41,15 @@ public sealed class EnglishRoberta : Model /// The JSON file path containing the dictionary of string keys and their ids. /// The file path containing the tokens's pairs list. /// Remap the original GPT-2 model Ids to high occurrence ranks and values. + /// The pre-tokenizer to use. + /// The normalizer to use. /// Indicate if want to filter the unsupported characters during the decoding. - public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, bool filterUnsupportedChars = true) + public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) : + this(vocabularyPath is null ? throw new ArgumentNullException(nameof(vocabularyPath)) : File.OpenRead(vocabularyPath), + mergePath is null ? throw new ArgumentNullException(nameof(mergePath)) : File.OpenRead(mergePath), + highestOccurrenceMappingPath is null ? throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)) : File.OpenRead(highestOccurrenceMappingPath), + preTokenizer, normalizer, filterUnsupportedChars, true) { - if (vocabularyPath is null) - { - throw new ArgumentNullException(nameof(vocabularyPath)); - } - - if (mergePath is null) - { - throw new ArgumentNullException(nameof(mergePath)); - } - - if (highestOccurrenceMappingPath is null) - { - throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)); - } - - FilterUnsupportedChars = filterUnsupportedChars; - - using Stream vocabularyStream = File.OpenRead(vocabularyPath); - using Stream mergeStream = File.OpenRead(mergePath); - using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath); - - // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" - // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" - // highestOccurrenceMappingPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt" - - _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream); - _vocab = GetVocabulary(vocabularyStream); - _vocabReverse = _vocab.ReverseSorted(); - _mergeRanks = GetMergeRanks(mergeStream); - int maxCharValue = GetByteToUnicode(out _byteToUnicode); - _charToString = new string[maxCharValue]; - for (char c = (char)0; c < (char)maxCharValue; c++) - { - _charToString[c] = c.ToString(); - } - - _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new StringSpanOrdinalKeyCache>(); } /// @@ -87,8 +58,15 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc /// The stream of a JSON file containing the dictionary of string keys and their ids. /// The stream of a file containing the tokens's pairs list. /// Remap the original GPT-2 model Ids to high occurrence ranks and values. + /// The pre-tokenizer to use. + /// The normalizer to use. /// Indicate if want to filter the unsupported characters during the decoding. - public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, bool filterUnsupportedChars = true) + public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) : + this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, false) + { + } + + public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream) { if (vocabularyStream is null) { @@ -106,6 +84,12 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } FilterUnsupportedChars = filterUnsupportedChars; + _preTokenizer = preTokenizer; + _normalizer = normalizer; + + // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" + // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" + // highestOccurrenceMappingPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt" _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream); _vocab = GetVocabulary(vocabularyStream); @@ -120,8 +104,25 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes _unicodeToByte = _byteToUnicode.Reverse(); _cache = new StringSpanOrdinalKeyCache>(); + + if (disposeStream) + { + vocabularyStream.Dispose(); + mergeStream.Dispose(); + highestOccurrenceMappingStream.Dispose(); + } } + /// + /// Gets the PreTokenizer used by the Tokenizer. + /// + public override PreTokenizer? PreTokenizer => _preTokenizer; + + /// + /// Gets the Normalizer in use by the Tokenizer. + /// + public override Normalizer? Normalizer => _normalizer; + /// /// Gets the dictionary mapping tokens to Ids. /// @@ -167,16 +168,64 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes return null; } + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span.Empty, out normalizedString, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization); + + private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization) + { + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return []; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + if (splits is not null) + { + List tokens = new(); + foreach ((int Offset, int Length) split in splits) + { + foreach (Token t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length))) + { + tokens.Add(new Token(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length))); + } + } + return tokens; + } + else + { + return EncodeInternal(textSpanToEncode); + } + } + /// /// Encode a text string to a list of tokens. /// /// The text to encode. /// The list of tokens generated from the text tokenization. - public override IReadOnlyList Encode(ReadOnlySpan text) + private IReadOnlyList EncodeInternal(ReadOnlySpan text) { if (text.IsEmpty) { - return Bpe.EmptyTokensList; + return []; } char[] token = ArrayPool.Shared.Rent(text.Length); @@ -197,7 +246,7 @@ public override IReadOnlyList Encode(ReadOnlySpan text) { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); - return Array.Empty(); + return []; } if (_cache.TryGetValue(text, out List? hit)) @@ -215,32 +264,257 @@ public override IReadOnlyList Encode(ReadOnlySpan text) } /// - /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. /// /// The text to encode. - /// The list of accumulated encoded Ids. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); /// - /// Get the number of tokens that the input text will be encoded to. + /// Encodes input text to token Ids up to maximum number of tokens. /// /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + textLength = 0; + normalizedString = null; + return []; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + List ids = new(); + if (splits is not null) + { + textLength = 0; + foreach ((int Offset, int Length) split in splits) + { + EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count); + textLength = split.Offset + length; + + if (length < split.Length || ids.Count >= maxTokenCount) + { + break; + } + } + } + else + { + EncodeToIdsInternal(textSpanToEncode, ids, out textLength, maxTokenCount); + } + + return ids; + } /// /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + private int CountTokens(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + textLength = 0; + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + int count = 0; + if (splits is not null) + { + foreach ((int Offset, int Length) split in splits) + { + count += EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int length, maxTokenCount - count); + textLength = split.Offset + length; + + if (length < split.Length || count >= maxTokenCount) + { + break; + } + } + } + else + { + count += EncodeToIdsInternal(textSpanToEncode, null, out textLength, maxTokenCount); + } + + return count; + } + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled; + /// conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(text, Span.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + tokenCount = 0; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + if (splits is not null) + { + tokenCount = 0; + foreach ((int Offset, int Length) split in splits.Reverse()) + { + tokenCount += EncodeToIdsFromEndInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int textIndex, maxTokenCount - tokenCount); + if (textIndex > 0 || tokenCount >= maxTokenCount) + { + return split.Offset + textIndex; + } + } + } + else + { + tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int textLength, maxTokenCount); + return textLength; + } + + return 0; + } private int EncodeToIdsResult(List tokens, IList? accumulatedIds, int maxTokens, int fullTextLength, out int textLength) { @@ -347,7 +621,7 @@ private int EncodeToIdsInternal(ReadOnlySpan text, IList? accumulated { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); - textLength = 0; + textLength = text.Length; return 0; } @@ -390,7 +664,7 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); - textIndex = text.Length; + textIndex = 0; return 0; } @@ -410,7 +684,32 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu public override int? MapTokenToId(ReadOnlySpan token) => _vocab.TryGetValue(token, out int value) ? value : null; /// - /// Convert a list of tokens Ids to highest occurrence rankings. + /// Decode the given ids, back to a String. + /// + /// The list of ids that we want to decode. + /// The decoded string. + public override string? Decode(IEnumerable ids) + { + if (ids is null) + { + throw new ArgumentNullException(nameof(ids)); + } + + ValueStringBuilder sb = new ValueStringBuilder(); + + foreach (int id in ids) + { + if (MapIdToToken(id) is string s) + { + sb.Append(s); + } + } + + return sb.ToString(); + } + + /// + /// Convert a list of token Ids to highest occurrence rankings. /// /// The Ids list to map to the high occurrence rank. /// The list of ranks mapped from the list of Ids. @@ -432,7 +731,7 @@ public IReadOnlyList ConvertIdsToOccurrenceRanks(IReadOnlyList ids) } /// - /// Convert a list of tokens Ids to highest occurrence values. + /// Convert a list of token Ids to highest occurrence values. /// /// The Ids list to map to the high occurrence values. /// The list of occurrence values mapped from the list of Ids. @@ -454,7 +753,7 @@ public IReadOnlyList ConvertIdsToOccurrenceValues(IReadOnlyList ids) } /// - /// Convert a list of highest occurrence rankings to tokens Ids list . + /// Convert a list of highest occurrence rankings to token Ids list . /// /// The high occurrence ranks list to map to the Ids list. /// The list of Ids mapped from the list of ranks. @@ -638,7 +937,7 @@ private List EncodeToTokens(Span token, Span indexMapping) { if (token.Length == 0) { - return Bpe.EmptyTokensList; + return []; } if (token.Length == 1) diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs deleted file mode 100644 index c6cc5578df..0000000000 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ /dev/null @@ -1,180 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// Represents a model used during Tokenization (like BPE or Word Piece or Unigram). - /// - public abstract class Model - { - /// - /// Encode a text to a list of tokens. - /// - /// The text to encode. - /// The list of tokens generated from the text tokenization. - public abstract IReadOnlyList Encode(ReadOnlySpan text); - - /// - /// Encode a text to a list of Ids and add them to the accumulatedIds list. - /// - /// The text to encode. - /// The list of accumulated encoded Ids. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// - /// This method does the default implementation that uses the Encode method to get the token's Ids. - /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. - /// - public virtual int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) - { - if (accumulatedIds is null) - { - throw new ArgumentNullException(nameof(accumulatedIds)); - } - - // Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance. - textLength = 0; - var tokens = Encode(text); - - int count = Math.Min(tokens.Count, maxTokens); - - for (int i = 0; i < count; i++) - { - textLength += tokens[i].Offset.Length; - accumulatedIds.Add(tokens[i].Id); - } - - return count; - } - - /// - /// Get the number of tokens that the input text will be encoded to. - /// - /// The text to encode. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// - /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. - /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. - /// - public virtual int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) - { - if (maxTokens <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0."); - } - - var ids = new List(); - - if (maxTokens == int.MaxValue) - { - EncodeToIds(text, ids, out _); - textLength = text.Length; - return ids.Count; - } - - IReadOnlyList tokens = Encode(text); - textLength = 0; - int count = Math.Min(tokens.Count, maxTokens); - for (int i = 0; i < count; i++) - { - textLength += tokens[i].Offset.Length; - } - - return count; - } - - /// - /// Get the number of tokens that the input text will be encoded to. - /// - /// The text to encode. - /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// - /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. - /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. - /// - public virtual int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) - { - if (maxTokens <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0."); - } - - var ids = new List(); - - if (maxTokens == int.MaxValue) - { - EncodeToIds(text, ids, out _); - textIndex = 0; - return ids.Count; - } - - IReadOnlyList tokens = Encode(text); - textIndex = text.Length; - int count = Math.Min(tokens.Count, maxTokens); - - int tokensCount = tokens.Count; - int end = tokensCount - count; - for (int i = tokensCount - 1; i >= end; i--) - { - textIndex -= tokens[i].Offset.Length; - } - - return count; - } - - /// - /// Map the token to encoded id with the option to skip the special tokens. - /// - /// The token to map to Id - /// The mapped Id of the token. - public abstract int? MapTokenToId(ReadOnlySpan token); - - /// - /// Map the encoded Id to the token. - /// - /// The Id to map to the token. - /// The mapped token of the Id. - public abstract string? MapIdToToken(int id); - - /// - /// Decode the given ids, back to a String. - /// - /// The list of ids that we want to decode. - /// The decoded string. - /// - /// This method does the default implementation that uses the MapIdToToken method to get the token. - /// Tokenizer models may opt to override this method to ensure accurate results if the default implementation - /// provided here proves insufficient for the model's specific scenario. - /// - public virtual string? Decode(IEnumerable ids) - { - if (ids is null) - { - throw new ArgumentNullException(nameof(ids)); - } - - ValueStringBuilder sb = new ValueStringBuilder(); - - foreach (int id in ids) - { - if (MapIdToToken(id) is string s) - { - sb.Append(s); - } - } - - return sb.ToString(); - } - } -} diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs index ab43325349..889d63f059 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Tokenizers /// /// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model. /// - public sealed class SentencePieceBpe : Model + public sealed class SentencePieceBpe : Tokenizer { private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id. @@ -29,9 +29,10 @@ public sealed class SentencePieceBpe : Model private readonly int _maxByteId; private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids. private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character. + private readonly Normalizer? _normalizer; internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) : - this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto) + this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto) { AddBeginningOfSentence = addBos; AddEndOfSentence = addEos; @@ -65,6 +66,8 @@ private SentencePieceBpe(ModelProto modelProto) EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces; TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix; ByteFallback = modelProto.TrainerSpec.ByteFallback; + + _normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix); } /// @@ -127,6 +130,16 @@ private SentencePieceBpe(ModelProto modelProto) /// public int UnknownId { get; } + /// + /// Gets the PreTokenizer used by the Tokenizer. + /// + public override PreTokenizer? PreTokenizer => null; + + /// + /// Gets the Normalizer in use by the Tokenizer. + /// + public override Normalizer? Normalizer => _normalizer; + /// /// The vocabulary of the model. /// @@ -152,12 +165,74 @@ public IReadOnlyDictionary Vocab } /// - /// Encode a text to a list of tokens. + /// Encodes input text a list of s with string value of the token, id, and offset. /// /// The text to encode. - /// The list of tokens generated from the text tokenization. - /// The input text has to be normalized before calling this method. - public override IReadOnlyList Encode(ReadOnlySpan text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence); + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) + => Encode(text, Span.Empty, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) + => Encode(null, text, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public IReadOnlyList Encode(string text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => Encode(text, Span.Empty, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => Encode(null, text, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization); + + private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization) + { + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return []; + } + + ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan(); + if (considerNormalization && _normalizer is not null) + { + normalizedString = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan); + textToEncode = normalizedString.AsSpan(); + } + else + { + normalizedString = null; + } + + return EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence); + } /// /// Encode a text to a list of tokens. @@ -167,11 +242,11 @@ public IReadOnlyDictionary Vocab /// Indicate emitting the end of sentence token during the encoding. /// The list of tokens generated from the text tokenization. /// The input text has to be normalized before calling this method. - public IReadOnlyList Encode(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence) + private IReadOnlyList EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence) { if (text.Length == 0) { - return Array.Empty(); + return []; } BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); @@ -306,16 +381,168 @@ revMerge is null || } /// - /// Encode a text to a list of Ids and add them to the accumulatedIds list. + /// Encodes input text to tokes Ids. /// /// The text to encode. - /// The list of accumulated encoded Ids. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// The input text has to be normalized before calling this method. - public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) - => EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + + private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + textLength = 0; + return []; + } + + return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + } + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider normalization before tokenization. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. + /// The list of encoded Ids. + public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, + out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (text.IsEmpty) + { + normalizedString = null; + textLength = 0; + return []; + } + + ReadOnlySpan textToEncode; + + if (considerNormalization && _normalizer is not null) + { + normalizedString = _normalizer.Normalize(text); + textToEncode = normalizedString.AsSpan(); + } + else + { + normalizedString = null; + textToEncode = text; + } + + List ids = new(); + + EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out textLength, maxTokenCount); + + return ids; + } /// /// Encode a text to a list of Ids and add them to the accumulatedIds list. @@ -328,7 +555,7 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. /// The input text has to be normalized before calling this method. - public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) + private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) { if (maxTokens <= 0) { @@ -378,7 +605,8 @@ public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool ad { if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength)) { - break; + ArrayPool.Shared.Return(symbols); + return idsCount; } } else @@ -391,6 +619,7 @@ public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool ad } else { + ArrayPool.Shared.Return(symbols); return idsCount; } } @@ -510,21 +739,252 @@ revMerge is null || /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// The number of tokens that the input text will be encoded to. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// The input text has to be normalized before calling this method. - public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens); + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _); /// /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _); + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public int IndexOfTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public int IndexOfTokenCount(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + => CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider normalization before tokenization. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. - public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens); + public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (text.IsEmpty) + { + normalizedString = null; + textLength = 0; + return 0; + } + + ReadOnlySpan textToEncode; + if (considerNormalization && _normalizer is not null) + { + normalizedString = _normalizer.Normalize(text); + textToEncode = normalizedString.AsSpan(); + } + else + { + normalizedString = null; + textToEncode = text; + } + + return CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textLength, maxTokenCount); + } + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled; + /// conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(text, Span.Empty, maxTokenCount, considerNormalization, out normalizedString, out tokenCount); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(null, text, maxTokenCount, considerNormalization, out normalizedString, out tokenCount); + + private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedString, out int tokenCount) + => LastIndexOfTokenCount(text is null ? textSpan : text.AsSpan(), maxTokenCount, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out tokenCount); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider normalization before tokenization. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + public int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int tokenCount) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + if (text.IsEmpty) + { + normalizedString = null; + tokenCount = 0; + return 0; + } + + ReadOnlySpan textToEncode; + if (considerNormalization && _normalizer is not null) + { + normalizedString = _normalizer.Normalize(text); + textToEncode = normalizedString.AsSpan(); + } + else + { + normalizedString = null; + textToEncode = text; + } + + tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out int textIndex, maxTokenCount); + return textIndex; + } /// /// Get the number of tokens that the input text will be encoded to. @@ -536,7 +996,7 @@ revMerge is null || /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. /// The input text has to be normalized before calling this method. - public int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue) + private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue) { textLength = 0; if (text.IsEmpty) @@ -705,7 +1165,7 @@ revMerge is null || /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. /// The input text has to be normalized before calling this method. - public int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue) + private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue) { textIndex = text.Length; if (text.IsEmpty) @@ -944,7 +1404,7 @@ revMerge is null || else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) { // escape the dummy prefix if needed. - sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == LlamaNormalizer.DummyPrefix ? + sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == SentencePieceNormalizer.DummyPrefix ? token.AsSpan(1) : token.AsSpan()); } @@ -999,7 +1459,7 @@ revMerge is null || FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); } - if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == LlamaNormalizer.DummyPrefix) + if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == SentencePieceNormalizer.DummyPrefix) { sb.RemoveLastChar(); } @@ -1014,7 +1474,7 @@ revMerge is null || ArrayPool.Shared.Return(charPoolArray); } - return sb.ToString(LlamaNormalizer.DummyPrefix, ' '); + return sb.ToString(SentencePieceNormalizer.DummyPrefix, ' '); static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb) { diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 8817ce6616..91b1d2cf9d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -19,9 +19,9 @@ namespace Microsoft.ML.Tokenizers { /// - /// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken. + /// Represent the rapid Byte Pair Encoding tokenizer. /// - public sealed partial class Tiktoken : Model + public sealed partial class Tiktoken : Tokenizer { private readonly Dictionary, int> _encoder; private readonly Dictionary> _decoder; @@ -29,46 +29,56 @@ public sealed partial class Tiktoken : Model private readonly Dictionary _vocab; private IReadOnlyDictionary? _vocabOriginal; private const int MaxWordLengthToCache = 15; + private readonly PreTokenizer? _preTokenizer; + private readonly Normalizer? _normalizer; /// - /// Create a new Tiktoken tokenizer's model object. + /// Create a new Tiktoken tokenizer's object. /// /// The path to the BPE vocab file. + /// The pre-tokenizer to use. /// The dictionary mapping special tokens to Ids. + /// The normalizer to use. /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : - this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true) + public Tiktoken(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache.DefaultCacheSize) : + this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: true) { } /// - /// Create a new Tiktoken tokenizer's model object. + /// Create a new Tiktoken tokenizer's object. /// /// The stream to the BPE vocab file. + /// The pre-tokenizer to use. /// The dictionary mapping special tokens to Ids. + /// The normalizer to use. /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : - this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false) + public Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary? specialTokens = null, Normalizer? normalizer = null, int cacheSize = LruCache.DefaultCacheSize) : + this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), preTokenizer, specialTokens, normalizer, cacheSize, disposeStream: false) { } /// - /// Create a new Tiktoken tokenizer's model object. + /// Create a new Tiktoken tokenizer's object. /// /// The dictionary mapping token utf-8 bytes to Ids. /// The dictionary mapping Ids to token utf-8 bytes. /// The dictionary mapping string tokens to Ids. + /// The pre-tokenizer to use. /// The dictionary mapping special tokens to Ids. + /// The normalizer to use. /// The max size of the cache to use. internal Tiktoken( Dictionary, int> encoder, Dictionary> decoder, Dictionary vocab, + PreTokenizer? preTokenizer, IReadOnlyDictionary? specialTokens, + Normalizer? normalizer = null, int cacheSize = LruCache.DefaultCacheSize) { _encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); @@ -78,19 +88,26 @@ internal Tiktoken( _encoder = encoder!; _decoder = decoder!; _vocab = vocab!; + + _preTokenizer = preTokenizer; + _normalizer = normalizer; + _cache = new LruCache<(int[] Bytes, string Token)>(cacheSize); SpecialTokens = specialTokens; CacheSpecialTokensEncoding(specialTokens); } - private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) + private Tiktoken(Stream vocabStream, PreTokenizer? preTokenizer, IReadOnlyDictionary? specialTokens, Normalizer? normalizer, int cacheSize, bool disposeStream) { try { _cache = new LruCache<(int[] Bytes, string Token)>(cacheSize); (_encoder, _vocab, _decoder) = LoadTiktokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult(); + _preTokenizer = preTokenizer; + _normalizer = normalizer; + SpecialTokens = specialTokens; CacheSpecialTokensEncoding(specialTokens); } @@ -103,6 +120,16 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo } } + /// + /// Gets the PreTokenizer used by the Tokenizer. + /// + public override PreTokenizer? PreTokenizer => _preTokenizer; + + /// + /// Gets the Normalizer in use by the Tokenizer. + /// + public override Normalizer? Normalizer => _normalizer; + private void CacheSpecialTokensEncoding(IReadOnlyDictionary? specialTokens) { Debug.Assert(_cache is not null); @@ -118,54 +145,6 @@ private void CacheSpecialTokensEncoding(IReadOnlyDictionary? specia } } - /// - /// Create a new Tiktoken tokenizer's model object asynchronously. - /// - /// The stream to the BPE vocab file. - /// The dictionary mapping special tokens to Ids. - /// The size of the cache to use. - /// used to request cancellation of the operation. - /// Tiktoken tokenizer's object. - public static async Task CreateAsync( - Stream vocabStream, - IReadOnlyDictionary? specialTokens = null, - int cacheSize = LruCache.DefaultCacheSize, - CancellationToken cancellationToken = default) - { - if (vocabStream is null) - { - throw new ArgumentNullException(nameof(vocabStream)); - } - - (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = - await LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false); - - return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize); - } - - /// - /// Create a new Tiktoken tokenizer's object asynchronously. - /// - /// The BPE vocab file. - /// The dictionary mapping special tokens to Ids. - /// The size of the cache to use. - /// used to request cancellation of the operation. - /// Tiktoken tokenizer's model object. - public static async Task CreateAsync( - string vocabFilePath, - IReadOnlyDictionary? specialTokensEncoder = null, - int cacheSize = LruCache.DefaultCacheSize, - CancellationToken cancellationToken = default) - { - if (vocabFilePath is null) - { - throw new ArgumentNullException(nameof(vocabFilePath)); - } - - using Stream vocabStream = File.OpenRead(vocabFilePath); - return await CreateAsync(vocabStream, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false); - } - /// /// Load BPE vocab dictionary from a stream. /// @@ -255,36 +234,79 @@ void AddData(byte[] tokenBytes, int rank) } /// - /// Encode text to a list of tokens. + /// Encodes input text a list of s with string value of the token, id, and offset. /// /// The text to encode. - /// The list of tokens generated from the text tokenization. - public override IReadOnlyList Encode(ReadOnlySpan text) + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span.Empty, out normalizedString, considerPreTokenization, considerNormalization); + + /// + /// Encodes input text a list of s with string value of the token, id, and offset. + /// + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization); + + private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization) { - Token[] tokens; + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return []; + } - if (text.IsEmpty) + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + List tokens = new(); + + if (splits is not null) { - return Array.Empty(); + foreach ((int Offset, int Length) split in splits) + { + Encode(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset); + } } + else + { + Encode(textSpanToEncode, tokens, 0); + } + + return tokens; + } + + /// + /// Encode text to a list of tokens. + /// + /// The text to encode. + /// The list of tokens to populate. + /// The offset to start encoding from. + private void Encode(ReadOnlySpan text, List tokens, int offset) + { + Debug.Assert(!text.IsEmpty); if (_cache.TryGetValue(text, out (int[] Ids, string Token) value)) { - tokens = new Token[value.Ids.Length]; - tokens[0] = new Token(value.Ids[0], value.Token, (0, value.Token.Length)); + tokens.Add(new Token(value.Ids[0], value.Token, (offset, value.Token.Length))); for (int i = 1; i < value.Ids.Length; i++) { // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. - tokens[i] = new Token(value.Ids[i], "", (text.Length, 0)); + tokens.Add(new Token(value.Ids[i], "", (offset + text.Length, 0))); } - return tokens; + return; } // cache miss if (_vocab.TryGetValue(text, out (int Id, string Token) mappedId)) { - return new Token[1] { new(mappedId.Id, mappedId.Token, (0, mappedId.Token.Length)) }; + tokens.Add(new Token(mappedId.Id, mappedId.Token, (offset, mappedId.Token.Length))); + return; } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); @@ -301,15 +323,98 @@ public override IReadOnlyList Encode(ReadOnlySpan text) _cache.Add(textAsString, (encodedIds, textAsString)); } - tokens = new Token[encodedIds.Length]; - tokens[0] = new Token(encodedIds[0], textAsString, (0, text.Length)); + tokens.Add(new Token(encodedIds[0], textAsString, (offset, text.Length))); for (int i = 1; i < encodedIds.Length; i++) { // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. - tokens[i] = new Token(encodedIds[i], "", (text.Length, 0)); + tokens.Add(new Token(encodedIds[i], "", (offset + text.Length, 0))); } + } - return tokens; + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount); + + private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + textLength = 0; + normalizedString = null; + return []; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + List ids = new(); + + if (splits is not null) + { + textLength = 0; + foreach ((int Offset, int Length) split in splits) + { + EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count); + textLength = split.Offset + length; + + if (length < split.Length || ids.Count >= maxTokenCount) + { + break; + } + } + } + else + { + EncodeToIds(textSpanToEncode, ids, out textLength); + } + + return ids; } /// @@ -318,11 +423,11 @@ public override IReadOnlyList Encode(ReadOnlySpan text) /// The text to encode. /// The list of accumulated Ids. /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. + /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. - public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) + private int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokenCount = int.MaxValue) { - Debug.Assert(maxTokens > 0); + Debug.Assert(maxTokenCount > 0); if (text.IsEmpty) { @@ -332,7 +437,7 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI if (_cache.TryGetValue(text, out (int[] Ids, string Token) value)) { - if (value.Ids.Length <= maxTokens) + if (value.Ids.Length <= maxTokenCount) { accumulatedIds.AddRange(value.Ids); textLength = text.Length; @@ -362,7 +467,7 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI } int result; - if (encodedIds.Length <= maxTokens) + if (encodedIds.Length <= maxTokenCount) { accumulatedIds.AddRange(encodedIds); textLength = text.Length; @@ -378,6 +483,104 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI return result; } + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _); + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + { + tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount); + return textLength; + } + + private int CountTokens(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + textLength = 0; + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + int count = 0; + if (splits is not null) + { + foreach ((int Offset, int Length) split in splits) + { + count += CountTokens(textSpanToEncode.Slice(split.Offset, split.Length), out int length, maxTokenCount - count); + textLength = split.Offset + length; + + if (length < split.Length || count >= maxTokenCount) + { + break; + } + } + } + else + { + count = CountTokens(textSpanToEncode, out textLength, maxTokenCount); + } + + return count; + } + /// /// Get the number of tokens that the input text will be encoded to. /// @@ -385,7 +588,7 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI /// The length of the text that encompasses the maximum encoded tokens. /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) + private int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) { Debug.Assert(maxTokens > 0); @@ -439,6 +642,76 @@ public override int CountTokens(ReadOnlySpan text, out int textLength, int return result; } + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled; + /// conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(text, Span.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount); + + private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedString = null; + tokenCount = 0; + return 0; + } + + IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode); + + if (splits is not null) + { + tokenCount = 0; + foreach ((int Offset, int Length) split in splits.Reverse()) + { + tokenCount += CountTokensFromEnd(textSpanToEncode.Slice(split.Offset, split.Length), out int textIndex, maxTokenCount - tokenCount); + if (textIndex > 0 || tokenCount >= maxTokenCount) + { + return split.Offset + textIndex; + } + } + + return 0; + } + else + { + tokenCount = CountTokensFromEnd(textSpanToEncode, out int textLength, maxTokenCount); + return textLength; + } + } + /// /// Get the number of tokens that the input text will be encoded to. /// @@ -446,7 +719,7 @@ public override int CountTokens(ReadOnlySpan text, out int textLength, int /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. - public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) + private int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) { Debug.Assert(maxTokens > 0); @@ -510,7 +783,7 @@ public override int CountTokensFromEnd(ReadOnlySpan text, out int textInde { if (token.IsEmpty) { - return 0; + return null; } if (_cache.TryGetValue(token, out (int[] Ids, string Token) value)) @@ -676,8 +949,8 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo [ // chat ( "gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k - ( "gpt-3.5-turbo-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc. - ( "gpt-35-turbo-", ModelEncoding.Cl100kBase ) // Azure deployment name + ( "gpt-3.5-", ModelEncoding.Cl100kBase), // e.g, gpt-3.5-turbo-0301, -0401, etc. + ( "gpt-35-", ModelEncoding.Cl100kBase ) // Azure deployment name ]; private static readonly Dictionary _modelToEncoding = @@ -687,6 +960,7 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo { "gpt-4", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, { "gpt-3.5-turbo-16k", ModelEncoding.Cl100kBase }, + { "gpt-35", ModelEncoding.Cl100kBase }, // Azure deployment name { "gpt-35-turbo", ModelEncoding.Cl100kBase }, // Azure deployment name { "gpt-35-turbo-16k", ModelEncoding.Cl100kBase }, // Azure deployment name @@ -817,26 +1091,13 @@ internal static (Dictionary SpecialTokens, Regex Regex, string Voca private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); - internal static Tokenizer CreateTokenizerForModel( - string modelName, - IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null) - { - if (string.IsNullOrEmpty(modelName)) - { - throw new ArgumentNullException(nameof(modelName)); - } - - return CreateTokenizerForModel(GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer); - } - - internal static Tokenizer CreateTokenizerForModel( - ModelEncoding modelEncoding, - string? modelName = null, - IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null) + internal static Tokenizer CreateForModel( + ModelEncoding modelEncoding, + string? modelName = null, + IReadOnlyDictionary? extraSpecialTokens = null, + Normalizer? normalizer = null) { - (Dictionary SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = Tiktoken.GetTiktokenConfigurations(modelEncoding, modelName); + (Dictionary SpecialTokens, Regex Regex, string VocabFile) tiktokenConfiguration = GetTiktokenConfigurations(modelEncoding, modelName); if (extraSpecialTokens is not null) { @@ -858,10 +1119,14 @@ internal static Tokenizer CreateTokenizerForModel( _tiktokenCache.TryAdd(tiktokenConfiguration.VocabFile, cache); } - return new Tokenizer( - new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache.DefaultCacheSize), - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), - normalizer); + return new Tiktoken( + cache.encoder, + cache.decoder, + cache.vocab, + new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + tiktokenConfiguration.SpecialTokens, + normalizer, + LruCache.DefaultCacheSize); } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/Word.cs b/src/Microsoft.ML.Tokenizers/Model/Word.cs index c1bfe2e4f2..555942825d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Word.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Word.cs @@ -97,24 +97,24 @@ public void Add(int c, int charLength) return changes; } - public void MergeAll(Dictionary, (int, int)> merges, float? dropout) + public void MergeAll(Dictionary, (int, int)> merges, float? dropout, ref PriorityQueue? priorityQueue) { - // Queue queue = new Queue(_symbols.Count); - PriorityQueue queue = new PriorityQueue(_symbols.Count); + priorityQueue ??= new PriorityQueue(_symbols.Count); + priorityQueue.Clear(); - Vec skip = new Vec(queue.Count); + Vec skip = new Vec(priorityQueue.Count); for (int i = 0; i < _symbols.Count - 1; i++) { if (merges.TryGetValue(Pair.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value)) { - queue.Enqueue(new Merge(i, value.m1, value.m2)); + priorityQueue.Enqueue(new Merge(i, value.m1, value.m2)); } } - while (queue.Count > 0) + while (priorityQueue.Count > 0) { - Merge top = queue.Dequeue(); + Merge top = priorityQueue.Dequeue(); if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout) { skip.Push(top); @@ -124,7 +124,7 @@ public void MergeAll(Dictionary, (int, int)> merges, float? dropout) // Re-insert the skipped elements for (int i = 0; i < skip.Count; i++) { - queue.Enqueue(skip[i]); + priorityQueue.Enqueue(skip[i]); } skip.Clear(); @@ -166,7 +166,7 @@ public void MergeAll(Dictionary, (int, int)> merges, float? dropout) if (merges.TryGetValue(newPair, out value)) { - queue.Enqueue(new Merge(current.Prev, value.m1, value.m2)); + priorityQueue.Enqueue(new Merge(current.Prev, value.m1, value.m2)); } } @@ -178,7 +178,7 @@ public void MergeAll(Dictionary, (int, int)> merges, float? dropout) Pair newPair = Pair.Create(current.C, nextSymbol.C); if (merges.TryGetValue(newPair, out value)) { - queue.Enqueue(new Merge(top.Pos, value.m1, value.m2)); + priorityQueue.Enqueue(new Merge(top.Pos, value.m1, value.m2)); } } } @@ -289,19 +289,16 @@ public override string ToString() return sb.ToString(); } - public List ToTokens(SortedDictionary vocabReverse) + public void ToTokens(SortedDictionary vocabReverse, List tokens, int offset) { - List tokens = new(SymbolsCount); int index = 0; for (int i = 0; i < SymbolsCount; i++) { int endIndex = index + _symbols[i].Len; - tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index, _symbols[i].Len))); + tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index + offset, _symbols[i].Len))); index += _symbols[i].Len; } - - return tokens; } } } diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/LowerCaseNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/LowerCaseNormalizer.cs index 8d4947d367..d793698099 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/LowerCaseNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/LowerCaseNormalizer.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; +using System.Diagnostics; namespace Microsoft.ML.Tokenizers { @@ -22,5 +24,27 @@ public LowerCaseNormalizer() { } /// The original string to normalize to lowercase form. /// The lower-cased normalized string. public override string Normalize(string original) => original.ToLowerInvariant(); + + /// + /// Lowercase the original string. + /// + /// The original string to normalize to lowercase form. + /// The lower-cased normalized string. + public override string Normalize(ReadOnlySpan original) + { + if (original.IsEmpty) + { + return string.Empty; + } + + char[] arrayPoolArray = ArrayPool.Shared.Rent(original.Length); + + int length = original.ToLowerInvariant(arrayPoolArray); + Debug.Assert(length == original.Length); + + string result = new string(arrayPoolArray, 0, length); + ArrayPool.Shared.Return(arrayPoolArray); + return result; + } } } diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/Normalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/Normalizer.cs index c662b43032..dedabda64c 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/Normalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/Normalizer.cs @@ -17,5 +17,12 @@ public abstract class Normalizer /// The original string to normalize. /// The normalized string. public abstract string Normalize(string original); + + /// + /// Process the original string to modify it and obtain a normalized string. + /// + /// The original string to normalize. + /// The normalized string. + public abstract string Normalize(ReadOnlySpan original); } } diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/LlamaNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs similarity index 86% rename from src/Microsoft.ML.Tokenizers/Normalizer/LlamaNormalizer.cs rename to src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs index 72ec6cb547..84ac3aaad9 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/LlamaNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs @@ -10,14 +10,14 @@ namespace Microsoft.ML.Tokenizers /// /// Normalize the string to lowercase form before processing it with the tokenizer. /// - public sealed class LlamaNormalizer : Normalizer + public sealed class SentencePieceNormalizer : Normalizer { internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK) /// /// Creates a LowerCaseNormalizer object. /// - public LlamaNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix) + public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix) { RemoveExtraWhiteSpaces = removeExtraWhiteSpaces; AddDummyPrefix = addDummyPrefix; @@ -40,7 +40,7 @@ public LlamaNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool es public bool TreatWhitespaceAsSuffix { get; } /// - /// Normalize the original string according to SentencePiece normalization with Llama model. + /// Normalize the original string according to SentencePiece normalization. /// /// The original string to normalize. /// The normalized string. @@ -51,6 +51,16 @@ public override string Normalize(string original) return string.Empty; } + return Normalize(original.AsSpan()); + } + + /// + /// Normalize the original string according to SentencePiece normalization. + /// + /// The original string to normalize. + /// The normalized string. + public override string Normalize(ReadOnlySpan original) + { int startIndex = 0; int endIndex = original.Length - 1; diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/UpperCaseNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/UpperCaseNormalizer.cs index 25a69b2813..d6d40f7065 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/UpperCaseNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/UpperCaseNormalizer.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; +using System.Diagnostics; namespace Microsoft.ML.Tokenizers { @@ -22,5 +24,27 @@ public UpperCaseNormalizer() { } /// The original string to normalize to uppercase form. /// The upper-cased normalized string. public override string Normalize(string original) => original.ToUpperInvariant(); + + /// + /// Uppercase the original string. + /// + /// The original string to normalize to uppercase form. + /// The upper-cased normalized string. + public override string Normalize(ReadOnlySpan original) + { + if (original.IsEmpty) + { + return string.Empty; + } + + char[] arrayPoolArray = ArrayPool.Shared.Rent(original.Length); + + int length = original.ToUpperInvariant(arrayPoolArray); + Debug.Assert(length == original.Length); + + string result = new string(arrayPoolArray, 0, length); + ArrayPool.Shared.Return(arrayPoolArray); + return result; + } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs index c38109b3d6..e98a8c32eb 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs @@ -3,66 +3,12 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.Text.RegularExpressions; namespace Microsoft.ML.Tokenizers { - /// - /// This Split contains the underlying split token as well as its offsets - /// in the original string. These offsets are in the `original` referential. - /// It also contains any `Token` associated to the current split. - /// - public struct Split : IEquatable - { - private readonly string? _originalString; - private string? _tokenString; - - /// - /// Gets the underlying split token. Each SubString is represented by a token - /// and in the end we might be carrying a lot of SubString representing various parts of the - /// original input string. - /// - public string TokenString => _tokenString ??= _originalString!.Substring(Offset.Index, Offset.Length); - - /// - /// Gets the underlying split token as a span. - /// - public ReadOnlySpan TokenSpan => _tokenString is string s ? s.AsSpan() : _originalString.AsSpan(Offset.Index, Offset.Length); - - /// - /// Returns the offset mapping to the original string - /// - public (int Index, int Length) Offset { get; } - - /// - /// create a Split object using the token and the offset - /// - /// The token string - /// The offset mapping to the original string - public Split(string token, (int Index, int Length) offset) - { - _tokenString = token; - Offset = offset; - } - - internal Split(string originalString, string? token, (int Index, int Length) offset) - { - _originalString = originalString; - _tokenString = token; - Offset = offset; - } - - /// - /// Indicates whether the current Split object is equal to another Split object. - /// - /// The Split object to compare with the current object. - public bool Equals(Split other) => - (_originalString == other._originalString || TokenString == other.TokenString) && - Offset.Index == other.Offset.Index && - Offset.Length == other.Offset.Length; - } - /// /// Base class for all pre-tokenizers classes. /// The PreTokenizer is in charge of doing the pre-segmentation step. @@ -70,23 +16,54 @@ public bool Equals(Split other) => public abstract class PreTokenizer { /// - /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. + /// Get the offsets and lengths of the tokens relative to the . /// /// The string to split into tokens. - /// The list of the splits containing the tokens and the token's offsets to the original string. - public abstract IEnumerable PreTokenize(string text); + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public abstract IEnumerable<(int Offset, int Length)> PreTokenize(string text); + + /// + /// Get the offsets and lengths of the tokens relative to the original string. + /// + /// The character span to split into tokens. + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public abstract IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text); - internal static IEnumerable SplitText(string text, Regex regex) + internal static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex) { (int Offset, int Length) match; int beginning = 0; while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match)) { - yield return new Split(text, null, (match.Offset, match.Length)); + yield return (match.Offset, match.Length); beginning = match.Offset + match.Length; } } + internal static IEnumerable<(int Offset, int Length)> SplitText(ReadOnlySpan text, Regex regex) + { +#if NET7_0_OR_GREATER + char[] buffer = ArrayPool.Shared.Rent(text.Length); + text.CopyTo(buffer); + return SplitText(buffer, regex, text.Length); + + static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, int textLength) + { + (int Offset, int Length) match; + int beginning = 0; + while (TryGetMatch(regex, text, beginning, textLength - beginning, out match)) + { + yield return (match.Offset, match.Length); + beginning = match.Offset + match.Length; + } + + ArrayPool.Shared.Return(text); + } +#else + return SplitText(text.ToString(), regex); +#endif // NET7_0_OR_GREATER + } + internal static bool TryGetMatch(Regex regex, string text, int beginning, int length, out (int offset, int length) match) { #if NET7_0_OR_GREATER @@ -106,5 +83,18 @@ internal static bool TryGetMatch(Regex regex, string text, int beginning, int le match = default; return false; } + +#if NET7_0_OR_GREATER + internal static bool TryGetMatch(Regex regex, scoped ReadOnlySpan text, int beginning, int length, out (int offset, int length) match) + { + foreach (ValueMatch m in regex.EnumerateMatches(text.Slice(beginning, length))) + { + match = (beginning + m.Index, m.Length); + return true; + } + match = default; + return false; + } +#endif // NET7_0_OR_GREATER } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs index 154df6fa4c..e871f6e114 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs @@ -18,15 +18,30 @@ public sealed partial class RobertaPreTokenizer : PreTokenizer public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer(); /// - /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. + /// Get the offsets and lengths of the tokens relative to the . /// /// The string to split into tokens. - /// The list of the splits containing the tokens and the token's offsets to the original string. - public override IEnumerable PreTokenize(string text) + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(string text) { if (string.IsNullOrEmpty(text)) { - return Array.Empty(); + return []; + } + + return SplitText(text, Tiktoken.P50kBaseRegex()); + } + + /// + /// Get the offsets and lengths of the tokens relative to the . + /// + /// The string to split into tokens. + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text) + { + if (text.IsEmpty) + { + return []; } return SplitText(text, Tiktoken.P50kBaseRegex()); diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/SentencePiecePreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/SentencePiecePreTokenizer.cs deleted file mode 100644 index e30322630c..0000000000 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/SentencePiecePreTokenizer.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// The pre-tokenizer for SentencePiece tokenizers. - /// - internal sealed partial class SentencePiecePreTokenizer : PreTokenizer - { - /// - /// Gets a singleton instance of the Roberta pre-tokenizer.. - /// - public static SentencePiecePreTokenizer Instance { get; } = new SentencePiecePreTokenizer(); - - /// - /// Return the whole text as one chunk. - /// - /// The string to split into tokens. - /// The original string as one chunk. - public override IEnumerable PreTokenize(string text) - { - if (string.IsNullOrEmpty(text)) - { - yield break; - } - - yield return new Split(text, (0, text.Length)); - } - } -} diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs index fd6cd50c8b..4050f75d07 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/TiktokenPreTokenizer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; @@ -39,20 +40,20 @@ public TiktokenPreTokenizer(Regex regex, IReadOnlyDictionary? speci } /// - /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. + /// Get the offsets and lengths of the tokens relative to the . /// /// The string to split into tokens. - /// The list of the splits containing the tokens and the token's offsets to the original string. - public override IEnumerable PreTokenize(string text) + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(string text) { if (string.IsNullOrEmpty(text)) { - return Array.Empty(); + return []; } return SplitText(text, _regex, _specialTokensRegex); - static IEnumerable SplitText(string text, Regex regex, Regex? specialTokensRegex) + static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex, Regex? specialTokensRegex) { (int Offset, int Length) match; int beginning = 0; @@ -69,21 +70,77 @@ static IEnumerable SplitText(string text, Regex regex, Regex? specialToke while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match)) { - yield return new Split(text, null, (match.Offset, match.Length)); + yield return (match.Offset, match.Length); beginning = match.Offset + match.Length; } - yield return new Split(text, null, (specialMatch.Offset, specialMatch.Length)); + yield return (specialMatch.Offset, specialMatch.Length); beginning = specialMatch.Offset + specialMatch.Length; } } while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match)) { - yield return new Split(text, null, (match.Offset, match.Length)); + yield return (match.Offset, match.Length); beginning = match.Length + match.Offset; } } } + + /// + /// Get the offsets and lengths of the tokens relative to the . + /// + /// The string to split into tokens. + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text) + { + if (text.IsEmpty) + { + return []; + } + +#if NET7_0_OR_GREATER + char[] buffer = ArrayPool.Shared.Rent(text.Length); + text.CopyTo(buffer); + return SplitText(buffer, _regex, _specialTokensRegex, text.Length); + + static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, Regex? specialTokensRegex, int textLength) + { + (int Offset, int Length) match; + int beginning = 0; + + if (specialTokensRegex is not null) + { + while (true) + { + (int Offset, int Length) specialMatch; + if (!TryGetMatch(specialTokensRegex, text.AsSpan(), beginning, textLength - beginning, out specialMatch)) + { + break; + } + + while (TryGetMatch(regex, text.AsSpan(), beginning, specialMatch.Offset - beginning, out match)) + { + yield return (match.Offset, match.Length); + beginning = match.Offset + match.Length; + } + + yield return (specialMatch.Offset, specialMatch.Length); + beginning = specialMatch.Offset + specialMatch.Length; + } + } + + while (TryGetMatch(regex, text.AsSpan(), beginning, textLength - beginning, out match)) + { + yield return (match.Offset, match.Length); + beginning = match.Length + match.Offset; + } + + ArrayPool.Shared.Return(text); + } +#else + return PreTokenize(text.ToString()); +#endif // NET7_0_OR_GREATER + } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs index f0a7fa812f..e9bd8cf0e9 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs @@ -29,15 +29,30 @@ public sealed partial class WhiteSpace : PreTokenizer #endif /// - /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. + /// Get the offsets and lengths of the tokens relative to the . /// /// The string to split into tokens. - /// The list of the splits containing the tokens and the token's offsets to the original string. - public override IEnumerable PreTokenize(string text) + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(string text) { if (string.IsNullOrEmpty(text)) { - return Array.Empty(); + return []; + } + + return SplitText(text, PretokenizeRegex()); + } + + /// + /// Get the offsets and lengths of the tokens relative to the . + /// + /// The string to split into tokens. + /// The offsets and lengths of the tokens, expressed as pairs, are relative to the original string. + public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text) + { + if (text.IsEmpty) + { + return []; } return SplitText(text, PretokenizeRegex()); diff --git a/src/Microsoft.ML.Tokenizers/Token.cs b/src/Microsoft.ML.Tokenizers/Token.cs index ae68e3c267..d7efd60019 100644 --- a/src/Microsoft.ML.Tokenizers/Token.cs +++ b/src/Microsoft.ML.Tokenizers/Token.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Tokenizers /// Represent the token produced from the tokenization process containing the token substring, /// the id associated to the token substring, and the offset mapping to the original string. /// - public sealed class Token + public readonly struct Token { /// /// Gets the Id value associated to the token. @@ -22,12 +22,12 @@ public sealed class Token /// /// Gets the token string value. /// - public string Value { get; set; } + public string Value { get; } /// /// Gets the offset mapping to the original string. /// - public (int Index, int Length) Offset { get; internal set; } + public (int Index, int Length) Offset { get; } /// /// Construct a new Token object using the token value, Id, and the offset mapping to the original string. diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index af1f729c67..7329b3648c 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -8,6 +8,7 @@ using System.Diagnostics; using System.IO; using System.Linq; +using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -15,263 +16,279 @@ namespace Microsoft.ML.Tokenizers { /// - /// A Tokenizer works as a pipeline. It processes some raw text as input and outputs a EncodingResult object. + /// serves as an abstraction for concrete tokenizers, enabling the encoding of text into tokens and IDs, as well as the decoding of IDs back into text. /// - public partial class Tokenizer + public abstract class Tokenizer { /// - /// Create a new Tokenizer object. + /// Gets the PreTokenizer used by the Tokenizer. /// - /// The Model in use by the Tokenizer. - /// The optional PreTokenizer in use by the Tokenizer. WhiteSpace PreTokenizer will be used if this parameter is null. - /// The optional Normalizer in use by the Tokenizer. - public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null) - { - Model = model; - PreTokenizer = preTokenizer ?? WhiteSpace.Instance; - Normalizer = normalizer; - } + public virtual PreTokenizer? PreTokenizer => null; /// - /// Gets the Model in use by the Tokenizer. + /// Gets the Normalizer in use by the Tokenizer. /// - public Model Model { get; } + public virtual Normalizer? Normalizer => null; /// - /// Gets or sets the PreTokenizer used by the Tokenizer. + /// Encodes input text a list of s with string value of the token, id, and offset. /// - public PreTokenizer PreTokenizer { get; } + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public virtual IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text.AsSpan(), out normalizedString, considerPreTokenization, considerNormalization); /// - /// Gets or sets the Normalizer in use by the Tokenizer. + /// Encodes input text a list of s with string value of the token, id, and offset. /// - public Normalizer? Normalizer { get; } + /// The text to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The tokenization result includes a list of s with string value of the token, id, and offset. + public abstract IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true); /// - /// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping. + /// Encodes input text to token Ids. /// /// The text to encode. - /// The tokenization result includes the tokens list, tokens Ids, tokens offset mapping. - public EncodingResult Encode(string text) - { - if (text is null) - { - throw new ArgumentNullException(nameof(text)); - } - - string normalized = Normalizer is null ? text : Normalizer.Normalize(text); - bool offsetsMappedToOriginal = true; - - EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized), offsetsMappedToOriginal); - - foreach (Split split in encoding.Splits) - { - IReadOnlyList tokens = Model.Encode(split.TokenString.AsSpan()); - foreach (Token token in tokens) - { - token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length); - } - - encoding.AddTokens(tokens); - } - - return encoding; - } + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public virtual IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) => EncodeToIds(text.AsSpan(), considerPreTokenization, considerNormalization); /// - /// Encodes input text to tokens Ids. + /// Encodes input text to token Ids. /// /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. /// The list of encoded Ids. - public IReadOnlyList EncodeToIds(string text) - { - if (text is null) - { - throw new ArgumentNullException(nameof(text)); - } - - string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; - List idsList = new(); - - foreach (Split split in PreTokenizer.PreTokenize(normalized)) - { - Model.EncodeToIds(split.TokenSpan, idsList, out _); - } - - return idsList; - } + public abstract IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true); /// - /// Encodes input text to tokens Ids up to maximum number of tokens. + /// Encodes input text to token Ids up to maximum number of tokens. /// /// The text to encode. /// The maximum number of tokens to encode. - /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. /// The list of encoded Ids. - public IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string processedText, out int textLength) - { - processedText = text; - textLength = 0; + public virtual IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + => EncodeToIds(text.AsSpan(), maxTokenCount, out normalizedText, out textLength, considerPreTokenization, considerNormalization); - if (text is null) - { - throw new ArgumentNullException(nameof(text)); - } - - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0."); - } - - if (Normalizer is not null) - { - processedText = Normalizer.Normalize(text); - } - - List idsList = new(); - - foreach (Split split in PreTokenizer.PreTokenize(processedText)) - { - Model.EncodeToIds(split.TokenSpan, idsList, out int length, maxTokenCount - idsList.Count); - if (length < split.Offset.Length || idsList.Count >= maxTokenCount) - { - break; - } - } - - return idsList; - } + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// The maximum number of tokens to encode. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The list of encoded Ids. + public abstract IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true); /// /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// The number of tokens Ids that the input text will be encoded to. - /// The input text is null. - /// Unable to encode the text. - public int CountTokens(string text) - { - if (text is null) - { - throw new ArgumentNullException(nameof(text)); - } - - string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public virtual int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) + => CountTokens(text.AsSpan(), considerPreTokenization, considerNormalization); - int idsCount = 0; - foreach (Split split in PreTokenizer.PreTokenize(normalized)) - { - idsCount += Model.CountTokens(split.TokenSpan, out _); - } + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// The number of token Ids that the input text will be encoded to. + public abstract int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true); - return idsCount; - } + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public virtual int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => IndexOfTokenCount(text.AsSpan(), maxTokenCount, out normalizedString, out tokenCount, considerPreTokenization, considerNormalization); /// /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. /// /// The text to encode. /// The maximum token count to limit the encoding capacity. - /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. /// /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. - /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the . + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public abstract int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true); + + /// + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. /// - /// The input text is null. - /// The maximum token count must be greater than 0. - public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount) - => IndexOf(text, maxTokenCount, out processedText, out tokenCount); + public virtual int LastIndexOfTokenCount(string text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) + => LastIndexOfTokenCount(text.AsSpan(), maxTokenCount, out processedText, out tokenCount, considerPreTokenization, considerNormalization); /// /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. /// /// The text to encode. /// The maximum token count to limit the encoding capacity. - /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. /// /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. /// - /// The input text is null. - /// The maximum token count must be greater than 0. - /// - /// If the whole text can be encoded within the token limit, the returned index will be 0. - /// - public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount) - => LastIndexOf(text, maxTokenCount, out processedText, out tokenCount); - - private int IndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount) + public abstract int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? processedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true); + + /// + /// Map the token to encoded Id. + /// + /// The token to map to the Id. + /// The mapped Id of the token. + public virtual int? MapTokenToId(string token) { - if (text is null) + if (token is null) { - throw new ArgumentNullException(nameof(text)); + throw new ArgumentNullException(nameof(token)); } - if (maxTokenCount <= 0) + return MapTokenToId(token.AsSpan()); + } + + /// + /// Map the token to encoded Id. + /// + /// The token to map to the Id. + /// The mapped Id of the token. + public abstract int? MapTokenToId(ReadOnlySpan token); + + /// + /// Decodes the Id to the mapped token. + /// + /// The id to map to the token. + /// The decoded string or null if there is no token mapped to the input id. + public abstract string? MapIdToToken(int id); + + /// + /// Decode the given ids, back to a String. + /// + /// The list of ids that we want to decode. + /// The decoded string. + public virtual string? Decode(IEnumerable ids) + { + if (ids is null) { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + throw new ArgumentNullException(nameof(ids)); } - processedText = Normalizer is not null ? Normalizer.Normalize(text) : text; - tokenCount = 0; + ValueStringBuilder sb = new ValueStringBuilder(); - IEnumerable splits = PreTokenizer.PreTokenize(processedText); - foreach (Split split in splits) + foreach (int id in ids) { - tokenCount += Model.CountTokens(split.TokenSpan, out int textLength, maxTokenCount - tokenCount); - if (textLength < split.Offset.Length || tokenCount >= maxTokenCount) + if (MapIdToToken(id) is string s) { - return split.Offset.Index + textLength; + sb.Append(s); } } - return processedText.Length; + return sb.ToString(); } - private int LastIndexOf(string text, int maxTokenCount, out string processedText, out int tokenCount) - { - if (text is null) - { - throw new ArgumentNullException(nameof(text)); - } + // + // Factory Methods + // - if (maxTokenCount <= 0) + /// + /// Create a new Tiktoken tokenizer's object asynchronously. + /// + /// The stream to the BPE vocab file. + /// The pre-tokenizer to use. + /// The normalizer to use. + /// The dictionary mapping special tokens to Ids. + /// The size of the cache to use. + /// used to request cancellation of the operation. + /// The tokenizer's object. + public static async Task CreateTiktokenAsync( + Stream vocabStream, + PreTokenizer? preTokenizer, + Normalizer? normalizer, + IReadOnlyDictionary? specialTokens = null, + int cacheSize = LruCache.DefaultCacheSize, + CancellationToken cancellationToken = default) + { + if (vocabStream is null) { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + throw new ArgumentNullException(nameof(vocabStream)); } - processedText = Normalizer is not null ? Normalizer.Normalize(text) : text; - tokenCount = 0; + (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = + await Tiktoken.LoadTiktokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false); - IEnumerable splits = PreTokenizer.PreTokenize(processedText); - foreach (Split split in splits.Reverse()) - { - tokenCount += Model.CountTokensFromEnd(split.TokenSpan, out int textIndex, maxTokenCount - tokenCount); - if (textIndex > 0 || tokenCount >= maxTokenCount) - { - return split.Offset.Index + textIndex; - } - } - - return 0; + return new Tiktoken(encoder, decoder, vocab, preTokenizer, specialTokens, normalizer, cacheSize); } /// - /// Decodes the Id to the mapped token. + /// Create a new Tiktoken tokenizer's object asynchronously. /// - /// The id to map to the token. - /// The decoded string or null if there is no token mapped to the input id. - public string? Decode(int id) => Model.MapIdToToken(id); + /// The BPE vocab file. + /// The pre-tokenizer to use. + /// The normalizer to use. + /// The dictionary mapping special tokens to Ids. + /// The size of the cache to use. + /// used to request cancellation of the operation. + /// The tokenizer's object. + public static async Task CreateTiktokenAsync( + string vocabFilePath, + PreTokenizer? preTokenizer, + Normalizer? normalizer, + IReadOnlyDictionary? specialTokensEncoder = null, + int cacheSize = LruCache.DefaultCacheSize, + CancellationToken cancellationToken = default) + { + if (vocabFilePath is null) + { + throw new ArgumentNullException(nameof(vocabFilePath)); + } - /// - /// Decode the given ids, back to a String. - /// - /// The list of ids that we want to decode. - /// The decoded string. - public string? Decode(IEnumerable ids) => Model.Decode(ids); + using Stream vocabStream = File.OpenRead(vocabFilePath); + return await CreateTiktokenAsync(vocabStream, preTokenizer, normalizer, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false); + } /// /// Create a Tiktoken tokenizer based on model name and vocab file. @@ -304,10 +321,11 @@ public static Tokenizer CreateTiktokenForModel( } } - return new Tokenizer( - new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize), - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), - normalizer); + return new Tiktoken(vocabStream, + new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + tiktokenConfiguration.SpecialTokens, + normalizer, + cacheSize); } /// @@ -343,10 +361,11 @@ public static async Task CreateTiktokenForModelAsync( } } - return new Tokenizer( - await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false), - new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), - normalizer); + return await CreateTiktokenAsync(vocabStream, + new TiktokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + normalizer, + tiktokenConfiguration.SpecialTokens, + cacheSize, cancellationToken).ConfigureAwait(false); } /// @@ -357,7 +376,7 @@ await Tiktoken.CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cac /// To normalize the text before tokenization /// The tokenizer public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary? extraSpecialTokens = null, Normalizer? normalizer = null) - => Tiktoken.CreateTokenizerForModel(modelName, extraSpecialTokens, normalizer); + => Tiktoken.CreateForModel(Tiktoken.GetModelEncoding(modelName), modelName, extraSpecialTokens, normalizer); /// /// Create tokenizer based on encoding name @@ -395,7 +414,7 @@ public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnly throw new ArgumentException($"The encoding name '{encodingName}' is not supported. The only supported encoding names are: {Tiktoken.Cl100kBaseEncodingName}, {Tiktoken.P50kBaseEncodingName}, {Tiktoken.P50kEditEncodingName}, and {Tiktoken.R50kBaseEncodingName}.", nameof(encodingName)); } - return Tiktoken.CreateTokenizerForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer); + return Tiktoken.CreateForModel(modelEncoding, modelName: null, extraSpecialTokens, normalizer); } /// @@ -427,16 +446,70 @@ public static Tokenizer CreateLlama( throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto)); } - LlamaNormalizer normalizer = new( + SentencePieceNormalizer normalizer = new( modelProto.NormalizerSpec.RemoveExtraWhitespaces, modelProto.NormalizerSpec.AddDummyPrefix, modelProto.NormalizerSpec.EscapeWhitespaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix); - return new Tokenizer( - new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence), - SentencePiecePreTokenizer.Instance, - normalizer); + return new SentencePieceBpe(modelProto, addBeginOfSentence, addEndOfSentence); + } + + internal static IEnumerable<(int Offset, int Length)>? InitializeForEncoding( + string? text, + ReadOnlySpan textSpan, + bool considerPreTokenization, + bool considerNormalization, + Normalizer? normalizer, + PreTokenizer? preTokenizer, + out string? normalizedString, + out ReadOnlySpan textSpanToEncode) + { + normalizedString = null; + IEnumerable<(int Offset, int Length)>? splits = null; + + if (text is null) + { + if (considerNormalization && (normalizer is not null)) + { + normalizedString = normalizer.Normalize(textSpan.ToString()); + textSpanToEncode = normalizedString.AsSpan(); + if (considerPreTokenization && preTokenizer is not null) + { + splits = preTokenizer.PreTokenize(normalizedString); + } + } + else + { + textSpanToEncode = textSpan; + if (considerPreTokenization && preTokenizer is not null) + { + splits = preTokenizer.PreTokenize(textSpan); + } + } + } + else + { + if (considerNormalization && (normalizer is not null)) + { + normalizedString = normalizer.Normalize(text); + textSpanToEncode = normalizedString.AsSpan(); + if (considerPreTokenization && preTokenizer is not null) + { + splits = preTokenizer.PreTokenize(normalizedString); + } + } + else + { + textSpanToEncode = text.AsSpan(); + if (considerPreTokenization && preTokenizer is not null) + { + splits = preTokenizer.PreTokenize(text); + } + } + } + + return splits; } } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs b/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs index e2eaa38846..751ce6bc10 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs @@ -71,6 +71,8 @@ public override string ToString() return s; } + public void Clear() => _data.Clear(); + public bool IsConsistent() { // is the heap property true for all data? diff --git a/src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs b/src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs index 19e649524d..9504760e98 100644 --- a/src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs +++ b/src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs @@ -26,12 +26,12 @@ internal static Tokenizer GetInstance(IChannel ch) // "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt" Assembly assembly = typeof(TokenizerExtensions).Assembly; - EnglishRoberta model = new EnglishRoberta( + _instance = new EnglishRoberta( assembly.GetManifestResourceStream("encoder.json"), assembly.GetManifestResourceStream("vocab.bpe"), - assembly.GetManifestResourceStream("dict.txt")); - model.AddMaskSymbol(); - _instance = new Tokenizer(model, new RobertaPreTokenizer()); + assembly.GetManifestResourceStream("dict.txt"), + new RobertaPreTokenizer()); + (_instance as EnglishRoberta).AddMaskSymbol(); } return _instance; @@ -40,7 +40,7 @@ internal static Tokenizer GetInstance(IChannel ch) [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer) { - EnglishRoberta model = tokenizer.Model as EnglishRoberta; + EnglishRoberta model = tokenizer as EnglishRoberta; if (model is null) { throw new InvalidOperationException($"The input tokenizer is not using the EnglishRoberta model."); @@ -51,8 +51,7 @@ internal static EnglishRoberta RobertaModel(this Tokenizer tokenizer) internal static IReadOnlyList EncodeToConverted(this Tokenizer tokenizer, string sentence) { - EncodingResult encoding = tokenizer.Encode(sentence); - return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids); + return tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(tokenizer.EncodeToIds(sentence)); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs b/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs index 5d9e5c596f..67518ee815 100644 --- a/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs @@ -167,16 +167,16 @@ private protected override torch.Tensor PrepareRowTensor(ref VBuffer targe Sentence1Getter(ref sentenceRom); var sentence = sentenceRom.ToString(); Tensor t; - var encoding = Tokenizer.Encode(sentence); + IReadOnlyList encoding = Tokenizer.Encode(sentence, out string normalizedString); - if (target.Length != encoding.Tokens.Count) + if (target.Length != encoding.Count) { var targetIndex = 0; - var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count); + var targetEditor = VBufferEditor.Create(ref target, encoding.Count); var newValues = targetEditor.Values; - for (var i = 0; i < encoding.Tokens.Count; i++) + for (var i = 0; i < encoding.Count; i++) { - if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) + if (NerTrainer.TokenStartsWithSpace(encoding[i].Value)) { newValues[i] = target.GetItemOrDefault(++targetIndex); } @@ -187,7 +187,7 @@ private protected override torch.Tensor PrepareRowTensor(ref VBuffer targe } target = targetEditor.Commit(); } - t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device); + t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(encoding.Select(t => t.Id).ToArray())).ToList(), device: Device); if (t.NumberOfElements > 512) t = t.slice(0, 0, 512, 1); @@ -377,16 +377,16 @@ private protected override Delegate CreateGetter(DataViewRow input, int iinfo, T private void CondenseOutput(ref VBuffer dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher) { var pre = tokenizer.PreTokenizer.PreTokenize(sentence); - EncodingResult encoding = tokenizer.Encode(sentence); + IReadOnlyList encoding = tokenizer.Encode(sentence, out string normalizedString); var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1); var prediction = argmax.ToArray(); var targetIndex = 0; // Figure out actual count of output tokens - for (var i = 0; i < encoding.Tokens.Count; i++) + for (var i = 0; i < encoding.Count; i++) { - if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) + if (NerTrainer.TokenStartsWithSpace(encoding[i].Value)) { targetIndex++; } @@ -398,9 +398,9 @@ private void CondenseOutput(ref VBuffer dst, string sentence, Tokenizer newValues[targetIndex++] = (uint)prediction[0]; - for (var i = 1; i < encoding.Tokens.Count; i++) + for (var i = 1; i < encoding.Count; i++) { - if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i])) + if (NerTrainer.TokenStartsWithSpace(encoding[i].Value)) { newValues[targetIndex++] = (uint)prediction[i]; } diff --git a/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs b/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs index 07fe172bc9..ec90feb9c2 100644 --- a/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs +++ b/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs @@ -401,9 +401,9 @@ private torch.Tensor PrepareBatchTensor(ref List inputTensors, Device de answerIndexGetter(ref answerIndex); var contextString = context.ToString(); - var contextTokens = Tokenizer.Encode(contextString); - var contextToken = contextTokens.Tokens; - var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Ids); + var contextTokens = Tokenizer.Encode(contextString, out string normalized); + var contextToken = contextTokens.Select(t => t.Value).ToArray(); + var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Select(t => t.Id).ToArray()); var mapping = AlignAnswerPosition(contextToken, contextString); if (mapping == null) @@ -437,7 +437,7 @@ private torch.Tensor PrepareBatchTensor(ref List inputTensors, Device de private Dictionary AlignAnswerPosition(IReadOnlyList tokens, string text) { - EnglishRoberta robertaModel = Tokenizer.Model as EnglishRoberta; + EnglishRoberta robertaModel = Tokenizer as EnglishRoberta; Debug.Assert(robertaModel is not null); var mapping = new Dictionary(); @@ -854,9 +854,9 @@ private Tensor PrepInputTensors(ref ReadOnlyMemory context, ref ReadOnlyMe contextGetter(ref context); questionGetter(ref question); - var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(context.ToString()).Ids); + var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(context.ToString())); - var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.Encode(question.ToString()).Ids); + var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(question.ToString())); var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device); diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 1cbd7b092d..d6d8271802 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -248,28 +248,28 @@ public void SimpleTestWithUnknownToken( try { - Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownToken); - Tokenizer tokenizer = new Tokenizer(bpe); - EncodingResult encoding = tokenizer.Encode(sentence); + Bpe bpe = new Bpe(vocabFile, mergesFile, unknownToken: unknownToken, continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken); + Tokenizer tokenizer = bpe; + IReadOnlyList encoding = tokenizer.Encode(sentence, out _); + int[] encodingIds = encoding.Select(t => t.Id).ToArray(); IReadOnlyList idsList = tokenizer.EncodeToIds(sentence); - Assert.Equal(expectedTokens.Length, encoding.Tokens.Count); - Assert.Equal(offsets.Length, encoding.Offsets.Count); - Assert.Equal(ids.Length, encoding.Ids.Count); + Assert.Equal(expectedTokens.Length, encoding.Count); + Assert.Equal(offsets.Length, encoding.Count); + Assert.Equal(ids.Length, encoding.Count); Assert.Equal(ids.Length, idsList.Count); Assert.Equal(ids.Length, tokenizer.CountTokens(sentence)); - Assert.Equal(decodedTokens, tokenizer.Decode(encoding.Ids)); - Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encoding.Ids, considerSpecialTokens: false)); + Assert.Equal(decodedTokens, tokenizer.Decode(encodingIds)); + Assert.Equal(decodedTokensWithoutUnknownToken, bpe.Decode(encodingIds, considerSpecialTokens: false)); - for (int i = 0; i < encoding.Tokens.Count; i++) + for (int i = 0; i < encoding.Count; i++) { - Assert.Equal(expectedTokens[i], encoding.Tokens[i]); - Assert.Equal(offsets[i], encoding.Offsets[i]); - Assert.Equal(ids[i], encoding.Ids[i]); + Assert.Equal(expectedTokens[i], encoding[i].Value); + Assert.Equal(offsets[i], encoding[i].Offset); + Assert.Equal(ids[i], encoding[i].Id); Assert.Equal(ids[i], idsList[i]); - Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); - Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i])); + Assert.Equal(encoding[i].Value, tokenizer.MapIdToToken(encodingIds[i])); + Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(encoding[i].Value.AsSpan())); } } finally @@ -282,6 +282,41 @@ public void SimpleTestWithUnknownToken( } } + private static Tokenizer? _gpt2Tokenizer = null; + + private static Tokenizer GetGpt2Tokenizer() + { + if (_gpt2Tokenizer is null) + { + // "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json"; + // "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt"; + using Stream vocabStream = File.OpenRead(Path.Combine(@"Gpt-2", "vocab.json")); + using Stream mergesStream = File.OpenRead(Path.Combine(@"Gpt-2", "merges.txt")); + + _gpt2Tokenizer = new Bpe(vocabStream, mergesStream); + } + + return _gpt2Tokenizer; + } + + [Fact] + public void TestGpt2Vocab() + { + Tokenizer tokenizer = GetGpt2Tokenizer(); + + string text = "The quick brown fox jumps over the lazy dog!"; + + IReadOnlyList encoding = tokenizer.Encode(text, out _); + IReadOnlyList ids = tokenizer.EncodeToIds(text); + + Assert.Equal(12, encoding.Count); + Assert.Equal(encoding.Select(t => t.Id).ToArray(), ids); + Assert.Equal(12, tokenizer.CountTokens(text)); + + TokenizerTests.TestTokenLimits(tokenizer); + } + + public static IEnumerable BpeTestData { get @@ -290,55 +325,84 @@ public static IEnumerable BpeTestData yield return new object?[] { "the brown fox jumped over the lazy dog!", - new string[] {"the", "brown", "fox", "jumped", "over", "the", "lazy", "dog", "!"}, - new (int, int)[] {(0, 3), (4, 9), (10, 13), (14, 20), (21, 25), (26, 29), (30, 34), (35, 38), (38, 39)} + new string[] { "the", "brown", "fox", "j", "umped", "over", "the", "l", "azy", "dog", "!" }, + new (int Index, int Length)[] { (0, 3), (4, 5), (10, 3), (14, 1), (15, 5), (21, 4), (26, 3), (30, 1), (31, 3), (35, 3), (38, 1) }, + new int[] { 1169, 33282, 12792, 73, 27073, 2502, 1169, 75, 12582, 9703, 0 } }; yield return new object?[] { "he traveled to Egypt during the summer, the weather was hot and ammunition." , - new string[] {"he", "traveled", "to", "Egypt", "during", "the", "summer", ",", "the", "weather", "was", "hot", "and", "ammunition", "."}, - new (int, int)[] {(0, 2), (3, 11), (12, 14), (15, 20), (21, 27), (28, 31), (32, 38), (38, 39), (40, 43), (44, 51), (52, 55), (56, 59), (60, 63), (64, 74), (74, 75)} + new string[] { "he", "travel", "ed", "to", "Egypt", "during", "the", "sum", "mer", ",", "the", "weather", "was", "hot", "and", "am", "munition", "." }, + new (int Index, int Length)[] { (0, 2), (3, 6), (9, 2), (12, 2), (15, 5), (21, 6), (28, 3), (32, 3), (35, 3), (38, 1), (40, 3), (44, 7), (52, 3), (56, 3), (60, 3), (64, 2), (66, 8), (74, 1) }, + new int[] { 258, 35927, 276, 1462, 39299, 42122, 1169, 16345, 647, 11, 1169, 23563, 9776, 8940, 392, 321, 12640, 13 } }; yield return new object?[] { "She played many games and she felt exhausted afterward", - new string[] {"She", "played", "many", "games", "and", "she", "felt", "exhausted", "afterward"}, - new (int, int)[] {(0, 3), (4, 10), (11, 15), (16, 21), (22, 25), (26, 29), (30, 34), (35, 44), (45, 54)} + new string[] { "She", "played", "many", "games", "and", "she", "felt", "ex", "ha", "usted", "after", "ward" }, + new (int Index, int Length)[] { (0, 3), (4, 6), (11, 4), (16, 5), (22, 3), (26, 3), (30, 4), (35, 2), (37, 2), (39, 5), (45, 5), (50, 4) }, + new int[] { 3347, 21542, 21834, 19966, 392, 7091, 31985, 1069, 3099, 8459, 8499, 904 } }; yield return new object?[] { "Hello, y'all! How are you 😁 ?", - new string[] {"Hello", ",", "y", "'", "all", "!", "How", "are", "you", "[UNK]", "?"}, - new (int, int)[] {(0, 5), (5, 6), (7, 8), (8, 9), (9, 12), (12, 13), (14, 17), (18, 21), (22, 25), (26, 28), (29, 30)} + new string[] { "Hello", ",", "y", "'", "all", "!", "How", "are", "you", "?" }, + new (int Index, int Length)[] { (0, 5), (5, 1), (7, 1), (8, 1), (9, 3), (12, 1), (14, 3), (18, 3), (22, 3), (29, 1) }, + new int[] { 15496, 11, 88, 6, 439, 0, 2437, 533, 5832, 30 } }; } } - private const string Gpt2VocabUrl = "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json"; - private const string Gpt2MergesUrl = "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt"; - - [Fact] - public async void TestGpt2Vocab() + [Theory] + [MemberData(nameof(BpeTestData))] + public void TestBpeTokenizer(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) { - using HttpClient httpClient = new HttpClient(); - using Stream vocabStream = await httpClient.GetStreamAsync(Gpt2VocabUrl); - using Stream mergesStream = await httpClient.GetStreamAsync(Gpt2MergesUrl); + Tokenizer tokenizer = GetGpt2Tokenizer(); - Bpe bpe = new Bpe(vocabStream, mergesStream); - Tokenizer tokenizer = new Tokenizer(bpe); + IReadOnlyList encoding = tokenizer.Encode(text, out _); + IReadOnlyList encoding1 = tokenizer.Encode(text.AsSpan(), out _); - string text = "The quick brown fox jumps over the lazy dog!"; + Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); - EncodingResult encoding = tokenizer.Encode(text); - IReadOnlyList ids = tokenizer.EncodeToIds(text); + Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); - Assert.Equal(12, encoding.Tokens.Count); - Assert.Equal(12, encoding.Offsets.Count); - Assert.Equal(12, encoding.Ids.Count); - Assert.Equal(encoding.Ids, ids); - Assert.Equal(12, tokenizer.CountTokens(text)); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan())); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); - TokenizerTests.TestTokenLimits(tokenizer); + Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length)); + Assert.Null(normalizedString); + int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length; + Assert.Equal(expectedLength, length); + Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(expectedLength, length); + + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text)); + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan())); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); } private static string WriteToMergeFile((string, string)[] mergeEntries) @@ -360,7 +424,7 @@ private static string WriteToVocabFile(Dictionary dic) return fileName; } - internal static Bpe CreateEmptyBpe() + internal static Bpe CreateEmptyBpe(PreTokenizer? preTokenizer = null, Normalizer? normalizer = null) { using MemoryStream emptyVocabStream = new MemoryStream(); using StreamWriter writer = new StreamWriter(emptyVocabStream); @@ -368,7 +432,7 @@ internal static Bpe CreateEmptyBpe() writer.Flush(); emptyVocabStream.Position = 0; - return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, unknownToken: "Ukn"); + return new Bpe(vocabStream: emptyVocabStream, mergesStream: null, preTokenizer: preTokenizer ?? WhiteSpace.Instance, normalizer: normalizer, unknownToken: "Ukn"); } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index 6280866b6d..0dcd6f2399 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -6,9 +6,11 @@ using System.IO; using System.Collections.Generic; using System.Linq; -using System.Text; +using System.Net.Http; using Xunit; +using System.Diagnostics; +using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers.Tests { @@ -80,6 +82,34 @@ public static IEnumerable BertaData } } + private static Tokenizer? _robertaTokenizer = null; + private async static Task GetRobertaTokenizer() + { + if (_robertaTokenizer is null) + { + string vocabFile = Utils.CreateTemporaryFile("json"); + string mergeFile = Utils.CreateTemporaryFile("txt"); + string translationFile = Utils.CreateTemporaryFile("txt"); + + try + { + await Utils.DownloadFile(_vocabUrl, vocabFile); + await Utils.DownloadFile(_mergeUrl, mergeFile); + await Utils.DownloadFile(_dictUrl, translationFile); + + _robertaTokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance); + } + finally + { + Utils.DeleteFile(vocabFile); + Utils.DeleteFile(mergeFile); + Utils.DeleteFile(translationFile); + } + } + + return _robertaTokenizer; + } + [Fact] public async void TokenizationTest() { @@ -94,26 +124,26 @@ public async void TokenizationTest() await Utils.DownloadFile(_mergeUrl, mergeFile); await Utils.DownloadFile(_dictUrl, translationFile); - Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); + Tokenizer tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); TokenizerTests.TestTokenLimits(tokenizer); - tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); + tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false); TestTokenizer(tokenizer); using Stream vocabStream = File.OpenRead(vocabFile); using Stream mergeStream = File.OpenRead(mergeFile); using Stream translationStream = File.OpenRead(translationFile); - tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance); + tokenizer = new EnglishRoberta(vocabStream, mergeStream, translationStream, RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); // Ensure caching works regardless of which method is called first. for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++) { - tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); + tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance); TestTokenizer(tokenizer, order); - tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); + tokenizer = new EnglishRoberta(vocabFile, mergeFile, translationFile, RobertaPreTokenizer.Instance, filterUnsupportedChars: false); TestTokenizer(tokenizer, order); } } @@ -132,6 +162,94 @@ public async void TokenizationTest() } } + public static IEnumerable RobertaTestData + { + get + { + // string to tokenize, produced tokens, the token offsets + yield return new object?[] + { + "the brown fox jumped over the lazy dog!", + new string[] { "the", "Ġbrown", "Ġfox", "Ġjumped", "Ġover", "Ġthe", "Ġlazy", "Ġdog", "!" }, + new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) }, + new int[] { 1169, 7586, 21831, 11687, 625, 262, 16931, 3290, 0 } + }; + yield return new object?[] + { + "he traveled to Egypt during the summer, the weather was hot and ammunition." , + new string[] { "he", "Ġtraveled", "Ġto", "ĠEgypt", "Ġduring", "Ġthe", "Ġsummer", ",", "Ġthe", "Ġweather", "Ġwas", "Ġhot", "Ġand", "Ġammunition", "." }, + new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) }, + new int[] { 258, 14113, 284, 6365, 1141, 262, 3931, 11, 262, 6193, 373, 3024, 290, 14271, 13 } + }; + yield return new object?[] + { + "She played many games and she felt exhausted afterward", + new string[] { "She", "Ġplayed", "Ġmany", "Ġgames", "Ġand", "Ġshe", "Ġfelt", "Ġexhausted", "Ġafterward" }, + new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) }, + new int[] { 3347, 2826, 867, 1830, 290, 673, 2936, 19064, 20875 } + }; + yield return new object?[] + { + "Hello, y'all! How are you 😁 ?", + new string[] { "Hello", ",", "Ġy", "'", "all", "!", "ĠHow", "Ġare", "Ġyou", "Ġ", "Ġ?" }, + new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 1), (9, 3), (12, 1), (13, 4), (17, 4), (21, 4), (25, 1), (28, 2) }, + new int[] { 15496, 11, 331, 6, 439, 0, 1374, 389, 345, 220, 5633 } + }; + } + } + + [Theory] + [MemberData(nameof(RobertaTestData))] + public async void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) + { + Tokenizer tokenizer = await GetRobertaTokenizer(); + + IReadOnlyList encoding = tokenizer.Encode(text, out _); + IReadOnlyList encoding1 = tokenizer.Encode(text.AsSpan(), out _); + + Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); + + Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); + + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan())); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); + + Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text, expectedIds.Length - 2, out normalizedString, out length)); + Assert.Null(normalizedString); + int expectedLength = expectedOffsets[expectedOffsets.Length - 3].Index + expectedOffsets[expectedOffsets.Length - 3].Length; + Assert.Equal(expectedLength, length); + Assert.Equal(expectedIds.Take(expectedIds.Length - 2), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 2, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(expectedLength, length); + + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text)); + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan())); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); + } + private enum CallingOrder { Encode, @@ -143,66 +261,67 @@ private enum CallingOrder // Calling with callIdsFirst = true will test the other way around. private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode) { - Assert.NotNull(tokenizer.Model); - Assert.True(tokenizer.Model is EnglishRoberta); + Assert.True(tokenizer is EnglishRoberta); Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer); foreach (object[] p in BertaData) { IReadOnlyList ids; - EncodingResult encoding; + IReadOnlyList encoding; int idsCount; if (callingOrder == CallingOrder.Encode) { - encoding = tokenizer.Encode((string)p[0]); + encoding = tokenizer.Encode((string)p[0], out _); ids = tokenizer.EncodeToIds((string)p[0]); idsCount = tokenizer.CountTokens((string)p[0]); } else if (callingOrder == CallingOrder.EncodeToIds) { ids = tokenizer.EncodeToIds((string)p[0]); - encoding = tokenizer.Encode((string)p[0]); + encoding = tokenizer.Encode((string)p[0], out _); idsCount = tokenizer.CountTokens((string)p[0]); } else // CountTokens { idsCount = tokenizer.CountTokens((string)p[0]); ids = tokenizer.EncodeToIds((string)p[0]); - encoding = tokenizer.Encode((string)p[0]); + encoding = tokenizer.Encode((string)p[0], out _); } - Assert.Equal(p[1], encoding.Ids); + int[] encodingIds = encoding.Select(t => t.Id).ToArray(); + (int, int)[] offsets = encoding.Select(t => t.Offset).ToArray(); + string[] tokens = encoding.Select(t => t.Value).ToArray(); + + Assert.Equal(p[1], encodingIds); Assert.Equal(p[1], ids); Assert.Equal(((int[])p[1]).Length, idsCount); - Assert.Equal(p[3], encoding.Offsets); - Assert.Equal(encoding.Ids.Count, encoding.Tokens.Count); - Assert.Equal(encoding.Ids.Count, encoding.Offsets.Count); + Assert.Equal(p[3], offsets); - EnglishRoberta? robertaModel = tokenizer.Model as EnglishRoberta; - Assert.Equal(p[2], encoding.Tokens); + EnglishRoberta? robertaModel = tokenizer as EnglishRoberta; + Assert.Equal(p[2], tokens); - Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encoding.Ids)); + Assert.Equal(string.Concat((string[])(p[robertaModel!.FilterUnsupportedChars ? 5 : 2])), tokenizer.Decode(encodingIds)); Assert.NotNull(robertaModel); - Assert.Equal(encoding.Ids, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encoding.Ids))); - Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encoding.Ids)); + Assert.Equal(encodingIds, robertaModel!.ConvertOccurrenceRanksToIds(robertaModel!.ConvertIdsToOccurrenceRanks(encodingIds))); + Assert.Equal(p[4], robertaModel.ConvertIdsToOccurrenceValues(encodingIds)); - for (int i = 0; i < encoding.Tokens.Count; i++) + for (int i = 0; i < tokens.Length; i++) { if (robertaModel.FilterUnsupportedChars) { string[]? filteredToken = p[5] as string[]; - Assert.Equal(filteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); + Assert.Equal(filteredToken![i], tokenizer.MapIdToToken(encodingIds[i])); } else { - Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); + Assert.Equal(tokens[i], tokenizer.MapIdToToken(encodingIds[i])); string[]? unfilteredToken = p[2] as string[]; - Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); + Assert.Equal(unfilteredToken![i], tokenizer.MapIdToToken(encodingIds[i])); } - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); + Assert.Equal(encodingIds[i], tokenizer.MapTokenToId(tokens[i].AsSpan())); } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs index 21c09893ff..5034771800 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/LlamaTests.cs @@ -17,20 +17,21 @@ namespace Microsoft.ML.Tokenizers.Tests public class LlamaTests { private static readonly HttpClient _httpClient = new HttpClient() { Timeout = TimeSpan.FromMinutes(5) }; - private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer().GetAwaiter().GetResult(); - private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer().GetAwaiter().GetResult(); + private static Tokenizer _llamaTokenizer = CreateLlamaTokenizer(); + private static Tokenizer _llamaMistralTokenizer = CreateLMistralTokenizer(); - private static async Task CreateLlamaTokenizer() + private static Tokenizer CreateLlamaTokenizer() { - const string modelUrl = @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"; - using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl); + // @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true"; + // @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model")); return Tokenizer.CreateLlama(remoteStream); } - private static async Task CreateLMistralTokenizer() + private static Tokenizer CreateLMistralTokenizer() { - const string modelUrl = @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true"; - using Stream remoteStream = await _httpClient.GetStreamAsync(modelUrl); + // @"https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model?download=true"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Mistral", "tokenizer.model")); return Tokenizer.CreateLlama(remoteStream); } @@ -185,13 +186,13 @@ public static IEnumerable LlamaTestData() [MemberData(nameof(LlamaTestData))] public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids, string[] tokens, (int Index, int Length)[] offsets) { - SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe; + SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe; Assert.NotNull(bpe); - EncodingResult result = llamaTokenizer.Encode(input); - Assert.Equal(ids, result.Ids); - Assert.Equal(tokens, result.Tokens); - Assert.Equal(offsets, result.Offsets); + IReadOnlyList result = llamaTokenizer.Encode(input, out _); + Assert.Equal(ids, result.Select(t => t.Id).ToArray()); + Assert.Equal(tokens, result.Select(t => t.Value).ToArray()); + Assert.Equal(offsets, result.Select(t => t.Offset).ToArray()); Assert.Equal(input, llamaTokenizer.Decode(ids)); Assert.Equal(ids, llamaTokenizer.EncodeToIds(input)); Assert.Equal(ids.Length, llamaTokenizer.CountTokens(input)); @@ -208,32 +209,29 @@ public void TestLlamaTokenizer(Tokenizer llamaTokenizer, string input, int[] ids bool isEmptyInput = string.IsNullOrEmpty(input); - IReadOnlyList bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false); + IReadOnlyList bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); Assert.Equal(ids.Skip(1), bpeTokens.Select(token => token.Id)); Assert.Equal(tokens.Skip(1), bpeTokens.Select(token => token.Value)); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); - List encodedIds = new(); - bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds: encodedIds, out _); + IReadOnlyList encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); Assert.Equal(ids.Skip(1), encodedIds); - Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: false, out _)); + Assert.Equal(isEmptyInput ? 0 : ids.Length - 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false)); - bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true); + bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false); Assert.Equal(isEmptyInput ? Array.Empty() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id)); Assert.Equal(isEmptyInput ? Array.Empty() : tokens.Skip(1).Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value)); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); - encodedIds.Clear(); - bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, accumulatedIds: encodedIds, out _); + encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false); Assert.Equal(isEmptyInput ? Array.Empty() : ids.Skip(1).Concat(new[] { bpe.EndOfSentenceId }), encodedIds); - Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: false, addEndOfSentence: true, out _)); + Assert.Equal(isEmptyInput ? 0 : ids.Length, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: true, considerNormalization: false)); - bpeTokens = bpe.Encode(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true); + bpeTokens = bpe.Encode(normalizedInput.AsSpan(), out _, addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false); Assert.Equal(isEmptyInput ? Array.Empty() : ids.Concat(new[] { bpe.EndOfSentenceId }), bpeTokens.Select(token => token.Id)); Assert.Equal(isEmptyInput ? Array.Empty() : tokens.Concat(new[] { bpe.EndOfSentenceToken }), bpeTokens.Select(token => token.Value)); Assert.Equal(input, llamaTokenizer.Decode(bpeTokens.Select(token => token.Id))); - encodedIds.Clear(); - bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, accumulatedIds: encodedIds, out _); + encodedIds = bpe.EncodeToIds(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false); Assert.Equal(isEmptyInput ? Array.Empty() : ids.Concat(new[] { bpe.EndOfSentenceId }), encodedIds); - Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginOfSentence: true, addEndOfSentence: true, out _)); + Assert.Equal(isEmptyInput ? 0 : ids.Length + 1, bpe.CountTokens(normalizedInput.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, considerNormalization: false)); } public static IEnumerable LlamaTokenizersListData() @@ -244,11 +242,17 @@ public static IEnumerable LlamaTokenizersListData() [Theory] [MemberData(nameof(LlamaTokenizersListData))] - public void TestLlamaTokenizerWithInvalidInput(Tokenizer llamaTokenizer) + public void TestLlamaTokenizerWithEmptyInput(Tokenizer llamaTokenizer) { - Assert.Throws(() => llamaTokenizer.Encode(null!)); - Assert.Throws(() => llamaTokenizer.EncodeToIds(null!)); - Assert.Throws(() => llamaTokenizer.CountTokens(null!)); + Assert.Equal([], llamaTokenizer.Encode((string)null!, out _)); + Assert.Equal([], llamaTokenizer.Encode(Span.Empty, out _)); + + Assert.Equal([], llamaTokenizer.EncodeToIds((string)null!)); + Assert.Equal([], llamaTokenizer.EncodeToIds(Span.Empty)); + + Assert.Equal(0, llamaTokenizer.CountTokens((string)null!)); + Assert.Equal(0, llamaTokenizer.CountTokens(Span.Empty)); + Assert.Throws(() => llamaTokenizer.Decode(null!)); } @@ -256,7 +260,7 @@ public void TestLlamaTokenizerWithInvalidInput(Tokenizer llamaTokenizer) [MemberData(nameof(LlamaTokenizersListData))] public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer) { - SentencePieceBpe? bpe = llamaTokenizer.Model as SentencePieceBpe; + SentencePieceBpe? bpe = llamaTokenizer as SentencePieceBpe; Assert.NotNull(bpe); Assert.NotNull(llamaTokenizer.Normalizer); @@ -284,34 +288,242 @@ public void TestLlamaTokenizerProperties(Tokenizer llamaTokenizer) } [Fact] - public void TestLlamaNormalizer() + public void TestSentencePieceNormalizer() { - LlamaNormalizer normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); + SentencePieceNormalizer normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello, World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: false); Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!")); + Assert.Equal(" Hello, World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!")); + Assert.Equal("▁Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: false); Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!")); + Assert.Equal("▁Hello,▁▁▁▁▁▁World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello,▁World!▁", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: true, addDummyPrefix: false, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello,▁World!", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: true, treatWhitespaceAsSuffix: true); Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello,▁▁▁▁▁▁World!▁", normalizer.Normalize("Hello, World!".AsSpan())); - normalizer = new LlamaNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true); + normalizer = new SentencePieceNormalizer(removeExtraWhiteSpaces: false, addDummyPrefix: true, escapeWhiteSpaces: false, treatWhitespaceAsSuffix: true); Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!")); + Assert.Equal("Hello, World! ", normalizer.Normalize("Hello, World!".AsSpan())); + } + + public static IEnumerable TokenizerTestData + { + get + { + // string to tokenize, produced tokens, the token offsets + yield return new object?[] + { + "the brown fox jumped over the lazy dog!", + "▁the▁brown▁fox▁jumped▁over▁the▁lazy▁dog!", + new string[] { "", "▁the", "▁brown", "▁fo", "x", "▁jump", "ed", "▁over", "▁the", "▁lazy", "▁dog", "!" }, + new (int Index, int Length)[] { (0, 0), (0, 4), (4, 6), (10, 3), (13, 1), (14, 5), (19, 2), (21, 5), (26, 4), (30, 5), (35, 4), (39, 1) }, + new int[] { 1, 278, 17354, 1701, 29916, 12500, 287, 975, 278, 17366, 11203, 29991 } + }; + yield return new object?[] + { + "he traveled to Egypt during the summer, the weather was hot and ammunition." , + "▁he▁traveled▁to▁Egypt▁during▁the▁summer,▁the▁weather▁was▁hot▁and▁ammunition." , + new string[] { "", "▁he", "▁tra", "ve", "led", "▁to", "▁Egypt", "▁during", "▁the", "▁summer", ",", "▁the", "▁weather", "▁was", "▁hot", "▁and", "▁am", "mun", "ition", "." }, + new (int Index, int Length)[] { (0, 0), (0, 3), (3, 4), (7, 2), (9, 3), (12, 3), (15, 6), (21, 7), (28, 4), (32, 7), (39, 1), (40, 4), (44, 8), (52, 4), (56, 4), (60, 4), (64, 3), (67, 3), (70, 5), (75, 1) }, + new int[] { 1, 540, 1020, 345, 839, 304, 12892, 2645, 278, 11801, 29892, 278, 14826, 471, 7375, 322, 626, 24579, 654, 29889 } + }; + yield return new object?[] + { + "She played many games and she felt exhausted afterward", + "▁She▁played▁many▁games▁and▁she▁felt▁exhausted▁afterward", + new string[] { "", "▁She", "▁played", "▁many", "▁games", "▁and", "▁she", "▁felt", "▁exha", "usted", "▁after", "ward" }, + new (int Index, int Length)[] { (0, 0), (0, 4), (4, 7), (11, 5), (16, 6), (22, 4), (26, 4), (30, 5), (35, 5), (40, 5), (45, 6), (51, 4) }, + new int[] { 1, 2296, 5318, 1784, 8090, 322, 1183, 7091, 18782, 16656, 1156, 1328 } + }; + yield return new object?[] + { + "Hello, y'all! How are you 😁 ?", + "▁Hello,▁y'all!▁How▁are▁you▁😁▁?", + new string[] { "", "▁Hello", ",", "▁y", "'", "all", "!", "▁How", "▁are", "▁you", "▁", "<0xF0>", "<0x9F>", "<0x98>", "<0x81>", "▁?" }, + new (int Index, int Length)[] { (0, 0), (0, 6), (6, 1), (7, 2), (9, 1), (10, 3), (13, 1), (14, 4), (18, 4), (22, 4), (26, 1), (27, 2), (27, 0), (27, 0), (27, 0), (29, 2) }, + new int[] { 1, 15043, 29892, 343, 29915, 497, 29991, 1128, 526, 366, 29871, 243, 162, 155, 132, 1577 } + }; + } + } + + [Theory] + [MemberData(nameof(TokenizerTestData))] + public void TestTokenizerEncoding(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) + { + Tokenizer tokenizer = _llamaTokenizer; + + Assert.NotNull(tokenizer.Normalizer); + Assert.Null(tokenizer.PreTokenizer); + + IReadOnlyList encoding = tokenizer.Encode(text, out _); + IReadOnlyList encoding1 = tokenizer.Encode(text.AsSpan(), out _); + + Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); + + Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); + + SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!; + foreach (bool considerNormalization in new[] { true, false }) + foreach (bool addBeginningOfSentence in new[] { true, false }) + foreach (bool addEndOfSentence in new[] { true, false }) + { + encoding = sentencePieceBpe.Encode( + considerNormalization ? text : normalizedText, + out _, + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + considerPreTokenization: false, + considerNormalization: considerNormalization); + + encoding1 = sentencePieceBpe.Encode( + considerNormalization ? text.AsSpan() : normalizedText.AsSpan(), + out _, + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + considerPreTokenization: false, + considerNormalization: considerNormalization); + + string[] expectedTokens1 = addBeginningOfSentence ? expectedTokens : expectedTokens.Skip(1).ToArray(); + expectedTokens1 = addEndOfSentence ? expectedTokens1.Concat(new[] { sentencePieceBpe.EndOfSentenceToken }).ToArray() : expectedTokens1; + + (int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets : expectedOffsets.Skip(1).ToArray(); + expectedOffsets1 = addEndOfSentence ? expectedOffsets1.Concat(new[] { (normalizedText.Length, 0) }).ToArray() : expectedOffsets1; + + int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray(); + expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1; + + Assert.Equal(expectedTokens1, encoding.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets1, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds1, encoding.Select(t => t.Id).ToArray()); + } + } + + [Theory] + [MemberData(nameof(TokenizerTestData))] + public void TestTokenizerEncodingToIds(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) + { + Tokenizer tokenizer = _llamaTokenizer; + + Assert.NotNull(expectedTokens); + Assert.NotNull(expectedOffsets); + + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan())); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(normalizedText.Length, length); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(normalizedText.Length, length); + + SentencePieceBpe sentencePieceBpe = (tokenizer as SentencePieceBpe)!; + foreach (bool considerNormalization in new[] { true, false }) + foreach (bool addBeginningOfSentence in new[] { true, false }) + foreach (bool addEndOfSentence in new[] { true, false }) + { + // (string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) + + int[] expectedIds1 = addBeginningOfSentence ? expectedIds : expectedIds.Skip(1).ToArray(); + expectedIds1 = addEndOfSentence ? expectedIds1.Concat(new[] { sentencePieceBpe.EndOfSentenceId }).ToArray() : expectedIds1; + + Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds( + considerNormalization ? text : normalizedText, + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + expectedIds1.Length, + out normalizedString, + out length, + considerNormalization: considerNormalization)); + + Assert.Equal(expectedIds1, sentencePieceBpe.EncodeToIds( + considerNormalization ? text.AsSpan() : normalizedText.AsSpan(), + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + expectedIds1.Length, + out normalizedString, + out length, + considerNormalization: considerNormalization)); + + Assert.Equal(considerNormalization ? normalizedText : null, normalizedString); + Assert.Equal(normalizedText.Length, length); + + Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds( + considerNormalization ? text : normalizedText, + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + expectedIds1.Length - 6, + out normalizedString, + out length, + considerNormalization: considerNormalization)); + Assert.Equal(considerNormalization ? normalizedText : null, normalizedString); + + (int Index, int Length)[] expectedOffsets1 = addBeginningOfSentence ? expectedOffsets.Take(expectedIds1.Length - 6).ToArray() : expectedOffsets.Skip(1).Take(expectedIds1.Length - 6).ToArray(); + + int expectedLength = expectedOffsets1[expectedOffsets1.Length - 1].Index + expectedOffsets1[expectedOffsets1.Length - 1].Length; + Assert.Equal(expectedLength, length); + + Assert.Equal(expectedIds1.Take(expectedIds1.Length - 6), sentencePieceBpe.EncodeToIds( + considerNormalization ? text.AsSpan() : normalizedText.AsSpan(), + addBeginningOfSentence: addBeginningOfSentence, + addEndOfSentence: addEndOfSentence, + expectedIds1.Length - 6, + out normalizedString, + out length, + considerNormalization: considerNormalization)); + Assert.Equal(expectedLength, length); + } + } + + + [Theory] + [MemberData(nameof(TokenizerTestData))] + public void TestTokenizerCountTokens(string text, string normalizedText, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) + { + Tokenizer tokenizer = _llamaTokenizer; + + Assert.NotNull(expectedTokens); + + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text)); + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan())); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 6, out string? normalizedString, out int tokenCount)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(expectedIds.Length - 6, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index + expectedOffsets[expectedOffsets.Length - 7].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 6, out normalizedString, out tokenCount)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(expectedIds.Length - 6, tokenCount); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text, 7, out normalizedString, out tokenCount)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(7, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 7].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 7, out normalizedString, out tokenCount)); + Assert.Equal(normalizedText, normalizedString); + Assert.Equal(7, tokenCount); } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj index 9ec9c38d4c..be99d7ad54 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj @@ -42,6 +42,7 @@ + \ No newline at end of file diff --git a/test/Microsoft.ML.Tokenizers.Tests/NormalizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/NormalizerTests.cs index 5da96a4b30..fb37e93f9d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/NormalizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/NormalizerTests.cs @@ -61,10 +61,9 @@ public void TestNormalizer(Normalizer normalizer, string text, string normalized string normalizedText = normalizer.Normalize(text); Assert.Equal(normalized, normalizedText); - Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), WhiteSpace.Instance, normalizer); - EncodingResult encoding = tokenizer.Encode(text); - Assert.Equal(text, encoding.OriginalString); - Assert.Equal(normalized, encoding.NormalizedString); + Tokenizer tokenizer = BpeTests.CreateEmptyBpe(preTokenizer: null, normalizer); + IReadOnlyList tokens = tokenizer.Encode(text, out string? normalizedString); + Assert.Equal(normalized, normalizedString); } public class RemoveQuotesNormalizer : Normalizer @@ -77,6 +76,22 @@ public override string Normalize(string original) return original; } + return RemoveQuotes(original.AsSpan(), index); + } + + public override string Normalize(ReadOnlySpan original) + { + int index = original.IndexOf('"'); + if (index <= 0) + { + return original.ToString(); + } + + return RemoveQuotes(original, index); + } + + private string RemoveQuotes(ReadOnlySpan original, int index) + { StringBuilder sb = new StringBuilder(original.Length); List mapping = new List(); @@ -97,7 +112,7 @@ public override string Normalize(string original) break; } - index = original.IndexOf('"', start); + index = original.Slice(start).IndexOf('"'); if (index <= 0) { for (int i = start; i < original.Length; i++) @@ -107,6 +122,8 @@ public override string Normalize(string original) } break; } + + index += start; } while (true); return sb.ToString(); @@ -130,6 +147,16 @@ public override string Normalize(string original) return original.Normalize(_normalizationForm); } + + public override string Normalize(ReadOnlySpan original) + { + if (original.IsEmpty) + { + return string.Empty; + } + + return original.ToString().Normalize(_normalizationForm); + } } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs index cd3488db58..b028d4ce6d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/PreTokenizerTests.cs @@ -20,62 +20,63 @@ public static IEnumerable PreTokenizerData { WhiteSpace.Instance, "How are you doing?", - new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (8, 3)), new Split("doing", (12, 5)), new Split("?", (17, 1)),} + new (int Offset, int Length)[] { (0, 3), (4, 3), (8, 3), (12, 5), (17, 1), } }; yield return new object[] { WhiteSpace.Instance, "I_am_Just_Fine!", - new Split[] { new Split("I_am_Just_Fine", (0, 14)), new Split("!", (14, 1)) } + new (int Offset, int Length)[] { (0, 14), (14, 1) } }; yield return new object[] { new SpacePreTokenizer(), "How are you doing?!", - new Split[] { new Split("How", (0, 3)), new Split("are", (4, 3)), new Split("you", (11, 3)), new Split("doing?!", (15, 7)) } + new (int Offset, int Length)[] { (0, 3), (4, 3), (11, 3), (15, 7) } }; yield return new object[] { new SpacePreTokenizer(), new string(' ', 100), - new Split[] { } + new (int Offset, int Length)[] { } }; } } [Theory] [MemberData(nameof(PreTokenizerData))] - public void TestPreTokenizer(PreTokenizer preTokenizer, string text, Split[] splits) + public void TestPreTokenizer(PreTokenizer preTokenizer, string text, (int Offset, int Length)[] splits) { - Split[] splitParts = preTokenizer.PreTokenize(text).ToArray(); + (int Offset, int Length)[] splitParts = preTokenizer.PreTokenize(text).ToArray<(int Offset, int Length)>(); Assert.Equal(splits, splitParts); // Empty tokenizer which tokenize all parts as unknown tokens. - Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), preTokenizer); + Tokenizer tokenizer = BpeTests.CreateEmptyBpe(normalizer: null, preTokenizer: preTokenizer); - EncodingResult encoding = tokenizer.Encode(text); - Assert.True(encoding.Tokens.Count >= splitParts.Length, $"Expected to have {encoding.Tokens.Count} >= {splitParts.Length}"); + IReadOnlyList encoding = tokenizer.Encode(text, out _); + Assert.True(encoding.Count >= splitParts.Length, $"Expected to have {encoding.Count} >= {splitParts.Length}"); } [Fact] public void TestWhiteSpacePreTokenizer() { - Assert.Empty(WhiteSpace.Instance.PreTokenize(null!)); + Assert.Empty(WhiteSpace.Instance.PreTokenize((string)null!)); } public class SpacePreTokenizer : PreTokenizer { - public override IEnumerable PreTokenize(string text) + public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan text) { - List splits = new(); - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { - return splits; + return []; } + List<(int Offset, int Length)> splits = new(); + int index = 0; while (true) { @@ -92,7 +93,7 @@ public override IEnumerable PreTokenize(string text) if (index < text.Length) { - splits.Add(new Split(text.Substring(index, end - index), (index, end - index))); + splits.Add((index, end - index)); } else { @@ -104,6 +105,16 @@ public override IEnumerable PreTokenize(string text) return splits; } + + public override IEnumerable<(int Offset, int Length)> PreTokenize(string text) + { + if (string.IsNullOrEmpty(text)) + { + return []; + } + + return PreTokenize(text.AsSpan()); + } } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index cdaf84bad7..15af661994 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -38,8 +38,8 @@ public async void TestTokenizerCreation() { TestGPT4TokenizationEncoding(GPT4); - Assert.True(GPT4.Model is Tiktoken); - IReadOnlyDictionary? specialTokensEncoder = (GPT4.Model as Tiktoken)!.SpecialTokens; + Assert.True(GPT4 is Tiktoken); + IReadOnlyDictionary? specialTokensEncoder = (GPT4 as Tiktoken)!.SpecialTokens; string tokenizerDataFileName = Utils.CreateTemporaryFile("tiktoken"); @@ -53,21 +53,21 @@ public async void TestTokenizerCreation() try { - Tokenizer tokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer); + Tokenizer tokenizer = new Tiktoken(tokenizerDataFileName, GPT4.PreTokenizer, specialTokensEncoder); TestGPT4TokenizationEncoding(tokenizer); using (Stream stream = File.OpenRead(tokenizerDataFileName)) { - tokenizer = new Tokenizer(new Tiktoken(stream, specialTokensEncoder), GPT4.PreTokenizer); + tokenizer = new Tiktoken(stream, GPT4.PreTokenizer, specialTokensEncoder); } TestGPT4TokenizationEncoding(tokenizer); - tokenizer = new Tokenizer(await Tiktoken.CreateAsync(tokenizerDataFileName, specialTokensEncoder), GPT4.PreTokenizer); + tokenizer = await Tokenizer.CreateTiktokenAsync(tokenizerDataFileName, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder); TestGPT4TokenizationEncoding(tokenizer); using (Stream stream = File.OpenRead(tokenizerDataFileName)) { - tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer); + tokenizer = await Tokenizer.CreateTiktokenAsync(stream, GPT4.PreTokenizer, normalizer: null, specialTokensEncoder); } TestGPT4TokenizationEncoding(tokenizer); @@ -109,11 +109,11 @@ public async void TestTokenizerUsingExternalVocab(Tokenizer tokenizer, string ur try { - Tiktoken tiktoken = (tokenizer.Model as Tiktoken)!; - Tokenizer externalTokenizer = new Tokenizer(new Tiktoken(tokenizerDataFileName, tiktoken.SpecialTokens), tokenizer.PreTokenizer); + Tiktoken tiktoken = (tokenizer as Tiktoken)!; + Tokenizer externalTokenizer = new Tiktoken(tokenizerDataFileName, tokenizer.PreTokenizer, tiktoken.SpecialTokens); IReadOnlyDictionary, int> encoder = tiktoken.Encoder; - IReadOnlyDictionary, int> externalEncoder = (externalTokenizer.Model as Tiktoken)!.Encoder; + IReadOnlyDictionary, int> externalEncoder = (externalTokenizer as Tiktoken)!.Encoder; Assert.Equal(externalEncoder.Count, encoder.Count); foreach (KeyValuePair, int> kvp in encoder) @@ -135,13 +135,17 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer) Assert.Equal(new List() { 9906, 4435 }, encoded); Assert.Equal(text, tokenizer.Decode(encoded.ToArray())!); - EncodingResult result = tokenizer.Encode(text); + IReadOnlyList result = tokenizer.Encode(text, out string? normalizedString); int idsCount = tokenizer.CountTokens(text); - Assert.Equal(encoded, result.Ids); - Assert.Equal(new string[] { "Hello", " World" }, result.Tokens); - Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets); + + int[] ids = result.Select(token => token.Id).ToArray(); + string[] tokens = result.Select(token => token.Value).ToArray(); + (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + Assert.Equal(encoded, ids); + Assert.Equal(new string[] { "Hello", " World" }, tokens); + Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, offsets); Assert.Equal(encoded.Count, idsCount); - Assert.Equal(encoded, result.Ids); + Assert.Equal(encoded, ids); TestGPT4Tokenizer(tokenizer); } @@ -154,13 +158,18 @@ public void TestEncode1() Assert.Equal(new List() { 100264, 9906, 4435, 100265 }, encoded); Assert.Equal(text, GPT4.Decode(encoded.ToArray())); - EncodingResult result = GPT4.Encode(text); + IReadOnlyList result = GPT4.Encode(text, out string? normalizedString); int idsCount = GPT4.CountTokens(text); - Assert.Equal(encoded, result.Ids); - Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, result.Tokens); - Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, result.Offsets); + + int[] ids = result.Select(token => token.Id).ToArray(); + string[] tokens = result.Select(token => token.Value).ToArray(); + (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + + Assert.Equal(encoded, ids); + Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, tokens); + Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 6), (23, 10) }, offsets); Assert.Equal(encoded.Count, idsCount); - Assert.Equal(encoded, result.Ids); + Assert.Equal(encoded, ids); } private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer) @@ -192,12 +201,16 @@ public void TestEncode3() string? decoded = GPT4.Decode(encoded.ToArray()); Assert.Equal(text, decoded); - EncodingResult result = GPT4.Encode(text); + IReadOnlyList result = GPT4.Encode(text, out string? normalizedString); + int[] ids = result.Select(token => token.Id).ToArray(); + string[] tokens = result.Select(token => token.Value).ToArray(); + (int, int)[] offsets = result.Select(token => token.Offset).ToArray(); + int idsCount = GPT4.CountTokens(text); - Assert.Equal(encoded, result.Ids); + Assert.Equal(encoded, ids); Assert.Equal(encoded.Count, idsCount); - Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, result.Tokens); - Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, result.Offsets); + Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, tokens); + Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 10), (27, 6) }, offsets); } [Fact] @@ -207,12 +220,10 @@ public void TestEncode4() IReadOnlyList encoded = GPT4.EncodeToIds(text); Assert.Empty(encoded); - EncodingResult result = GPT4.Encode(text); + IReadOnlyList result = GPT4.Encode(text, out string? normalizedString); int idsCount = GPT4.CountTokens(text); - Assert.Empty(result.Ids); - Assert.Empty(result.Tokens); - Assert.Empty(result.Offsets); - Assert.Equal(result.Ids.Count, idsCount); + Assert.Empty(result); + Assert.Equal(0, idsCount); } [Fact] @@ -224,11 +235,11 @@ public void TestEncode5() Assert.Equal(new List() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded); Assert.Equal(text, GPT4.Decode(encoded.ToArray())); - EncodingResult result = GPT4.Encode(text); - Assert.Equal(encoded, result.Ids); + IReadOnlyList result = GPT4.Encode(text, out string? normalizedString); + Assert.Equal(encoded, result.Select(token => token.Id).ToArray()); Assert.Equal(encoded.Count, idsCount); - Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Tokens); - Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Offsets); + Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Select(token => token.Value).ToArray()); + Assert.Equal(new List<(int, int)> { (0, 12), (12, 5), (17, 2), (19, 0), (19, 6), (25, 10) }, result.Select(token => token.Offset).ToArray()); } [Fact] @@ -310,9 +321,12 @@ public void TestEncodeR50kBase() [Theory] [InlineData("gpt-4")] [InlineData("gpt-4-")] + [InlineData("gpt-3.5-")] [InlineData("gpt-3.5-turbo")] [InlineData("gpt-3.5-turbo-")] [InlineData("gpt-3.5-turbo-16k")] + [InlineData("gpt-35")] + [InlineData("gpt-35-")] [InlineData("gpt-35-turbo")] [InlineData("gpt-35-turbo-16k")] [InlineData("gpt-35-turbo-")] @@ -351,7 +365,7 @@ public void TestEncodeR50kBase() public void TestAllSupportedModelNames(string modelName) { Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(modelName); - Assert.NotNull(tokenizer.Model); + Assert.True(tokenizer is Tiktoken); Assert.NotNull(tokenizer.PreTokenizer); } @@ -363,7 +377,7 @@ public void TestAllSupportedModelNames(string modelName) public void TestAllSupportedEncodingNames(string encodingName) { Tokenizer tokenizer = Tokenizer.CreateTiktokenForEncoding(encodingName); - Assert.NotNull(tokenizer.Model); + Assert.True(tokenizer is Tiktoken); Assert.NotNull(tokenizer.PreTokenizer); string modelName = encodingName.ToLowerInvariant() switch @@ -377,15 +391,16 @@ public void TestAllSupportedEncodingNames(string encodingName) Tokenizer tokenizer1 = Tokenizer.CreateTiktokenForModel(modelName); - Tiktoken? model1 = tokenizer.Model as Tiktoken; - Tiktoken? model2 = tokenizer1.Model as Tiktoken; - Assert.NotNull(model1); - Assert.NotNull(model2); + Assert.True(tokenizer is Tiktoken); + Assert.True(tokenizer1 is Tiktoken); + + Tiktoken tiktoken = (tokenizer as Tiktoken)!; + Tiktoken tiktoken1 = (tokenizer1 as Tiktoken)!; - Assert.Equal(model2.Encoder, model1.Encoder); - Assert.Equal(model2.Decoder, model1.Decoder); - Assert.Equal(model2.SpecialTokens, model1.SpecialTokens); - Assert.Equal(model2.Vocab, model1.Vocab); + Assert.Equal(tiktoken1.Encoder, tiktoken.Encoder); + Assert.Equal(tiktoken1.Decoder, tiktoken.Decoder); + Assert.Equal(tiktoken1.SpecialTokens, tiktoken.SpecialTokens); + Assert.Equal(tiktoken1.Vocab, tiktoken.Vocab); } [Fact] @@ -408,11 +423,99 @@ public void TestCreationUsingModel(string modelName) RemoteExecutor.Invoke(static (name) => { Tokenizer tokenizer = Tokenizer.CreateTiktokenForModel(name); - Assert.NotNull(tokenizer.Model); + Assert.True(tokenizer is Tiktoken); Assert.NotNull(tokenizer.PreTokenizer); }, modelName).Dispose(); } + public static IEnumerable TokenizerTestData + { + get + { + // string to tokenize, produced tokens, the token offsets + yield return new object?[] + { + "the brown fox jumped over the lazy dog!", + new string[] { "the", " brown", " fox", " jumped", " over", " the", " lazy", " dog", "!" }, + new (int Index, int Length)[] { (0, 3), (3, 6), (9, 4), (13, 7), (20, 5), (25, 4), (29, 5), (34, 4), (38, 1) }, + new int[] { 1820, 14198, 39935, 27096, 927, 279, 16053, 5679, 0 } + }; + yield return new object?[] + { + "he traveled to Egypt during the summer, the weather was hot and ammunition." , + new string[] { "he", " traveled", " to", " Egypt", " during", " the", " summer", ",", " the", " weather", " was", " hot", " and", " ammunition", "." }, + new (int Index, int Length)[] { (0, 2), (2, 9), (11, 3), (14, 6), (20, 7), (27, 4), (31, 7), (38, 1), (39, 4), (43, 8), (51, 4), (55, 4), (59, 4), (63, 11), (74, 1) }, + new int[] { 383, 31796, 311, 15212, 2391, 279, 7474, 11, 279, 9282, 574, 4106, 323, 37768, 13 } + }; + yield return new object?[] + { + "She played many games and she felt exhausted afterward", + new string[] { "She", " played", " many", " games", " and", " she", " felt", " exhausted", " afterward" }, + new (int Index, int Length)[] { (0, 3), (3, 7), (10, 5), (15, 6), (21, 4), (25, 4), (29, 5), (34, 10), (44, 10) }, + new int[] { 8100, 6476, 1690, 3953, 323, 1364, 6612, 39019, 49043 } + }; + yield return new object?[] + { + "Hello, y'all! How are you 😁 ?", + new string[] { "Hello", ",", " y", "'all", "!", " How", " are", " you", " 😁", "", " ?" }, + new (int Index, int Length)[] { (0, 5), (5, 1), (6, 2), (8, 4), (12, 1), (13, 4), (17, 4), (21, 4), (25, 3), (28, 0), (28, 2) }, + new int[] { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949 } + }; + } + } + + [Theory] + [MemberData(nameof(TokenizerTestData))] + public void TestTokenizerEncoding(string text, string[] expectedTokens, (int Index, int Length)[] expectedOffsets, int[] expectedIds) + { + Tokenizer tokenizer = GPT4; + + IReadOnlyList encoding = tokenizer.Encode(text, out _); + IReadOnlyList encoding1 = tokenizer.Encode(text.AsSpan(), out _); + + Assert.Equal(expectedTokens, encoding.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding.Select(t => t.Id).ToArray()); + + Assert.Equal(expectedTokens, encoding1.Select(t => t.Value).ToArray()); + Assert.Equal(expectedOffsets, encoding1.Select(t => t.Offset).ToArray()); + Assert.Equal(expectedIds, encoding1.Select(t => t.Id).ToArray()); + + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text)); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan())); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text, expectedIds.Length, out string? normalizedString, out int length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); + Assert.Equal(expectedIds, tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(text.Length, length); + + Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text, expectedIds.Length - 4, out normalizedString, out length)); + Assert.Null(normalizedString); + int expectedLength = expectedOffsets[expectedOffsets.Length - 5].Index + expectedOffsets[expectedOffsets.Length - 5].Length; + Assert.Equal(expectedLength, length); + Assert.Equal(expectedIds.Take(expectedIds.Length - 4), tokenizer.EncodeToIds(text.AsSpan(), expectedIds.Length - 4, out normalizedString, out length)); + Assert.Null(normalizedString); + Assert.Equal(expectedLength, length); + + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text)); + Assert.Equal(expectedIds.Length, tokenizer.CountTokens(text.AsSpan())); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text, expectedIds.Length - 3, out normalizedString, out int tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 4].Index + expectedOffsets[expectedOffsets.Length - 4].Length, tokenizer.IndexOfTokenCount(text.AsSpan(), expectedIds.Length - 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(expectedIds.Length - 3, tokenCount); + + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text, 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); + Assert.Equal(expectedOffsets[expectedOffsets.Length - 3].Index, tokenizer.LastIndexOfTokenCount(text.AsSpan(), 3, out normalizedString, out tokenCount)); + Assert.Null(normalizedString); + Assert.Equal(3, tokenCount); + } + // Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'. // This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail. private string ReadAndSanitizeFile(string path) diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index 106f278ce7..aa9751889d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -26,12 +26,12 @@ internal static void TestTokenLimits(Tokenizer tokenizer) for (int i = 1; i <= fullIdsList.Count; i++) { - int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1); - int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2); - IReadOnlyList partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string processedText, out int textLength); + int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string? processedText1, out int tokenCount1); + int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string? processedText2, out int tokenCount2); + IReadOnlyList partialIdsList = tokenizer.EncodeToIds(input, maxTokenCount: i, out string? processedText, out int textLength); - Assert.True(textLength <= processedText.Length); - Assert.True(tokenizer.Normalizer is not null || processedText == input); + Assert.True(processedText is null || textLength <= processedText.Length); + Assert.True(tokenizer.Normalizer is not null || processedText is null); Assert.Equal(fullIdsList.Take(partialIdsList.Count), partialIdsList); @@ -42,14 +42,13 @@ internal static void TestTokenLimits(Tokenizer tokenizer) // In this case, we'll get index1 equal to zero and nothing really will need to be tested. if (tokenCount1 > 0 && index1 > 0) { - string prefixString = processedText1.Substring(0, index1); + string prefixString = (processedText1 ?? input).Substring(0, index1); - if (tokenizer.Model is SentencePieceBpe) + if (tokenizer is SentencePieceBpe) { // SentencePieceBpe model normalize the text and insert more characters. // We call the model directly to bypass the normalization step - prefixIds = new List(); - tokenizer.Model.EncodeToIds(prefixString.AsSpan(), (prefixIds as IList)!, out _); + prefixIds = tokenizer.EncodeToIds(prefixString.AsSpan(), considerNormalization: false); } else { @@ -61,14 +60,13 @@ internal static void TestTokenLimits(Tokenizer tokenizer) if (tokenCount2 > 0) { - string suffixString = processedText2.Substring(index2); + string suffixString = (processedText2 ?? input).Substring(index2); - if (tokenizer.Model is SentencePieceBpe) + if (tokenizer is SentencePieceBpe) { // SentencePieceBpe model normalize the text and insert more characters. // We call the model directly to bypass the normalization step - suffixIds = new List(); - tokenizer.Model.EncodeToIds(suffixString.AsSpan(), (suffixIds as IList)!, out _); + suffixIds = tokenizer.EncodeToIds(suffixString.AsSpan(), considerNormalization: false); if (i < fullIdsList.Count) { suffixIds = suffixIds.Skip(1).ToList(); // Skip the start of sentence token @@ -85,12 +83,13 @@ internal static void TestTokenLimits(Tokenizer tokenizer) if (i == fullIdsList.Count) { - if (index1 != processedText1.Length) + string s = processedText1 ?? input; + if (index1 != s.Length) { // It's possible that the remaining text on the left doesn't produce any tokens, as in the case of BPE, // where the pre-tokenizer removes spaces and the left text consists entirely of spaces. - Assert.True(index1 < processedText1.Length); - Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(index1))); + Assert.True(index1 < s.Length); + Assert.Equal(0, tokenizer.CountTokens(s.Substring(index1))); } if (index2 != 0) @@ -98,7 +97,7 @@ internal static void TestTokenLimits(Tokenizer tokenizer) // It's possible that the remaining text on the right doesn't produce any tokens, as in the case of BPE, // where the pre-tokenizer removes spaces and the left text consists entirely of spaces. Assert.True(index2 > 0); - Assert.Equal(0, tokenizer.CountTokens(processedText1.Substring(0, index2))); + Assert.Equal(0, tokenizer.CountTokens(s.Substring(0, index2))); } Assert.Equal(fullIdsList, prefixIds); @@ -106,13 +105,15 @@ internal static void TestTokenLimits(Tokenizer tokenizer) } } + Assert.Equal(0, tokenizer.IndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _)); + Assert.Equal(0, tokenizer.LastIndexOfTokenCount((string)null!, maxTokenCount: 10, out _, out _)); + Assert.Equal(0, tokenizer.IndexOfTokenCount(Span.Empty, maxTokenCount: 10, out _, out _)); + Assert.Equal(0, tokenizer.LastIndexOfTokenCount(Span.Empty, maxTokenCount: 10, out _, out _)); + Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); - - Assert.Throws(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); - Assert.Throws(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); } } } \ No newline at end of file