Skip to content

Commit

Permalink
.by to by
Browse files Browse the repository at this point in the history
  • Loading branch information
bnicenboim committed Nov 4, 2023
1 parent d8ec90a commit 74847d0
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 50 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: pangoling
Type: Package
Title: Access to Large Language Model Predictions
Version: 0.0.0.9008
Version: 0.0.0.9009
Authors@R: c(
person("Bruno", "Nicenboim",
email = "[email protected]",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@

# pangoling 0.0.0.9008
* Fix a bug when `.by` is unordered

# pangoling 0.0.0.9009
* Deprecated `.by` in favor of `by`.

64 changes: 47 additions & 17 deletions R/tr_causal.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ causal_next_tokens_tbl <- function(context,
#'
#'
#' @param x Vector of words, phrases or texts.
#' @param .by Vector that indicates how the text should be split.
#' @param by Vector that indicates how the text should be split.
#' @param l_contexts Left context for each word in `x`. If `l_contexts` is used,
#' `.by` is ignored. Set `.by = NULL` to avoid a message notifying that.
#' `by` is ignored. Set `by = NULL` to avoid a message notifying that.
#' @param ... not in use.
#' @inheritParams causal_preload
#' @param ignore_regex Can ignore certain characters when calculates the log
#' probabilities. For example `^[[:punct:]]$` will ignore
Expand All @@ -159,30 +160,45 @@ causal_next_tokens_tbl <- function(context,
#'causal_lp(
#' x = "tree.",
#' l_contexts = "The apple doesn't fall far from the tree.",
#' .by = NULL, # it's ignored anyways
#' by = NULL, # it's ignored anyways
#' model = "gpt2"
#' )

#' @family causal model functions
#' @export
causal_lp <- function(x,
.by = rep(1, length(x)),
by = rep(1, length(x)),
l_contexts = NULL,
ignore_regex = "",
model = getOption("pangoling.causal.default"),
checkpoint = NULL,
add_special_tokens = NULL,
config_model = NULL,
config_tokenizer = NULL,
batch_size = 1) {
batch_size = 1,
...) {
dots <- list(...)
# Check for the deprecated .by argument
if (!is.null(dots$.by)) {
warning("The '.by' argument is deprecated. Please use 'by' instead.")
by <- dots$.by # Assume that if .by is supplied, it takes precedence

Check warning on line 184 in R/tr_causal.R

View check run for this annotation

Codecov / codecov/patch

R/tr_causal.R#L183-L184

Added lines #L183 - L184 were not covered by tests
}
# Check for unknown arguments
if (length(dots) > 0) {
unknown_args <- setdiff(names(dots), ".by")
if (length(unknown_args) > 0) {
stop("Unknown arguments: ", paste(unknown_args, collapse = ", "), ".")

Check warning on line 190 in R/tr_causal.R

View check run for this annotation

Codecov / codecov/patch

R/tr_causal.R#L188-L190

Added lines #L188 - L190 were not covered by tests
}
}

stride <- 1 # fixed for now
message_verbose("Processing using causal model '", file.path(model, checkpoint), "'...")
if(!is.null(l_contexts)){
if(all(!is.null(.by))) message_verbose("Ignoring `.by` argument")
if(all(!is.null(by))) message_verbose("Ignoring `by` argument")
x <- c(rbind(l_contexts, x))
.by <- rep(seq_len(length(x)/2), each = 2)
by <- rep(seq_len(length(x)/2), each = 2)
}
word_by_word_texts <- get_word_by_word_texts(x, .by)
word_by_word_texts <- get_word_by_word_texts(x, by)

pasted_texts <- lapply(
word_by_word_texts,
Expand Down Expand Up @@ -245,13 +261,13 @@ causal_lp <- function(x,
} else {
keep <- TRUE
}
# split(x, .by) |> unsplit(.by)
# split(x, by) |> unsplit(by)
# tidytable::map2_dfr(, ~ tidytable::tidytable(x = .x))
out <- out |> lapply(function(x) x[keep])
lps <- out |> unsplit(.by[keep], drop = TRUE)
lps <- out |> unsplit(by[keep], drop = TRUE)

names(lps) <- out |> lapply(function(x) paste0(names(x),"")) |>
unsplit(.by[keep], drop = TRUE)
unsplit(by[keep], drop = TRUE)
lps
}

Expand Down Expand Up @@ -409,7 +425,7 @@ causal_mat <- function(tensor,
#'
#' @inheritParams causal_lp
#' @inheritParams causal_preload
#' @param sorted When default FALSE it will retain the order of groups we are splitting on. When TRUE then sorted (according to `.by`) list(s) are returned.
#' @param sorted When default FALSE it will retain the order of groups we are splitting on. When TRUE then sorted (according to `by`) list(s) are returned.
#' @inherit causal_preload details
#' @inheritSection causal_next_tokens_tbl More examples
#' @return A list of matrices with tokens in their columns and the vocabulary of the model in their rows
Expand All @@ -424,14 +440,28 @@ causal_mat <- function(tensor,
#' @export
#'
causal_lp_mats <- function(x,
.by = rep(1, length(x)),
by = rep(1, length(x)),
sorted = FALSE,
model = getOption("pangoling.causal.default"),
checkpoint = NULL,
add_special_tokens = NULL,
config_model = NULL,
config_tokenizer = NULL,
batch_size = 1) {
batch_size = 1,
...) {
dots <- list(...)
# Check for the deprecated .by argument
if (!is.null(dots$.by)) {
warning("The '.by' argument is deprecated. Please use 'by' instead.")
by <- dots$.by # Assume that if .by is supplied, it takes precedence

Check warning on line 456 in R/tr_causal.R

View check run for this annotation

Codecov / codecov/patch

R/tr_causal.R#L455-L456

Added lines #L455 - L456 were not covered by tests
}
# Check for unknown arguments
if (length(dots) > 0) {
unknown_args <- setdiff(names(dots), ".by")
if (length(unknown_args) > 0) {
stop("Unknown arguments: ", paste(unknown_args, collapse = ", "), ".")

Check warning on line 462 in R/tr_causal.R

View check run for this annotation

Codecov / codecov/patch

R/tr_causal.R#L460-L462

Added lines #L460 - L462 were not covered by tests
}
}
stride <- 1
message_verbose("Processing using causal model '", file.path(model, checkpoint), "'...")
tkzr <- tokenizer(model,
Expand All @@ -444,7 +474,7 @@ causal_lp_mats <- function(x,
config_model = config_model
)
x <- trimws(x, whitespace = "[ \t]")
word_by_word_texts <- split(x, .by)
word_by_word_texts <- split(x, by)
pasted_texts <- lapply(
word_by_word_texts,
function(word) paste0(word, collapse = " ")
Expand All @@ -466,8 +496,8 @@ causal_lp_mats <- function(x,
)
}
)
names(lmat) <- levels(as.factor(.by))
if(!sorted) lmat <- lmat[unique(as.factor(.by))]
names(lmat) <- levels(as.factor(by))
if(!sorted) lmat <- lmat[unique(as.factor(by))]
lmat |>
unlist(recursive = FALSE)
}
1 change: 0 additions & 1 deletion R/tr_masked.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ masked_lp <- function(l_contexts,

message_verbose("Processing using masked model '", model, "'...")

# word_by_word_texts <- get_word_by_word_texts(x, .by)
target_tokens <- char_to_token(targets, tkzr)
masked_sentences <- tidytable::pmap_chr(
list(
Expand Down
8 changes: 4 additions & 4 deletions R/tr_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ encode <- function(x, tkzr, add_special_tokens = NULL, ...) {
}
}

get_word_by_word_texts <- function(x, .by) {
if (length(x) != length(.by)) {
stop2("The argument `.by` has an incorrect length.")
get_word_by_word_texts <- function(x, by) {
if (length(x) != length(by)) {
stop2("The argument `by` has an incorrect length.")
}
x <- trimws(x, whitespace = "[ \t]")
split(x, .by)
split(x, by)
}

#' Sends a var to python
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ One can get the log-transformed probability of each word based on GPT-2 as follo

```{r, cache = TRUE}
df_sent <- df_sent |>
mutate(lp = causal_lp(word, .by = sent_n))
mutate(lp = causal_lp(word, by = sent_n))
df_sent
```

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public.](https://www.repostatus.org/badges/latest/wip.svg)](https://www.repostat
Review](https://badges.ropensci.org/575_status.svg)](https://github.com/ropensci/software-review/issues/575)
<!-- badges: end -->

`pangoling`\[1\] is an R package for estimating the log-probabilities of
`pangoling`[^1] is an R package for estimating the log-probabilities of
words in a given context using transformer models. The package provides
an interface for utilizing pre-trained transformer models (such as GPT-2
or BERT) to obtain word probabilities. These log-probabilities are often
Expand All @@ -28,7 +28,7 @@ The package is mostly a wrapper of the python package
[`transformers`](https://pypi.org/project/transformers/) to process data
in a convenient format.

## Important\! Limitations and bias
## Important! Limitations and bias

The training data of the most popular models (such as GPT-2) haven’t
been released, so one cannot inspect it. It’s clear that the data
Expand Down Expand Up @@ -93,7 +93,7 @@ as follows:

``` r
df_sent <- df_sent |>
mutate(lp = causal_lp(word, .by = sent_n))
mutate(lp = causal_lp(word, by = sent_n))
#> Processing using causal model ''...
#> Processing a batch of size 1 with 10 tokens.
#> Processing a batch of size 1 with 9 tokens.
Expand Down Expand Up @@ -125,7 +125,7 @@ df_sent
## How to cite

> Nicenboim B (2023). *pangoling: Access to language model predictions
> in R*. R package version 0.0.0.9007, DOI:
> in R*. R package version 0.0.0.9008, DOI:
> [10.5281/zenodo.7637526](https://zenodo.org/badge/latestdoi/497831295),
> <https://github.com/bnicenboim/pangoling>.
Expand All @@ -146,7 +146,7 @@ Another R package that act as a wrapper for
[`text`](https://r-text.org//) However, `text` is more general, and its
focus is on Natural Language Processing and Machine Learning.

1. The logo of the package was created with [stable
[^1]: The logo of the package was created with [stable
diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion)
and the R package
[hexSticker](https://github.com/GuangchuangYu/hexSticker).
13 changes: 8 additions & 5 deletions man/causal_lp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions man/causal_lp_mats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions tests/testthat/test-tr_causal.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ test_that("empty or small strings", {
expect_warning(lp_NA <- causal_tokens_lp_tbl(texts = ""))
expect_equal(as.data.frame(lp_NA), data.frame(token = "", lp = NA_real_))
small_str <- c("It", "It", "is")
lp_small <- causal_lp(x = small_str, .by = c(1, 2, 2))
lp_small <- causal_lp(x = small_str, by = c(1, 2, 2))
expect_equal(lp_small[1:2], c(It = NA_real_, It = NA_real_))
expect_warning(lp_small_ <- causal_lp(x = c("", "It"), .by = c(1, 2)))
expect_warning(lp_small_ <- causal_lp(x = c("", "It"), by = c(1, 2)))
expect_equal(lp_small_, c(NA_real_, "It" = NA_real_))
})

Expand All @@ -49,7 +49,7 @@ test_that("long input work", {

test_that("errors work", {
skip_if_no_python_stuff()
expect_error(causal_lp(c("It", "is."), .by = 3))
expect_error(causal_lp(c("It", "is."), by = 3))
})

test_that("gpt2 get prob work", {
Expand Down Expand Up @@ -106,7 +106,7 @@ test_that("gpt2 get prob work", {
lp_sent_rep <-
causal_lp(
x = rep(sent_w, 2),
.by = rep(seq_len(2), each = length(sent_w))
by = rep(seq_len(2), each = length(sent_w))
)
expect_equal(
unname(lp_sent_rep[seq_along(sent_w)]),
Expand All @@ -118,10 +118,10 @@ test_that("gpt2 get prob work", {
df_order2 <- data.frame(word = c(sent2_words,prov_words),
item = c(rep(2, each = length(sent2_words)),
rep(1, each= length(prov_words))))
expect_equal(causal_lp(df_order1$word, .by = df_order1$item),
causal_lp(x = df_order2$word, .by = df_order2$item))
expect_equal(causal_lp_mats(x = df_order1$word, .by = df_order1$item),
causal_lp_mats(x = df_order2$word, .by = df_order2$item) |>
expect_equal(causal_lp(df_order1$word, by = df_order1$item),
causal_lp(x = df_order2$word, by = df_order2$item))
expect_equal(causal_lp_mats(x = df_order1$word, by = df_order1$item),
causal_lp_mats(x = df_order2$word, by = df_order2$item) |>
setNames(c("1","2")))

})
Expand All @@ -148,8 +148,8 @@ test_that("batches work", {
rep(6, length(sent2_words))
)
)
lp_2_batch <- causal_lp(x = df$x, .by = df$.id, batch_size = 4)
lp_2_no_batch <- causal_lp(x = df$x, .by = df$.id, batch_size = 1)
lp_2_batch <- causal_lp(x = df$x, by = df$.id, batch_size = 4)
lp_2_no_batch <- causal_lp(x = df$x, by = df$.id, batch_size = 1)
expect_equal(lp_2_batch, lp_2_no_batch, tolerance = .0001)

df <- data.frame(l_contexts = rep(c("Don't judge a book by its","The apple doesn't fall far from the"),5), x = rep(c("cover", "tree"),5))
Expand Down
4 changes: 2 additions & 2 deletions vignettes/articles/intro-gpt2.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ One can get the log-transformed probability of each word based on GPT-2 as follo

```{r}
df_sent <- df_sent |>
mutate(lp = causal_lp(word, .by = sent_n))
mutate(lp = causal_lp(word, by = sent_n))
df_sent
```

Notice that the `.by` is inside the `causal_lp()` function. It' also possible to use `.by` in the mutate call, or `group_by()`, but it will be slower.
Notice that the `by` is inside the `causal_lp()` function. It' also possible to use `by` in the mutate call, or `group_by()`, but it will be slower.


The attentive reader might have noticed that the log-probability of "tree" here is not the same as the one presented before. This is because the actual word is `" tree."` (notice the space), which contains two tokens:
Expand Down

0 comments on commit 74847d0

Please sign in to comment.