Skip to content

Commit

Permalink
Default model id llama3-8b-8192
Browse files Browse the repository at this point in the history
  • Loading branch information
drnic committed Apr 20, 2024
1 parent 1fbaaa0 commit 9ae89ac
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ gem install groq
- Use the `Groq::Client` to interact with Groq and your favourite model.

```ruby
client = Groq::Client.new # uses ENV["GROQ_API_KEY"]
client = Groq::Client.new(api_key: "...")
client = Groq::Client.new # uses ENV["GROQ_API_KEY"] and "llama3-8b-8192"
client = Groq::Client.new(api_key: "...", model_id: "llama3-8b-8192")

Groq.configuration do |config|
config.api_key = "..."
config.model_id = "llama3-70b-8192"
end
client = Groq::Client.new
```
Expand Down
1 change: 1 addition & 0 deletions lib/groq/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ class Groq::Client
CONFIG_KEYS = %i[
api_key
api_url
model_id
].freeze
attr_reader(*CONFIG_KEYS, :faraday_middleware)

Expand Down
3 changes: 2 additions & 1 deletion lib/groq/configuration.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class Groq::Configuration
attr_writer :api_key
attr_accessor :api_url, :request_timeout, :extra_headers
attr_accessor :model_id, :api_url, :request_timeout, :extra_headers

DEFAULT_API_URL = "https://api.groq.com"
DEFAULT_REQUEST_TIMEOUT = 5
Expand All @@ -9,6 +9,7 @@ class Error < StandardError; end

def initialize
@api_key = ENV["GROQ_API_KEY"]
@model_id = Groq::Model.default_model_id
@api_url = DEFAULT_API_URL
@request_timeout = DEFAULT_REQUEST_TIMEOUT
@extra_headers = {}
Expand Down
8 changes: 8 additions & 0 deletions lib/groq/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ class Groq::Model
model_card: "https://huggingface.co/google/gemma-1.1-7b-it"
}
]

def self.default_model
MODELS.first
end

def self.default_model_id
default_model[:model_id]
end
end
15 changes: 11 additions & 4 deletions test/groq/test_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
class TestGroqClient < Minitest::Test
# define "say hello world" for each model, such as: test_hello_world_llama3_8b et al
Groq::Model::MODELS.each do |model|
define_method :"test_hello_world_#{model[:model_id]}" do
VCR.use_cassette("#{model[:model_id]}/hello_world") do
client = Groq::Client.new
model_id = model[:model_id]
define_method :"test_hello_world_#{model_id}" do
VCR.use_cassette("#{model_id}/hello_world") do
client = Groq::Client.new(model_id: model_id)
response = client.post(path: "/openai/v1/chat/completions", body: {
model: model[:model_id],
model: model_id,
messages: [{role: "user", content: "Reply with only the words: Hello, World!"}]
})
assert_equal 200, response.status
Expand All @@ -18,4 +19,10 @@ class TestGroqClient < Minitest::Test
end
end
end

# It's mentioned in the README, so let's assert it too
def test_default_model
client = Groq::Client.new
assert_equal "llama3-8b-8192", client.model_id
end
end

0 comments on commit 9ae89ac

Please sign in to comment.