diff --git a/eng/Versions.props b/eng/Versions.props
index aeb07aebbe..b494d4d1b3 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -41,7 +41,7 @@
1.0.0-beta.23509.3
1.16.3
0.0.0.12
-
2023.0.0.23189
@@ -87,6 +87,7 @@
0.0.13-test
0.0.6-test
0.0.7-test
+ 2.0.0-beta.24218.2
4.8.6
1.0.118
1.2.7
diff --git a/src/Microsoft.ML.Tokenizers/EncodingResult.cs b/src/Microsoft.ML.Tokenizers/EncodingResult.cs
deleted file mode 100644
index 9401cfa490..0000000000
--- a/src/Microsoft.ML.Tokenizers/EncodingResult.cs
+++ /dev/null
@@ -1,152 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Microsoft.ML.Tokenizers
-{
- ///
- /// The Encoding represents the output of a Tokenizer.
- ///
- public sealed class EncodingResult
- {
- ///
- /// Create a new object of the EncodingResult object.
- ///
- /// The list of tokens to merge.
- /// The list of tokens to merge.
- /// The list of tokens to merge.
- /// Indicate whether the offsets is mapped to the original string or the normalized string.
- public EncodingResult(string originalString, string normalizedString, IEnumerable splits, bool offsetsMappedToOriginalString)
- {
- OriginalString = originalString;
- NormalizedString = normalizedString;
- Splits = splits;
- OffsetsMappedToOriginalString = offsetsMappedToOriginalString;
- }
-
- ///
- /// Gets the original tokenized string.
- ///
- public string? OriginalString { get; }
-
- ///
- /// Gets the normalized form of the original string.
- ///
- public string? NormalizedString { get; }
-
- ///
- /// Gets the normalized form of the original string.
- ///
- public bool OffsetsMappedToOriginalString { get; }
-
- internal IEnumerable Splits { get; }
- private List? _tokens;
- private List? _tokensWords;
- private List? _ids;
- private List<(int Index, int Length)>? _offsets;
-
- internal void AddTokens(IReadOnlyList addedTokens)
- {
- if (_tokens is null)
- {
- _tokens = new(addedTokens);
- return;
- }
-
- foreach (var token in addedTokens)
- {
- _tokens.Add(token);
- }
- }
-
- ///
- /// Gets list of the tokens Ids.
- /// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands.
- ///
- public IReadOnlyList Ids
- {
- get
- {
- if (_ids is not null)
- {
- return _ids;
- }
-
- if (_tokens is null)
- {
- return Array.Empty();
- }
-
- _ids = new List(_tokens.Count);
-
- foreach (var token in _tokens)
- {
- _ids.Add(token.Id);
- }
-
- return _ids;
- }
- }
-
- ///
- /// Gets the generated tokens. They are the string representation of the Ids.
- ///
- public IReadOnlyList Tokens
- {
- get
- {
- if (_tokensWords is not null)
- {
- return _tokensWords;
- }
-
- if (_tokens is null)
- {
- return Array.Empty();
- }
-
- _tokensWords = new List(_tokens.Count);
-
- foreach (var token in _tokens)
- {
- _tokensWords.Add(token.Value);
- }
-
- return _tokensWords;
- }
- }
-
- ///
- /// Gets The list of offsets. These offsets let's you slice the input string, and thus retrieve
- /// the original part that led to producing the corresponding token.
- ///
- public IReadOnlyList<(int Index, int Length)> Offsets
- {
- get
- {
- if (_offsets is not null)
- {
- return _offsets;
- }
-
- if (_tokens is null)
- {
- return Array.Empty<(int, int)>();
- }
-
- _offsets = new List<(int Index, int Length)>(_tokens.Count);
-
- foreach (var token in _tokens)
- {
- _offsets.Add(token.Offset);
- }
-
- return _offsets;
- }
- }
- }
-}
diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
index 5a8fb4d330..ad6627da8c 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
@@ -17,13 +17,15 @@ namespace Microsoft.ML.Tokenizers
///
/// Represent the Byte Pair Encoding model.
///
- public sealed class Bpe : Model
+ public sealed class Bpe : Tokenizer
{
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
private const int MaxWordLengthToCache = 15;
private string? _unknownToken;
private int? _unknownTokenId;
+ private readonly PreTokenizer? _preTokenizer;
+ private readonly Normalizer? _normalizer;
///
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
@@ -74,13 +76,15 @@ private set
///
/// The JSON file path containing the dictionary of string keys and their ids.
/// The file path containing the tokens's pairs list.
+ /// The pre-tokenizer to use.
+ /// The normalizer to use.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
/// Indicate whether allowing multiple unknown tokens get fused.
- public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
+ public Bpe(string vocabFile, string? mergesFile, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
- mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
+ mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
{
}
@@ -89,16 +93,18 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
///
/// The JSON stream containing the dictionary of string keys and their ids.
/// The stream containing the tokens's pairs list.
+ /// The pre-tokenizer to use.
+ /// The normalizer to use.
/// The unknown token to be used by the model.
/// The prefix to attach to sub-word units that don’t represent a beginning of word.
/// The suffix to attach to sub-word units that represent an end of word.
/// Indicate whether allowing multiple unknown tokens get fused.
- public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
- this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
+ public Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
+ this(vocabStream, mergesStream, preTokenizer, normalizer, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
{
}
- private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
+ private Bpe(Stream vocabStream, Stream? mergesStream, PreTokenizer? preTokenizer, Normalizer? normalizer, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
{
try
{
@@ -110,6 +116,8 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
+ _preTokenizer = preTokenizer ?? WhiteSpace.Instance; // Default to WhiteSpace pre-tokenizer
+ _normalizer = normalizer;
(Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary();
@@ -166,47 +174,320 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
}
///
- /// Encode a text string to a list of tokens.
+ /// Gets the PreTokenizer used by the Tokenizer.
+ ///
+ public override PreTokenizer? PreTokenizer => _preTokenizer;
+
+ ///
+ /// Gets the Normalizer in use by the Tokenizer.
+ ///
+ public override Normalizer? Normalizer => _normalizer;
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
///
/// The text to encode.
- /// The list of tokens generated from the text tokenization.
- public override IReadOnlyList Encode(ReadOnlySpan text)
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span.Empty, out normalizedString, considerPreTokenization, considerNormalization);
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
+
+ private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
{
- if (text.Length == 0)
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ return [];
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ List tokens = new();
+ PriorityQueue? priorityQueue = null;
+
+ if (splits is not null)
+ {
+ foreach ((int Offset, int Length) split in splits)
+ {
+ EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue);
+ }
+ }
+ else
{
- return EmptyTokensList;
+ EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue);
}
- return EncodeWithCache(text);
+ return tokens;
}
///
- /// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
+ /// Encodes input text to token Ids.
///
/// The text to encode.
- /// The list of accumulated encoded Ids.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
/// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, accumulatedIds, maxTokens, out textLength);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
///
- /// Get the number of tokens that the input text will be encoded to.
+ /// Encodes input text to token Ids up to maximum number of tokens.
///
/// The text to encode.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
/// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsWithCache(text, null, maxTokens, out textLength);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ textLength = 0;
+ normalizedString = null;
+ return [];
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ List ids = new();
+ PriorityQueue? priorityQueue = null;
+
+ if (splits is not null)
+ {
+ textLength = 0;
+ foreach ((int Offset, int Length) split in splits)
+ {
+ EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue);
+ textLength = split.Offset + length;
+
+ if (length < split.Length || ids.Count >= maxTokenCount)
+ {
+ break;
+ }
+ }
+ }
+ else
+ {
+ EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out textLength, ref priorityQueue);
+ }
+
+ return ids;
+ }
///
/// Get the number of tokens that the input text will be encoded to.
///
/// The text to encode.
- /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndWithCache(text, null, maxTokens, out textIndex);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ private int CountTokens(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ textLength = 0;
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ return 0;
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ PriorityQueue? priorityQueue = null;
+ int count = 0;
+ if (splits is not null)
+ {
+ foreach ((int Offset, int Length) split in splits)
+ {
+ count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
+ textLength = split.Offset + length;
+
+ if (length < split.Length || count >= maxTokenCount)
+ {
+ break;
+ }
+ }
+ }
+ else
+ {
+ count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue);
+ }
+
+ return count;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled;
+ /// conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(text, Span.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
+
+ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ tokenCount = 0;
+ return 0;
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ PriorityQueue? priorityQueue = null;
+
+ if (splits is not null)
+ {
+ tokenCount = 0;
+ foreach ((int Offset, int Length) split in splits.Reverse())
+ {
+ tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue);
+ if (textIndex > 0 || tokenCount >= maxTokenCount)
+ {
+ return split.Offset + textIndex;
+ }
+ }
+ }
+ else
+ {
+ tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int textLength, ref priorityQueue);
+ return textLength;
+ }
+
+ return 0;
+ }
///
/// Map the token to encoded Id.
@@ -383,7 +664,7 @@ internal string CharToString(char c)
return s;
}
- internal Word MergeWord(ReadOnlySpan w)
+ internal Word MergeWord(ReadOnlySpan w, ref PriorityQueue? priorityQueue)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
@@ -494,23 +775,24 @@ internal Word MergeWord(ReadOnlySpan w)
word.Add(unk.Value.Id, unk.Value.Len);
}
- word.MergeAll(Merges, Dropout);
+ word.MergeAll(Merges, Dropout, ref priorityQueue);
return word;
}
- internal List WordToTokens(ref Word word) => word.ToTokens(VocabReverse);
+ internal void WordToTokens(ref Word word, List tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset);
- internal List EncodeWithCache(ReadOnlySpan text)
+ internal void EncodeWithCache(ReadOnlySpan text, List tokens, int offset, ref PriorityQueue? priorityQueue)
{
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out word))
{
- return WordToTokens(ref word);
+ WordToTokens(ref word, tokens, offset);
+ return;
}
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
@@ -519,15 +801,15 @@ internal List EncodeWithCache(ReadOnlySpan text)
}
else
{
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
}
- return WordToTokens(ref word);
+ WordToTokens(ref word, tokens, offset);
}
internal int WordToIds(ref Word word, IList? accumulatedIds, out int textLength, int fullTextLength, int maxTokens)
{
- if (word.SymbolsCount < maxTokens)
+ if (word.SymbolsCount <= maxTokens)
{
textLength = fullTextLength;
if (accumulatedIds is not null)
@@ -548,7 +830,7 @@ internal int WordToIds(ref Word word, IList? accumulatedIds, out int textLe
internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
{
- if (word.SymbolsCount < maxTokens)
+ if (word.SymbolsCount <= maxTokens)
{
textIndex = 0;
if (accumulatedIds is not null)
@@ -567,7 +849,7 @@ internal int WordToIdsFromEnd(ref Word word, IList? accumulatedIds, out int
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
}
- internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textLength)
+ private int EncodeToIdsWithCache(ReadOnlySpan text, List? accumulatedIds, int maxTokens, out int textLength, ref PriorityQueue? priorityQueue)
{
Word word;
@@ -578,7 +860,7 @@ internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulat
return WordToIds(ref hit, accumulatedIds, out textLength, text.Length, maxTokens);
}
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
@@ -587,13 +869,13 @@ internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulat
}
else
{
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
}
return WordToIds(ref word, accumulatedIds, out textLength, text.Length, maxTokens);
}
- internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textIndex)
+ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue? priorityQueue)
{
Word word;
@@ -604,7 +886,7 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
}
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
@@ -613,12 +895,10 @@ internal int EncodeToIdsFromEndWithCache(ReadOnlySpan text, IList? ac
}
else
{
- word = MergeWord(text);
+ word = MergeWord(text, ref priorityQueue);
}
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
}
-
- internal static readonly List EmptyTokensList = new();
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
index 013af1917f..ac4da084c2 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
@@ -8,6 +8,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
+using System.Text;
using System.Text.Json;
namespace Microsoft.ML.Tokenizers
@@ -15,7 +16,7 @@ namespace Microsoft.ML.Tokenizers
///
/// Represent the Byte Pair Encoding model.
///
- public sealed class EnglishRoberta : Model
+ public sealed class EnglishRoberta : Tokenizer
{
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly Dictionary _vocab;
@@ -26,6 +27,8 @@ public sealed class EnglishRoberta : Model
private readonly IReadOnlyDictionary _unicodeToByte;
private readonly string[] _charToString;
private readonly StringSpanOrdinalKeyCache> _cache;
+ private readonly PreTokenizer? _preTokenizer;
+ private readonly Normalizer? _normalizer;
///
/// Indicate if want to filter the unsupported characters during the decoding.
@@ -38,47 +41,15 @@ public sealed class EnglishRoberta : Model
/// The JSON file path containing the dictionary of string keys and their ids.
/// The file path containing the tokens's pairs list.
/// Remap the original GPT-2 model Ids to high occurrence ranks and values.
+ /// The pre-tokenizer to use.
+ /// The normalizer to use.
/// Indicate if want to filter the unsupported characters during the decoding.
- public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, bool filterUnsupportedChars = true)
+ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
+ this(vocabularyPath is null ? throw new ArgumentNullException(nameof(vocabularyPath)) : File.OpenRead(vocabularyPath),
+ mergePath is null ? throw new ArgumentNullException(nameof(mergePath)) : File.OpenRead(mergePath),
+ highestOccurrenceMappingPath is null ? throw new ArgumentNullException(nameof(highestOccurrenceMappingPath)) : File.OpenRead(highestOccurrenceMappingPath),
+ preTokenizer, normalizer, filterUnsupportedChars, true)
{
- if (vocabularyPath is null)
- {
- throw new ArgumentNullException(nameof(vocabularyPath));
- }
-
- if (mergePath is null)
- {
- throw new ArgumentNullException(nameof(mergePath));
- }
-
- if (highestOccurrenceMappingPath is null)
- {
- throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
- }
-
- FilterUnsupportedChars = filterUnsupportedChars;
-
- using Stream vocabularyStream = File.OpenRead(vocabularyPath);
- using Stream mergeStream = File.OpenRead(mergePath);
- using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);
-
- // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
- // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
- // highestOccurrenceMappingPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
-
- _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
- _vocab = GetVocabulary(vocabularyStream);
- _vocabReverse = _vocab.ReverseSorted();
- _mergeRanks = GetMergeRanks(mergeStream);
- int maxCharValue = GetByteToUnicode(out _byteToUnicode);
- _charToString = new string[maxCharValue];
- for (char c = (char)0; c < (char)maxCharValue; c++)
- {
- _charToString[c] = c.ToString();
- }
-
- _unicodeToByte = _byteToUnicode.Reverse();
- _cache = new StringSpanOrdinalKeyCache>();
}
///
@@ -87,8 +58,15 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
/// The stream of a JSON file containing the dictionary of string keys and their ids.
/// The stream of a file containing the tokens's pairs list.
/// Remap the original GPT-2 model Ids to high occurrence ranks and values.
+ /// The pre-tokenizer to use.
+ /// The normalizer to use.
/// Indicate if want to filter the unsupported characters during the decoding.
- public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, bool filterUnsupportedChars = true)
+ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, bool filterUnsupportedChars = true) :
+ this(vocabularyStream, mergeStream, highestOccurrenceMappingStream, preTokenizer, normalizer, filterUnsupportedChars, false)
+ {
+ }
+
+ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream, PreTokenizer? preTokenizer, Normalizer? normalizer, bool filterUnsupportedChars, bool disposeStream)
{
if (vocabularyStream is null)
{
@@ -106,6 +84,12 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}
FilterUnsupportedChars = filterUnsupportedChars;
+ _preTokenizer = preTokenizer;
+ _normalizer = normalizer;
+
+ // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
+ // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
+ // highestOccurrenceMappingPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
@@ -120,8 +104,25 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
_unicodeToByte = _byteToUnicode.Reverse();
_cache = new StringSpanOrdinalKeyCache>();
+
+ if (disposeStream)
+ {
+ vocabularyStream.Dispose();
+ mergeStream.Dispose();
+ highestOccurrenceMappingStream.Dispose();
+ }
}
+ ///
+ /// Gets the PreTokenizer used by the Tokenizer.
+ ///
+ public override PreTokenizer? PreTokenizer => _preTokenizer;
+
+ ///
+ /// Gets the Normalizer in use by the Tokenizer.
+ ///
+ public override Normalizer? Normalizer => _normalizer;
+
///
/// Gets the dictionary mapping tokens to Ids.
///
@@ -167,16 +168,64 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
return null;
}
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(text, Span.Empty, out normalizedString, considerPreTokenization, considerNormalization);
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) => Encode(null, text, out normalizedString, considerPreTokenization, considerNormalization);
+
+ private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool considerPreTokenization, bool considerNormalization)
+ {
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ return [];
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ if (splits is not null)
+ {
+ List tokens = new();
+ foreach ((int Offset, int Length) split in splits)
+ {
+ foreach (Token t in EncodeInternal(textSpanToEncode.Slice(split.Offset, split.Length)))
+ {
+ tokens.Add(new Token(t.Id, t.Value, (split.Offset + t.Offset.Index, t.Offset.Length)));
+ }
+ }
+ return tokens;
+ }
+ else
+ {
+ return EncodeInternal(textSpanToEncode);
+ }
+ }
+
///
/// Encode a text string to a list of tokens.
///
/// The text to encode.
/// The list of tokens generated from the text tokenization.
- public override IReadOnlyList Encode(ReadOnlySpan text)
+ private IReadOnlyList EncodeInternal(ReadOnlySpan text)
{
if (text.IsEmpty)
{
- return Bpe.EmptyTokensList;
+ return [];
}
char[] token = ArrayPool.Shared.Rent(text.Length);
@@ -197,7 +246,7 @@ public override IReadOnlyList Encode(ReadOnlySpan text)
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
- return Array.Empty();
+ return [];
}
if (_cache.TryGetValue(text, out List? hit))
@@ -215,32 +264,257 @@ public override IReadOnlyList Encode(ReadOnlySpan text)
}
///
- /// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
///
/// The text to encode.
- /// The list of accumulated encoded Ids.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
/// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, accumulatedIds, out textLength, maxTokens);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
///
- /// Get the number of tokens that the input text will be encoded to.
+ /// Encodes input text to token Ids up to maximum number of tokens.
///
/// The text to encode.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
/// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => EncodeToIdsInternal(text, null, out textLength, maxTokens);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, considerPreTokenization, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ textLength = 0;
+ normalizedString = null;
+ return [];
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ List ids = new();
+ if (splits is not null)
+ {
+ textLength = 0;
+ foreach ((int Offset, int Length) split in splits)
+ {
+ EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
+ textLength = split.Offset + length;
+
+ if (length < split.Length || ids.Count >= maxTokenCount)
+ {
+ break;
+ }
+ }
+ }
+ else
+ {
+ EncodeToIdsInternal(textSpanToEncode, ids, out textLength, maxTokenCount);
+ }
+
+ return ids;
+ }
///
/// Get the number of tokens that the input text will be encoded to.
///
/// The text to encode.
- /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => EncodeToIdsFromEndInternal(text, null, out textIndex, maxTokens);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(null, text, considerPreTokenization, considerNormalization, out _, out _);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(text, Span.Empty, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(null, text, considerPreTokenization, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ private int CountTokens(string? text, ReadOnlySpan textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ textLength = 0;
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ return 0;
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ int count = 0;
+ if (splits is not null)
+ {
+ foreach ((int Offset, int Length) split in splits)
+ {
+ count += EncodeToIdsInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int length, maxTokenCount - count);
+ textLength = split.Offset + length;
+
+ if (length < split.Length || count >= maxTokenCount)
+ {
+ break;
+ }
+ }
+ }
+ else
+ {
+ count += EncodeToIdsInternal(textSpanToEncode, null, out textLength, maxTokenCount);
+ }
+
+ return count;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled;
+ /// conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(text, Span.Empty, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(null, text, maxTokenCount, considerPreTokenization, considerNormalization, out normalizedString, out tokenCount);
+
+ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedString, out int tokenCount)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ tokenCount = 0;
+ return 0;
+ }
+
+ IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(text, textSpan, considerPreTokenization, considerNormalization, _normalizer, _preTokenizer, out normalizedString, out ReadOnlySpan textSpanToEncode);
+
+ if (splits is not null)
+ {
+ tokenCount = 0;
+ foreach ((int Offset, int Length) split in splits.Reverse())
+ {
+ tokenCount += EncodeToIdsFromEndInternal(textSpanToEncode.Slice(split.Offset, split.Length), null, out int textIndex, maxTokenCount - tokenCount);
+ if (textIndex > 0 || tokenCount >= maxTokenCount)
+ {
+ return split.Offset + textIndex;
+ }
+ }
+ }
+ else
+ {
+ tokenCount = EncodeToIdsFromEndInternal(textSpanToEncode, null, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ return 0;
+ }
private int EncodeToIdsResult(List tokens, IList? accumulatedIds, int maxTokens, int fullTextLength, out int textLength)
{
@@ -347,7 +621,7 @@ private int EncodeToIdsInternal(ReadOnlySpan text, IList? accumulated
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
- textLength = 0;
+ textLength = text.Length;
return 0;
}
@@ -390,7 +664,7 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
- textIndex = text.Length;
+ textIndex = 0;
return 0;
}
@@ -410,7 +684,32 @@ private int EncodeToIdsFromEndInternal(ReadOnlySpan text, IList? accu
public override int? MapTokenToId(ReadOnlySpan token) => _vocab.TryGetValue(token, out int value) ? value : null;
///
- /// Convert a list of tokens Ids to highest occurrence rankings.
+ /// Decode the given ids, back to a String.
+ ///
+ /// The list of ids that we want to decode.
+ /// The decoded string.
+ public override string? Decode(IEnumerable ids)
+ {
+ if (ids is null)
+ {
+ throw new ArgumentNullException(nameof(ids));
+ }
+
+ ValueStringBuilder sb = new ValueStringBuilder();
+
+ foreach (int id in ids)
+ {
+ if (MapIdToToken(id) is string s)
+ {
+ sb.Append(s);
+ }
+ }
+
+ return sb.ToString();
+ }
+
+ ///
+ /// Convert a list of token Ids to highest occurrence rankings.
///
/// The Ids list to map to the high occurrence rank.
/// The list of ranks mapped from the list of Ids.
@@ -432,7 +731,7 @@ public IReadOnlyList ConvertIdsToOccurrenceRanks(IReadOnlyList ids)
}
///
- /// Convert a list of tokens Ids to highest occurrence values.
+ /// Convert a list of token Ids to highest occurrence values.
///
/// The Ids list to map to the high occurrence values.
/// The list of occurrence values mapped from the list of Ids.
@@ -454,7 +753,7 @@ public IReadOnlyList ConvertIdsToOccurrenceValues(IReadOnlyList ids)
}
///
- /// Convert a list of highest occurrence rankings to tokens Ids list .
+ /// Convert a list of highest occurrence rankings to token Ids list .
///
/// The high occurrence ranks list to map to the Ids list.
/// The list of Ids mapped from the list of ranks.
@@ -638,7 +937,7 @@ private List EncodeToTokens(Span token, Span indexMapping)
{
if (token.Length == 0)
{
- return Bpe.EmptyTokensList;
+ return [];
}
if (token.Length == 1)
diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs
deleted file mode 100644
index c6cc5578df..0000000000
--- a/src/Microsoft.ML.Tokenizers/Model/Model.cs
+++ /dev/null
@@ -1,180 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Microsoft.ML.Tokenizers
-{
- ///
- /// Represents a model used during Tokenization (like BPE or Word Piece or Unigram).
- ///
- public abstract class Model
- {
- ///
- /// Encode a text to a list of tokens.
- ///
- /// The text to encode.
- /// The list of tokens generated from the text tokenization.
- public abstract IReadOnlyList Encode(ReadOnlySpan text);
-
- ///
- /// Encode a text to a list of Ids and add them to the accumulatedIds list.
- ///
- /// The text to encode.
- /// The list of accumulated encoded Ids.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- ///
- /// This method does the default implementation that uses the Encode method to get the token's Ids.
- /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
- ///
- public virtual int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
- {
- if (accumulatedIds is null)
- {
- throw new ArgumentNullException(nameof(accumulatedIds));
- }
-
- // Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
- textLength = 0;
- var tokens = Encode(text);
-
- int count = Math.Min(tokens.Count, maxTokens);
-
- for (int i = 0; i < count; i++)
- {
- textLength += tokens[i].Offset.Length;
- accumulatedIds.Add(tokens[i].Id);
- }
-
- return count;
- }
-
- ///
- /// Get the number of tokens that the input text will be encoded to.
- ///
- /// The text to encode.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- ///
- /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
- /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
- ///
- public virtual int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue)
- {
- if (maxTokens <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
- }
-
- var ids = new List();
-
- if (maxTokens == int.MaxValue)
- {
- EncodeToIds(text, ids, out _);
- textLength = text.Length;
- return ids.Count;
- }
-
- IReadOnlyList tokens = Encode(text);
- textLength = 0;
- int count = Math.Min(tokens.Count, maxTokens);
- for (int i = 0; i < count; i++)
- {
- textLength += tokens[i].Offset.Length;
- }
-
- return count;
- }
-
- ///
- /// Get the number of tokens that the input text will be encoded to.
- ///
- /// The text to encode.
- /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- ///
- /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
- /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
- ///
- public virtual int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue)
- {
- if (maxTokens <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokens), "The maximum number of tokens must be greater than 0.");
- }
-
- var ids = new List();
-
- if (maxTokens == int.MaxValue)
- {
- EncodeToIds(text, ids, out _);
- textIndex = 0;
- return ids.Count;
- }
-
- IReadOnlyList tokens = Encode(text);
- textIndex = text.Length;
- int count = Math.Min(tokens.Count, maxTokens);
-
- int tokensCount = tokens.Count;
- int end = tokensCount - count;
- for (int i = tokensCount - 1; i >= end; i--)
- {
- textIndex -= tokens[i].Offset.Length;
- }
-
- return count;
- }
-
- ///
- /// Map the token to encoded id with the option to skip the special tokens.
- ///
- /// The token to map to Id
- /// The mapped Id of the token.
- public abstract int? MapTokenToId(ReadOnlySpan token);
-
- ///
- /// Map the encoded Id to the token.
- ///
- /// The Id to map to the token.
- /// The mapped token of the Id.
- public abstract string? MapIdToToken(int id);
-
- ///
- /// Decode the given ids, back to a String.
- ///
- /// The list of ids that we want to decode.
- /// The decoded string.
- ///
- /// This method does the default implementation that uses the MapIdToToken method to get the token.
- /// Tokenizer models may opt to override this method to ensure accurate results if the default implementation
- /// provided here proves insufficient for the model's specific scenario.
- ///
- public virtual string? Decode(IEnumerable ids)
- {
- if (ids is null)
- {
- throw new ArgumentNullException(nameof(ids));
- }
-
- ValueStringBuilder sb = new ValueStringBuilder();
-
- foreach (int id in ids)
- {
- if (MapIdToToken(id) is string s)
- {
- sb.Append(s);
- }
- }
-
- return sb.ToString();
- }
- }
-}
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs
index ab43325349..889d63f059 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs
@@ -19,7 +19,7 @@ namespace Microsoft.ML.Tokenizers
///
/// SentencePieceBpe is a tokenizer that splits the input into tokens using the SentencePiece Bpe model.
///
- public sealed class SentencePieceBpe : Model
+ public sealed class SentencePieceBpe : Tokenizer
{
private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
@@ -29,9 +29,10 @@ public sealed class SentencePieceBpe : Model
private readonly int _maxByteId;
private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids.
private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character.
+ private readonly Normalizer? _normalizer;
internal SentencePieceBpe(ModelProto modelProto, bool addBos, bool addEos) :
- this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto)
+ this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto)
{
AddBeginningOfSentence = addBos;
AddEndOfSentence = addEos;
@@ -65,6 +66,8 @@ private SentencePieceBpe(ModelProto modelProto)
EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
ByteFallback = modelProto.TrainerSpec.ByteFallback;
+
+ _normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix);
}
///
@@ -127,6 +130,16 @@ private SentencePieceBpe(ModelProto modelProto)
///
public int UnknownId { get; }
+ ///
+ /// Gets the PreTokenizer used by the Tokenizer.
+ ///
+ public override PreTokenizer? PreTokenizer => null;
+
+ ///
+ /// Gets the Normalizer in use by the Tokenizer.
+ ///
+ public override Normalizer? Normalizer => _normalizer;
+
///
/// The vocabulary of the model.
///
@@ -152,12 +165,74 @@ public IReadOnlyDictionary Vocab
}
///
- /// Encode a text to a list of tokens.
+ /// Encodes input text a list of s with string value of the token, id, and offset.
///
/// The text to encode.
- /// The list of tokens generated from the text tokenization.
- /// The input text has to be normalized before calling this method.
- public override IReadOnlyList Encode(ReadOnlySpan text) => Encode(text, AddBeginningOfSentence, AddEndOfSentence);
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
+ => Encode(text, Span.Empty, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public override IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true)
+ => Encode(null, text, out normalizedString, AddBeginningOfSentence, AddEndOfSentence, considerPreTokenization, considerNormalization);
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public IReadOnlyList Encode(string text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => Encode(text, Span.Empty, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
+
+ ///
+ /// Encodes input text a list of s with string value of the token, id, and offset.
+ ///
+ /// The text to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The tokenization result includes a list of s with string value of the token, id, and offset.
+ public IReadOnlyList Encode(ReadOnlySpan text, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => Encode(null, text, out normalizedString, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
+
+ private IReadOnlyList Encode(string? text, ReadOnlySpan textSpan, out string? normalizedString, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization)
+ {
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ return [];
+ }
+
+ ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan();
+ if (considerNormalization && _normalizer is not null)
+ {
+ normalizedString = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan);
+ textToEncode = normalizedString.AsSpan();
+ }
+ else
+ {
+ normalizedString = null;
+ }
+
+ return EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence);
+ }
///
/// Encode a text to a list of tokens.
@@ -167,11 +242,11 @@ public IReadOnlyDictionary Vocab
/// Indicate emitting the end of sentence token during the encoding.
/// The list of tokens generated from the text tokenization.
/// The input text has to be normalized before calling this method.
- public IReadOnlyList Encode(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence)
+ private IReadOnlyList EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence)
{
if (text.Length == 0)
{
- return Array.Empty();
+ return [];
}
BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
@@ -306,16 +381,168 @@ revMerge is null ||
}
///
- /// Encode a text to a list of Ids and add them to the accumulatedIds list.
+ /// Encodes input text to tokes Ids.
///
/// The text to encode.
- /// The list of accumulated encoded Ids.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
/// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The input text has to be normalized before calling this method.
- public override int EncodeToIds(ReadOnlySpan text, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
- => EncodeToIds(text, AddBeginningOfSentence, AddEndOfSentence, accumulatedIds, out textLength, maxTokens);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(string text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public override IReadOnlyList EncodeToIds(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ ///
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The maximum number of tokens to encode.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The list of encoded Ids.
+ public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true)
+ => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+
+ private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedString = null;
+ textLength = 0;
+ return [];
+ }
+
+ return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+ }
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider normalization before tokenization.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
+ /// The list of encoded Ids.
+ public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
+ out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (text.IsEmpty)
+ {
+ normalizedString = null;
+ textLength = 0;
+ return [];
+ }
+
+ ReadOnlySpan textToEncode;
+
+ if (considerNormalization && _normalizer is not null)
+ {
+ normalizedString = _normalizer.Normalize(text);
+ textToEncode = normalizedString.AsSpan();
+ }
+ else
+ {
+ normalizedString = null;
+ textToEncode = text;
+ }
+
+ List ids = new();
+
+ EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out textLength, maxTokenCount);
+
+ return ids;
+ }
///
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
@@ -328,7 +555,7 @@ public override int EncodeToIds(ReadOnlySpan text, IList accumulatedI
/// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
/// The input text has to be normalized before calling this method.
- public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
+ private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int textLength, int maxTokens = int.MaxValue)
{
if (maxTokens <= 0)
{
@@ -378,7 +605,8 @@ public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool ad
{
if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textLength))
{
- break;
+ ArrayPool.Shared.Return(symbols);
+ return idsCount;
}
}
else
@@ -391,6 +619,7 @@ public int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool ad
}
else
{
+ ArrayPool.Shared.Return(symbols);
return idsCount;
}
}
@@ -510,21 +739,252 @@ revMerge is null ||
/// Get the number of tokens that the input text will be encoded to.
///
/// The text to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The input text has to be normalized before calling this method.
- public override int CountTokens(ReadOnlySpan text, out int textLength, int maxTokens = int.MaxValue) => CountTokens(text, AddBeginningOfSentence, AddEndOfSentence, out textLength, maxTokens);
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
///
/// Get the number of tokens that the input text will be encoded to.
///
/// The text to encode.
- /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public override int CountTokens(ReadOnlySpan text, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(text, Span.Empty, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public override int IndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(null, text, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ /// The number of token Ids that the input text will be encoded to.
+ public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
+ => CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public int IndexOfTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public int IndexOfTokenCount(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ {
+ tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out int textLength, maxTokenCount);
+ return textLength;
+ }
+
+ private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ => CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedString, out textLength, maxTokenCount);
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider normalization before tokenization.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
- public override int CountTokensFromEnd(ReadOnlySpan text, out int textIndex, int maxTokens = int.MaxValue) => CountTokensFromEnd(text, AddBeginningOfSentence, AddEndOfSentence, out textIndex, maxTokens);
+ public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int textLength, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (text.IsEmpty)
+ {
+ normalizedString = null;
+ textLength = 0;
+ return 0;
+ }
+
+ ReadOnlySpan textToEncode;
+ if (considerNormalization && _normalizer is not null)
+ {
+ normalizedString = _normalizer.Normalize(text);
+ textToEncode = normalizedString.AsSpan();
+ }
+ else
+ {
+ normalizedString = null;
+ textToEncode = text;
+ }
+
+ return CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textLength, maxTokenCount);
+ }
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the text or the if normalization is enabled;
+ /// conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(text, Span.Empty, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
+ ///
+ public override int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
+ => LastIndexOf(null, text, maxTokenCount, considerNormalization, out normalizedString, out tokenCount);
+
+ private int LastIndexOf(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedString, out int tokenCount)
+ => LastIndexOfTokenCount(text is null ? textSpan : text.AsSpan(), maxTokenCount, AddBeginningOfSentence, AddEndOfSentence, considerNormalization, out normalizedString, out tokenCount);
+
+ ///
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider normalization before tokenization.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
+ ///
+ public int LastIndexOfTokenCount(ReadOnlySpan text, int maxTokenCount, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedString, out int tokenCount)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
+ }
+
+ if (text.IsEmpty)
+ {
+ normalizedString = null;
+ tokenCount = 0;
+ return 0;
+ }
+
+ ReadOnlySpan textToEncode;
+ if (considerNormalization && _normalizer is not null)
+ {
+ normalizedString = _normalizer.Normalize(text);
+ textToEncode = normalizedString.AsSpan();
+ }
+ else
+ {
+ normalizedString = null;
+ textToEncode = text;
+ }
+
+ tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out int textIndex, maxTokenCount);
+ return textIndex;
+ }
///
/// Get the number of tokens that the input text will be encoded to.
@@ -536,7 +996,7 @@ revMerge is null ||
/// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
/// The input text has to be normalized before calling this method.
- public int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
+ private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textLength, int maxTokens = int.MaxValue)
{
textLength = 0;
if (text.IsEmpty)
@@ -705,7 +1165,7 @@ revMerge is null ||
/// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
/// The input text has to be normalized before calling this method.
- public int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
+ private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
{
textIndex = text.Length;
if (text.IsEmpty)
@@ -944,7 +1404,7 @@ revMerge is null ||
else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
{
// escape the dummy prefix if needed.
- sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == LlamaNormalizer.DummyPrefix ?
+ sb.Append(AddDummyPrefix && !TreatWhitespaceAsSuffix && token.Length > 0 && token[0] == SentencePieceNormalizer.DummyPrefix ?
token.AsSpan(1) :
token.AsSpan());
}
@@ -999,7 +1459,7 @@ revMerge is null ||
FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
}
- if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == LlamaNormalizer.DummyPrefix)
+ if (AddDummyPrefix && TreatWhitespaceAsSuffix && sb.Length > 0 && sb[sb.Length - 1] == SentencePieceNormalizer.DummyPrefix)
{
sb.RemoveLastChar();
}
@@ -1014,7 +1474,7 @@ revMerge is null ||
ArrayPool.Shared.Return(charPoolArray);
}
- return sb.ToString(LlamaNormalizer.DummyPrefix, ' ');
+ return sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ');
static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
{
diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index 8817ce6616..91b1d2cf9d 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -19,9 +19,9 @@
namespace Microsoft.ML.Tokenizers
{
///
- /// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken.
+ /// Represent the rapid Byte Pair Encoding tokenizer.
///
- public sealed partial class Tiktoken : Model
+ public sealed partial class Tiktoken : Tokenizer
{
private readonly Dictionary, int> _encoder;
private readonly Dictionary> _decoder;
@@ -29,46 +29,56 @@ public sealed partial class Tiktoken : Model
private readonly Dictionary _vocab;
private IReadOnlyDictionary? _vocabOriginal;
private const int MaxWordLengthToCache = 15;
+ private readonly PreTokenizer? _preTokenizer;
+ private readonly Normalizer? _normalizer;
///
- /// Create a new Tiktoken tokenizer's model object.
+ /// Create a new Tiktoken tokenizer's object.
///
/// The path to the BPE vocab file.
+ /// The pre-tokenizer to use.
/// The dictionary mapping special tokens to Ids.
+ /// The normalizer to use.
/// The size of the cache to use.
/// Thrown when is null or empty.
/// Thrown when failed to load the BPE vocab file.
- public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
- this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
+ public Tiktoken(string vocabFilePath, PreTokenizer? preTokenizer, IReadOnlyDictionary