Skip to content

Commit

Permalink
Implemented 'ai:models:call' for embeddings models
Browse files Browse the repository at this point in the history
  • Loading branch information
sbosio committed Sep 27, 2024
1 parent e4c2faa commit 2df6d15
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ FLAGS
-o, --output=<value> The file path where the command writes the model response.
-p, --prompt=<value> (required) The input prompt for the model.
-r, --remote=<value> git remote of app to use
--browser=<value> browser to open images with (example: "firefox", "safari")
--browser=<value> browser to open URLs with (example: "firefox", "safari")
--optfile=<value> Additional options for model inference, provided as a JSON config file.
--opts=<value> Additional options for model inference, provided as a JSON string.
Expand Down
35 changes: 31 additions & 4 deletions src/commands/ai/models/call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import color from '@heroku-cli/color'
import {flags} from '@heroku-cli/command'
import {Args, ux} from '@oclif/core'
import fs from 'node:fs'
import {ChatCompletionResponse, ImageResponse, ModelList} from '../../../lib/ai/types'
import {ChatCompletionResponse, EmbeddingResponse, ImageResponse, ModelList} from '../../../lib/ai/types'
import Command from '../../../lib/base'
import {openUrl} from '../../../lib/open-url'

Expand All @@ -27,7 +27,7 @@ export default class Call extends Command {
// description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)',
// default: false,
// }),
browser: flags.string({description: 'browser to open images with (example: "firefox", "safari")'}),
browser: flags.string({description: 'browser to open URLs with (example: "firefox", "safari")'}),
json: flags.boolean({char: 'j', description: 'Output response as JSON'}),
optfile: flags.string({
description: 'Additional options for model inference, provided as a JSON config file.',
Expand Down Expand Up @@ -67,8 +67,11 @@ export default class Call extends Command {
const modelType = availableModels.find(m => m.model_id === this.apiModelId)?.type[0]

switch (modelType) {
case 'Embedding':
case 'Embedding': {
const embedding = await this.createEmbedding(prompt, options)
await this.displayEmbedding(embedding, output, json)
break
}

case 'Text-to-Image': {
const image = await this.generateImage(prompt, options)
Expand Down Expand Up @@ -193,6 +196,30 @@ export default class Call extends Command {
return
}

ux.error('Unexpected response format')
// This should never happen, but we'll handle it anyway
ux.error('Unexpected response format', {exit: 1})
}

private async createEmbedding(input: string, options = {}) {
const {body: EmbeddingResponse} = await this.herokuAI.post<EmbeddingResponse>('/v1/embeddings', {
body: {
...options,
model: this.apiModelId,
input,
},
headers: {authorization: `Bearer ${this.apiKey}`},
})

return EmbeddingResponse
}

private async displayEmbedding(embedding: EmbeddingResponse, output?: string, json = false) {
const content = (embedding.data[0].embeddings || []).toString()

if (output) {
fs.writeFileSync(output, json ? JSON.stringify(embedding, null, 2) : content)
} else {
json ? ux.styledJSON(embedding) : ux.log(content)
}
}
}
38 changes: 33 additions & 5 deletions src/lib/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ export type ChatCompletionResponse = {
readonly prompt_tokens: number
/** Total number of tokens used in the request (prompt + completion) */
readonly total_tokens: number
/** Breakdown of tokens used in a completion */
readonly completion_tokens_details?: {
/** Tokens generated by the model for reasoning */
readonly reasoning_tokens: number
} | null
}
}

Expand All @@ -172,3 +167,36 @@ export type ImageResponse = {
/** A list of images */
readonly data: Array<Image>
}

/**
* Embedding schema
*/
export type Embedding = {
/** The index of the embedding in the list of embeddings */
readonly index: number
/** The embedding vector, which is a list of floats */
readonly embeddings: Array<number>
/** The object type, which is always "embeddings" */
readonly object: 'embeddings'
}

/**
* Embedding response schema.
*/
export type EmbeddingResponse = {
/** The object type, which is always "list" */
readonly object: 'list'
/** The list of Embedding objects */
readonly data: Array<Embedding>
/** The model used to generate the embeddings */
readonly model: ModelName
/** Usage statistics for embeddings generation */
readonly usage: {
/** Number of tokens in the generated embeddings */
readonly completion_tokens: number
/** Number of tokens in the prompt */
readonly prompt_tokens: number
/** Total number of tokens used in the request (prompt + completion) */
readonly total_tokens: number
}
}
Loading

0 comments on commit 2df6d15

Please sign in to comment.