Skip to content

Commit

Permalink
Configure/override max_tokens + temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
drnic committed Apr 20, 2024
1 parent c2ddcd2 commit 43dd4a1
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 19 deletions.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Speed and pricing at 2024-04-21. Also see their [changelog](https://console.groq

## Groq Cloud API

You can interact with their API using any Ruby HTTP library by following their documentation at <https://console.groq.com/docs/quickstart>
You can interact with their API using any Ruby HTTP library by following their documentation at <https://console.groq.com/docs/quickstart>. Also use their [Playground](https://console.groq.com/playground) and watch the API traffic in the browser's developer tools.

The Groq Cloud API looks to be copying a subset of the OpenAI API. For example, you perform chat completions at `https://api.groq.com/openai/v1/chat/completions` with the same POST body schema as OpenAI. The Tools support looks to have the same schema for defining tools/functions.

Expand Down Expand Up @@ -244,6 +244,34 @@ messages << T("25 degrees celcius", tool_call_id: tool_call_id, name: "get_weath
# => {"role"=>"assistant", "content"=> "I'm glad you called the function!\n\nAs of your current location, the weather in Paris is indeed 25°C (77°F)..."}
```

### Max Tokens & Temperature

Max tokens is the maximum number of tokens that the model can process in a single response. This limits ensures computational efficiency and resource management.

Temperature setting for each API call controls randomness of responses. A lower temperature leads to more predictable outputs while a higher temperature results in more varies and sometimes more creative outputs. The range of values is 0 to 2.

Each API call includes a `max_token:` and `temperature:` value.

The defaults are:

```ruby
@client.max_tokens
=> 1024
@client.temperature
=> 1
```

You can override them in the `Groq.configuration` block, or with each `chat()` call:

```ruby
Groq.configuration do |config|
config.max_tokens = 512
config.temperature = 0.5
end
# or
@client.chat("Hello, world!", max_tokens: 512, temperature: 0.5)
```

## Development

After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
Expand Down
8 changes: 6 additions & 2 deletions lib/groq/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class Groq::Client
api_key
api_url
model_id
max_tokens
temperature
].freeze
attr_reader(*CONFIG_KEYS, :faraday_middleware)

Expand All @@ -20,7 +22,7 @@ def initialize(config = {}, &faraday_middleware)
end

# TODO: support stream: true; or &stream block
def chat(messages, model_id: nil, tools: nil)
def chat(messages, model_id: nil, tools: nil, max_tokens: nil, temperature: nil)
unless messages.is_a?(Array) || messages.is_a?(String)
raise ArgumentError, "require messages to be an Array or String"
end
Expand All @@ -34,7 +36,9 @@ def chat(messages, model_id: nil, tools: nil)
body = {
model: model_id,
messages: messages,
tools: tools
tools: tools,
max_tokens: max_tokens || @max_tokens,
temperature: temperature || @temperature
}.compact
response = post(path: "/openai/v1/chat/completions", body: body)
if response.status == 200
Expand Down
10 changes: 8 additions & 2 deletions lib/groq/configuration.rb
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
class Groq::Configuration
attr_writer :api_key
attr_accessor :model_id, :api_url, :request_timeout, :extra_headers
attr_accessor :model_id, :max_tokens, :temperature
attr_accessor :api_url, :request_timeout, :extra_headers

DEFAULT_API_URL = "https://api.groq.com"
DEFAULT_REQUEST_TIMEOUT = 5
DEFAULT_MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 1

class Error < StandardError; end

def initialize
@api_key = ENV["GROQ_API_KEY"]
@model_id = Groq::Model.default_model_id
@api_url = DEFAULT_API_URL
@request_timeout = DEFAULT_REQUEST_TIMEOUT
@extra_headers = {}

@model_id = Groq::Model.default_model_id
@max_tokens = DEFAULT_MAX_TOKENS
@temperature = DEFAULT_TEMPERATURE
end

def api_key
Expand Down
72 changes: 72 additions & 0 deletions test/fixtures/vcr_cassettes/llama3-8b-8192/chat_max_tokens.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 14 additions & 14 deletions test/fixtures/vcr_cassettes/llama3-8b-8192/chat_messages.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions test/groq/test_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
require "test_helper"

class TestGroqClient < Minitest::Test
def test_defaults
client = Groq::Client.new
assert_equal "llama3-8b-8192", client.model_id
assert_equal 1024, client.max_tokens
assert_equal 1, client.temperature
end

# define "say hello world" for each model, such as: test_hello_world_llama3_8b et al
Groq::Model::MODELS.each do |model|
model_id = model[:model_id]
Expand Down Expand Up @@ -103,4 +110,15 @@ def test_tools_weather_report
assert_equal response, {"role" => "assistant", "content" => "The weather in Brisbane, QLD is 25 degrees Celsius."}
end
end

def test_max_tokens
VCR.use_cassette("llama3-8b-8192/chat_max_tokens") do
client = Groq::Client.new(model_id: "llama3-8b-8192")
response = client.chat("What's the next day after Wednesday?", max_tokens: 1)
assert_equal response, {
"role" => "assistant", "content" => "The"
}
# Yeah, max_tokens=1 still returns a full word; because its a single token.
end
end
end

0 comments on commit 43dd4a1

Please sign in to comment.