Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multicharacter Token Support #193

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install .
```

For faster installation use (replace `<N>` with the number of CPUs available):

```bash
# get the code
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode
MAX_JOBS=<N> python3 setup.py build
python3 setup.py install
```

## How to Use

```python
Expand All @@ -32,7 +42,8 @@ decoder = CTCBeamDecoder(
beam_width=100,
num_processes=4,
blank_id=0,
log_probs_input=False
log_probs_input=False,
is_token_based=False
)
beam_results, beam_scores, timesteps, out_lens = decoder.decode(output)
```
Expand All @@ -52,6 +63,7 @@ beam_results, beam_scores, timesteps, out_lens = decoder.decode(output)
- `num_processes` Parallelize the batch using num_processes workers. You probably want to pass the number of cpus your computer has. You can find this in python with `import multiprocessing` then `n_cpus = multiprocessing.cpu_count()`. Default 4.
- `blank_id` This should be the index of the CTC blank token (probably 0).
- `log_probs_input` If your outputs have passed through a softmax and represent probabilities, this should be false, if they passed through a LogSoftmax and represent negative log likelihood, you need to pass True. If you don't understand this, run `print(output[0][0].sum())`, if it's a negative number you've probably got NLL and need to pass True, if it sums to ~1.0 you should pass False. Default False.
- `is_token_based` If you use LM based on custom tokens (e.g., BPEs) set to True. Default False.

### Inputs to the `decode` method
- `output` should be the output activations from your model. If your output has passed through a SoftMax layer, you shouldn't need to alter it (except maybe to transpose), but if your `output` represents negative log likelihoods (raw logits), you either need to pass it through an additional `torch.nn.functional.softmax` or you can pass `log_probs_input=False` to the decoder. Your output should be BATCHSIZE x N_TIMESTEPS x N_LABELS so you may need to transpose it before passing it to the decoder. Note that if you pass things in the wrong order, the beam search will probably still run, you'll just get back nonsense results.
Expand Down Expand Up @@ -79,7 +91,8 @@ decoder = OnlineCTCBeamDecoder(
beam_width=100,
num_processes=4,
blank_id=0,
log_probs_input=False
log_probs_input=False,
is_token_based=False
)

state1 = ctcdecode.DecoderState(decoder)
Expand Down
16 changes: 14 additions & 2 deletions ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class CTCBeamDecoder(object):
num_processes (int): Parallelize the batch using num_processes workers.
blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
is_token_based (bool): True if you use tokens (e.g., BPEs).
"""

def __init__(
Expand All @@ -35,6 +36,7 @@ def __init__(
num_processes=4,
blank_id=0,
log_probs_input=False,
is_token_based=False,
):
self.cutoff_top_n = cutoff_top_n
self._beam_width = beam_width
Expand All @@ -44,9 +46,10 @@ def __init__(
self._num_labels = len(labels)
self._blank_id = blank_id
self._log_probs = 1 if log_probs_input else 0
self._is_token_based = 1 if is_token_based else 0
if model_path:
self._scorer = ctc_decode.paddle_get_scorer(
alpha, beta, model_path.encode(), self._labels, self._num_labels
alpha, beta, model_path.encode(), self._labels, self._num_labels, self._is_token_based
)
self._cutoff_prob = cutoff_prob

Expand Down Expand Up @@ -124,6 +127,9 @@ def decode(self, probs, seq_lens=None):

def character_based(self):
return ctc_decode.is_character_based(self._scorer) if self._scorer else None

def token_based(self):
return ctc_decode.is_token_based(self._scorer) if self._scorer else None

def max_order(self):
return ctc_decode.get_max_order(self._scorer) if self._scorer else None
Expand Down Expand Up @@ -158,6 +164,7 @@ class OnlineCTCBeamDecoder(object):
num_processes (int): Parallelize the batch using num_processes workers.
blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
is_token_based (bool): True if you use tokens (e.g., BPEs).
"""
def __init__(
self,
Expand All @@ -171,6 +178,7 @@ def __init__(
num_processes=4,
blank_id=0,
log_probs_input=False,
is_token_based=False,
):
self._cutoff_top_n = cutoff_top_n
self._beam_width = beam_width
Expand All @@ -180,9 +188,10 @@ def __init__(
self._num_labels = len(labels)
self._blank_id = blank_id
self._log_probs = 1 if log_probs_input else 0
self._is_token_based = 1 if is_token_based else 0
if model_path:
self._scorer = ctc_decode.paddle_get_scorer(
alpha, beta, model_path.encode(), self._labels, self._num_labels
alpha, beta, model_path.encode(), self._labels, self._num_labels, self._is_token_based
)
self._cutoff_prob = cutoff_prob

Expand Down Expand Up @@ -240,6 +249,9 @@ def decode(self, probs, states, is_eos_s, seq_lens=None):
def character_based(self):
return ctc_decode.is_character_based(self._scorer) if self._scorer else None

def token_based(self):
return ctc_decode.is_token_based(self._scorer) if self._scorer else None

def max_order(self):
return ctc_decode.get_max_order(self._scorer) if self._scorer else None

Expand Down
10 changes: 8 additions & 2 deletions ctcdecode/src/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ void* paddle_get_scorer(double alpha,
double beta,
const char* lm_path,
vector<std::string> new_vocab,
int vocab_size) {
Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab);
int vocab_size,
bool is_token_based) {
Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab, is_token_based);
return static_cast<void*>(scorer);
}

Expand Down Expand Up @@ -272,6 +273,10 @@ int is_character_based(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->is_character_based();
}
int is_token_based(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->is_token_based();
}
size_t get_max_order(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->get_max_order();
Expand All @@ -293,6 +298,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paddle_get_scorer", &paddle_get_scorer, "paddle_get_scorer");
m.def("paddle_release_scorer", &paddle_release_scorer, "paddle_release_scorer");
m.def("is_character_based", &is_character_based, "is_character_based");
m.def("is_token_based", &is_token_based, "is_token_based");
m.def("get_max_order", &get_max_order, "get_max_order");
m.def("get_dict_size", &get_dict_size, "get_max_order");
m.def("reset_params", &reset_params, "reset_params");
Expand Down
4 changes: 3 additions & 1 deletion ctcdecode/src/binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ void* paddle_get_scorer(double alpha,
double beta,
const char* lm_path,
std::vector<std::string> labels,
int vocab_size);
int vocab_size,
bool is_token_based);


void* paddle_get_decoder_state(const std::vector<std::string> &vocabulary,
Expand All @@ -50,6 +51,7 @@ void paddle_release_state(void* state);


int is_character_based(void *scorer);
int is_token_based(void *scorer);
size_t get_max_order(void *scorer);
size_t get_dict_size(void *scorer);
void reset_params(void *scorer, double alpha, double beta);
8 changes: 4 additions & 4 deletions ctcdecode/src/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ DecoderState::DecoderState(const std::vector<std::string> &vocabulary,
root.score = root.log_prob_b_prev = 0.0;
prefixes.push_back(&root);

if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer != nullptr && !(ext_scorer->is_character_based() || ext_scorer->is_token_based())) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
Expand Down Expand Up @@ -119,10 +119,10 @@ DecoderState::next(const std::vector<std::vector<double>> &probs_seq)

// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
(c == space_id || ext_scorer->is_character_based() || ext_scorer->is_token_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
if (ext_scorer->is_character_based() || ext_scorer->is_token_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
Expand Down Expand Up @@ -171,7 +171,7 @@ DecoderState::decode()
}

// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
if (ext_scorer != nullptr && !(ext_scorer->is_character_based() || ext_scorer->is_token_based())) {
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i];
if (!prefix->is_empty() && prefix->character != space_id) {
Expand Down
32 changes: 21 additions & 11 deletions ctcdecode/src/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ using namespace lm::ngram;
Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
const std::vector<std::string>& vocab_list,
bool is_token_based) {
this->alpha = alpha;
this->beta = beta;

dictionary = nullptr;
is_character_based_ = true;
is_character_based_ = !is_token_based;
is_token_based_ = is_token_based;
language_model_ = nullptr;

max_order_ = 0;
Expand All @@ -47,7 +49,7 @@ void Scorer::setup(const std::string& lm_path,
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
if (!(is_character_based() || is_token_based())) {
fill_dictionary(true);
}
}
Expand Down Expand Up @@ -126,21 +128,29 @@ void Scorer::reset_params(float alpha, float beta) {

std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += char_list_[ind];
for (size_t i = 0; i < input.size(); ++i) {
word += char_list_[input[i]];
if(is_token_based_ && i + 1 < input.size())
word += " ";
}
return word;
}

std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {};

std::string s = vec2str(labels);
std::vector<std::string> words;
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
if(is_token_based_) {
for (auto ind : labels)
words.push_back(char_list_[ind]);
}
else {
std::string s = vec2str(labels);
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
}
return words;
}
Expand Down Expand Up @@ -169,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<int> prefix_vec;
std::vector<int> prefix_steps;

if (is_character_based_) {
if (is_character_based_ || is_token_based_) {
new_node = current_node->get_path_vec(prefix_vec, prefix_steps, -1, 1);
current_node = new_node;
} else {
Expand Down
7 changes: 6 additions & 1 deletion ctcdecode/src/scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class Scorer {
Scorer(double alpha,
double beta,
const std::string &lm_path,
const std::vector<std::string> &vocabulary);
const std::vector<std::string> &vocabulary,
bool is_token_based);
~Scorer();

double get_log_cond_prob(const std::vector<std::string> &words);
Expand All @@ -58,6 +59,9 @@ class Scorer {

// retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }

// retrun true if the language model is token based (e.g., BPE)
bool is_token_based() const { return is_token_based_; }

// reset params alpha & beta
void reset_params(float alpha, float beta);
Expand Down Expand Up @@ -99,6 +103,7 @@ class Scorer {
private:
void *language_model_;
bool is_character_based_;
bool is_token_based_;
size_t max_order_;
size_t dict_size_;

Expand Down