diff --git a/lib/ollama.ex b/lib/ollama.ex index 07a4731..0b51f6e 100644 --- a/lib/ollama.ex +++ b/lib/ollama.ex @@ -472,7 +472,7 @@ defmodule Ollama do ], context: [ type: {:list, {:or, [:integer, :float]}}, - doc: "The context parameter returned from a previous `f:completion/2` call (enabling short conversational memory).", + doc: "The context parameter returned from a previous `completion/2` call (enabling short conversational memory).", ], format: [ type: :string, @@ -842,6 +842,61 @@ defmodule Ollama do end + schema :embed, [ + model: [ + type: :string, + required: true, + doc: "The name of the model used to generate the embeddings.", + ], + input: [ + type: {:or, [:string, {:list, :string}]}, + required: true, + doc: "Text or list of text to generate embeddings for.", + ], + truncate: [ + type: :boolean, + doc: "Truncates the end of each input to fit within context length.", + ], + keep_alive: [ + type: {:or, [:integer, :string]}, + doc: "How long to keep the model loaded.", + ], + options: [ + type: {:map, {:or, [:atom, :string]}, :any}, + doc: "Additional advanced [model parameters](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).", + ], + ] + + @doc """ + Generate embeddings from a model for the given prompt. + + ## Options + + #{doc(:embed)} + + ## Example + + iex> Ollama.embed(client, [ + ...> model: "nomic-embed-text", + ...> input: ["Why is the sky blue?", "Why is the grass green?"], + ...> ]) + {:ok, %{"embedding" => [ + [ 0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844, + 0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127, ...], + [ 0.028196355, 0.043162502, -0.18592504, 0.035034444, 0.055619627, + 0.12082449, -0.0090096295, 0.047170386, -0.032078084, 0.0047163847, ...] + ]}} + """ + @spec embed(client(), keyword()) :: response() + def embed(%__MODULE__{} = client, params) when is_list(params) do + with {:ok, params} <- NimbleOptions.validate(params, schema(:embed)) do + client + |> req(:post, "/embed", json: Enum.into(params, %{})) + |> res() + end + end + + schema :embeddings, [ model: [ type: :string, @@ -881,6 +936,7 @@ defmodule Ollama do 0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281 ]}} """ + @deprecated "Superseded by embed/2" @spec embeddings(client(), keyword()) :: response() def embeddings(%__MODULE__{} = client, params) when is_list(params) do with {:ok, params} <- NimbleOptions.validate(params, schema(:embeddings)) do diff --git a/test/ollama_test.exs b/test/ollama_test.exs index 2f69f2f..a360d1a 100644 --- a/test/ollama_test.exs +++ b/test/ollama_test.exs @@ -251,6 +251,39 @@ defmodule OllamaTest do end end + describe "embed/1" do + test "generates an embedding for a given input", %{client: client} do + assert {:ok, res} = Ollama.embed(client, [ + model: "nomic-embed-text", + input: "Why is the sky blue?", + ]) + + assert res["model"] == "nomic-embed-text" + assert is_list(res["embeddings"]) + assert length(res["embeddings"]) == 1 + assert Enum.all?(res["embeddings"], &is_list/1) + end + + test "generates an embedding for a list of input texts", %{client: client} do + assert {:ok, res} = Ollama.embed(client, [ + model: "nomic-embed-text", + input: ["Why is the sky blue?", "Why is the grass green?"], + ]) + + assert res["model"] == "nomic-embed-text" + assert is_list(res["embeddings"]) + assert length(res["embeddings"]) == 2 + assert Enum.all?(res["embeddings"], &is_list/1) + end + + test "returns error when model not found", %{client: client} do + assert {:error, %HTTPError{status: 404}} = Ollama.embed(client, [ + model: "not-found", + input: "Why is the sky blue?", + ]) + end + end + describe "embeddings/2" do test "generates an embedding for a given prompt", %{client: client} do assert {:ok, res} = Ollama.embeddings(client, [ diff --git a/test/support/mock_server.ex b/test/support/mock_server.ex index e049c11..39328ae 100644 --- a/test/support/mock_server.ex +++ b/test/support/mock_server.ex @@ -172,6 +172,42 @@ defmodule Ollama.MockServer do } """, + # truncated for simplicity + embed_one: """ + { + "embeddings": [ + [ + 0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844, + 0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127 + ] + ], + "load_duration": 1881917, + "model": "nomic-embed-text", + "prompt_eval_count": 8, + "total_duration": 48675959 + } + """, + + # truncated for simplicity + embed_many: """ + { + "embeddings": [ + [ + 0.009724553, 0.04449892, -0.14063916, 0.0013168337, 0.032128844, + 0.10730086, -0.008447222, 0.010106917, 5.2289694e-4, -0.03554127 + ], + [ + 0.028196355, 0.043162502, -0.18592504, 0.035034444, 0.055619627, + 0.12082449, -0.0090096295, 0.047170386, -0.032078084, 0.0047163847 + ] + ], + "load_duration": 1902709, + "model": "nomic-embed-text", + "prompt_eval_count": 16, + "total_duration": 53473292 + } + """, + embeddings: """ { "embedding": [ @@ -339,6 +375,14 @@ defmodule Ollama.MockServer do post "/blobs/:digest", do: respond(conn, 200) post "/embeddings", do: handle_request(conn, :embeddings) + post "/embed" do + case conn.body_params do + %{"model" => "not-found"} -> respond(conn, 404) + %{"input" => input} when is_binary(input) -> respond(conn, :embed_one) + %{"input" => input} when is_list(input) > 1 -> respond(conn, :embed_many) + end + end + defp handle_request(conn, name) do case conn.body_params do %{"model" => "not-found"} -> respond(conn, 404)