From 9ae89ac46fe813f08ec03ee2dc17a7523c3f7690 Mon Sep 17 00:00:00 2001 From: Dr Nic Williams Date: Sat, 20 Apr 2024 19:22:54 +1000 Subject: [PATCH] Default model id llama3-8b-8192 --- README.md | 5 +++-- lib/groq/client.rb | 1 + lib/groq/configuration.rb | 3 ++- lib/groq/model.rb | 8 ++++++++ test/groq/test_client.rb | 15 +++++++++++---- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index fd04d48..d3c2ede 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@ gem install groq - Use the `Groq::Client` to interact with Groq and your favourite model. ```ruby -client = Groq::Client.new # uses ENV["GROQ_API_KEY"] -client = Groq::Client.new(api_key: "...") +client = Groq::Client.new # uses ENV["GROQ_API_KEY"] and "llama3-8b-8192" +client = Groq::Client.new(api_key: "...", model_id: "llama3-8b-8192") Groq.configuration do |config| config.api_key = "..." + config.model_id = "llama3-70b-8192" end client = Groq::Client.new ``` diff --git a/lib/groq/client.rb b/lib/groq/client.rb index 34f3edf..e2578f6 100644 --- a/lib/groq/client.rb +++ b/lib/groq/client.rb @@ -4,6 +4,7 @@ class Groq::Client CONFIG_KEYS = %i[ api_key api_url + model_id ].freeze attr_reader(*CONFIG_KEYS, :faraday_middleware) diff --git a/lib/groq/configuration.rb b/lib/groq/configuration.rb index d07c2d5..c783f6e 100644 --- a/lib/groq/configuration.rb +++ b/lib/groq/configuration.rb @@ -1,6 +1,6 @@ class Groq::Configuration attr_writer :api_key - attr_accessor :api_url, :request_timeout, :extra_headers + attr_accessor :model_id, :api_url, :request_timeout, :extra_headers DEFAULT_API_URL = "https://api.groq.com" DEFAULT_REQUEST_TIMEOUT = 5 @@ -9,6 +9,7 @@ class Error < StandardError; end def initialize @api_key = ENV["GROQ_API_KEY"] + @model_id = Groq::Model.default_model_id @api_url = DEFAULT_API_URL @request_timeout = DEFAULT_REQUEST_TIMEOUT @extra_headers = {} diff --git a/lib/groq/model.rb b/lib/groq/model.rb index 8648da9..e3fe64c 100644 --- a/lib/groq/model.rb +++ b/lib/groq/model.rb @@ -36,4 +36,12 @@ class Groq::Model model_card: "https://huggingface.co/google/gemma-1.1-7b-it" } ] + + def self.default_model + MODELS.first + end + + def self.default_model_id + default_model[:model_id] + end end diff --git a/test/groq/test_client.rb b/test/groq/test_client.rb index 3539da8..3efefa7 100644 --- a/test/groq/test_client.rb +++ b/test/groq/test_client.rb @@ -5,11 +5,12 @@ class TestGroqClient < Minitest::Test # define "say hello world" for each model, such as: test_hello_world_llama3_8b et al Groq::Model::MODELS.each do |model| - define_method :"test_hello_world_#{model[:model_id]}" do - VCR.use_cassette("#{model[:model_id]}/hello_world") do - client = Groq::Client.new + model_id = model[:model_id] + define_method :"test_hello_world_#{model_id}" do + VCR.use_cassette("#{model_id}/hello_world") do + client = Groq::Client.new(model_id: model_id) response = client.post(path: "/openai/v1/chat/completions", body: { - model: model[:model_id], + model: model_id, messages: [{role: "user", content: "Reply with only the words: Hello, World!"}] }) assert_equal 200, response.status @@ -18,4 +19,10 @@ class TestGroqClient < Minitest::Test end end end + + # It's mentioned in the README, so let's assert it too + def test_default_model + client = Groq::Client.new + assert_equal "llama3-8b-8192", client.model_id + end end