Skip to content

Commit

Permalink
Implementation of 'ai:models:call' for chat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
sbosio committed Sep 27, 2024
1 parent a3a9e01 commit 90ff2d2
Show file tree
Hide file tree
Showing 7 changed files with 600 additions and 27 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ USAGE
* [`heroku ai:docs`](#heroku-aidocs)
* [`heroku ai:models`](#heroku-aimodels)
* [`heroku ai:models:attach MODEL_RESOURCE`](#heroku-aimodelsattach-model_resource)
* [`heroku ai:models:call MODEL_RESOURCE`](#heroku-aimodelscall-model_resource)
* [`heroku ai:models:create MODEL_NAME`](#heroku-aimodelscreate-model_name)
* [`heroku ai:models:list`](#heroku-aimodelslist)

Expand Down Expand Up @@ -94,6 +95,38 @@ EXAMPLES

_See code: [dist/commands/ai/models/attach.ts](https://github.com/heroku/heroku-cli-plugin-integration/blob/v0.0.0/dist/commands/ai/models/attach.ts)_

## `heroku ai:models:call MODEL_RESOURCE`

make an inference request to a specific AI model resource

```
USAGE
$ heroku ai:models:call [MODEL_RESOURCE] -a <value> -p <value> [-j] [--optfile <value>] [--opts <value>] [-o
<value>] [-r <value>]
ARGUMENTS
MODEL_RESOURCE The resource ID or alias of the model to call.
FLAGS
-a, --app=<value> (required) app to run command against
-j, --json Output response as JSON
-o, --output=<value> The file path where the command writes the model response.
-p, --prompt=<value> (required) The input prompt for the model.
-r, --remote=<value> git remote of app to use
--optfile=<value> Additional options for model inference, provided as a JSON config file.
--opts=<value> Additional options for model inference, provided as a JSON string.
DESCRIPTION
make an inference request to a specific AI model resource
EXAMPLES
$ heroku ai:models:call my_llm --prompt "What is the meaning of life?"
$ heroku ai:models:call sdxl --prompt "Generate an image of a sunset" --opts '{"quality": "hd"}'
```

_See code: [dist/commands/ai/models/call.ts](https://github.com/heroku/heroku-cli-plugin-integration/blob/v0.0.0/dist/commands/ai/models/call.ts)_

## `heroku ai:models:create MODEL_NAME`

provision access to an AI model
Expand Down
158 changes: 158 additions & 0 deletions src/commands/ai/models/call.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import color from '@heroku-cli/color'
import {flags} from '@heroku-cli/command'
import {Args, ux} from '@oclif/core'
import fs from 'node:fs'
// import path from 'node:path'
import {ChatCompletionResponse, ModelList} from '../../../lib/ai/types'
import Command from '../../../lib/base'

export default class Call extends Command {
static args = {
model_resource: Args.string({
description: 'The resource ID or alias of the model to call.',
required: true,
}),
}

static description = 'make an inference request to a specific AI model resource'
static examples = [
'heroku ai:models:call my_llm --prompt "What is the meaning of life?"',
'heroku ai:models:call sdxl --prompt "Generate an image of a sunset" --opts \'{"quality": "hd"}\'',
]

static flags = {
app: flags.app({required: true}),
// interactive: flags.boolean({
// char: 'i',
// description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)',
// default: false,
// }),
json: flags.boolean({char: 'j', description: 'Output response as JSON'}),
optfile: flags.string({
description: 'Additional options for model inference, provided as a JSON config file.',
required: false,
}),
opts: flags.string({
description: 'Additional options for model inference, provided as a JSON string.',
required: false,
}),
output: flags.string({
char: 'o',
// description: 'The file path where the command writes the model response. If used with --interactive, this flag writes the entire exchange when the session closes.',
description: 'The file path where the command writes the model response.',
required: false,
}),
prompt: flags.string({
char: 'p',
description: 'The input prompt for the model.',
required: true,
}),
remote: flags.remote(),
}

public async run(): Promise<void> {
const {args, flags} = await this.parse(Call)
const {model_resource: modelResource} = args
const {app, json, optfile, opts, output, prompt} = flags

// Initially, configure the default client to fetch the available model classes
await this.configureHerokuAIClient()
const {body: availableModels} = await this.herokuAI.get<ModelList>('/available-models')

// Now, configure the client to send a request for the target model resource
await this.configureHerokuAIClient(modelResource, app)
const options = this.parseOptions(optfile, opts)
// Not sure why `type` is an array in ModelListItem, we use the type from the first entry.
const modelType = availableModels.find(m => m.model_id === this.apiModelId)?.type[0]

switch (modelType) {
case 'Embedding':
break

case 'Text-to-Image':
break

case 'Text-to-Text': {
const completion = await this.createChatCompletion(prompt, options)
this.displayChatCompletion(completion, output, json)
break
}

default:
throw new Error(`Unsupported model type: ${modelType}`)
}
}

/**
* Parse the model call request options from the command flags.
*
* @param optfile Path to a JSON file containing options.
* @param opts JSON string containing options.
* @returns The parsed options as an object.
*/
private parseOptions(optfile?: string, opts?: string) {
const options = {}

if (optfile) {
const optfileContents = fs.readFileSync(optfile)

try {
Object.assign(options, JSON.parse(optfileContents.toString()))
} catch (error: unknown) {
if (error instanceof SyntaxError) {
const {message} = error as SyntaxError
return ux.error(
`Invalid JSON in ${color.yellow(optfile)}. Check the formatting in your file.\n${message}`,
{exit: 1},
)
}

throw error
}
}

if (opts) {
try {
Object.assign(options, JSON.parse(opts))
} catch (error: unknown) {
if (error instanceof SyntaxError) {
const {message} = error as SyntaxError
return ux.error(
`Invalid JSON. Check the formatting in your ${color.yellow('--opts')} value.\n${message}`,
{exit: 1},
)
}

throw error
}
}

return options
}

private async createChatCompletion(prompt: string, options = {}) {
const {body: chatCompletionResponse} = await this.herokuAI.post<ChatCompletionResponse>('/v1/chat/completions', {
body: {
...options,
model: this.apiModelId,
messages: [{
role: 'user',
content: prompt,
}],
},
headers: {authorization: `Bearer ${this.apiKey}`},
})

return chatCompletionResponse
}

private displayChatCompletion(completion: ChatCompletionResponse, output?: string, json = false) {
const content = json ? JSON.stringify(completion, null, 2) : completion.choices[0].message.content || ''

if (output) {
fs.writeFileSync(output, content)
} else {
json ? ux.styledJSON(completion) : ux.log(content)
}
}
}
105 changes: 101 additions & 4 deletions src/lib/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ export type ModelName =
'cohere-embed-english' |
'cohere-embed-multilingual'

export type ModelType = 'Text to Text' | 'Embedding'
export type ModelType = 'Text-to-Image' | 'Text-to-Text' | 'Embedding'

/**
* Object schema for each collection item returned by the Model List endpoint.
*/
export type ModelListItem = {
name: ModelName
model_id: ModelName
type: Array<ModelType>
}

Expand Down Expand Up @@ -50,6 +50,103 @@ export type ModelResource = {
}

/**
* Types returned for `ai:models:call` will be added after the description gets refined in the
* API reference document.
* OpenAI compatible response schemas for model calls
*/

/**
* Tool call schema
*/
export type ToolCall = {
/** The ID of the tool call. Currently, only function is supported */
id: string
/** The type of the tool call */
type: string
/** The function that the model called */
function: {
/** The name of the function to call */
name: string
/** The arguments to call the function with, as generated by the model in JSON format */
arguments: string
}
}

/**
* Log probability token schema
*/
export type LogProbToken = {
/** The token */
token: string
/** The log probability of this token */
logprob: number
/** The encoded bytes representing the token */
bytes: Array<number> | null
}

/**
* Log probability schema
*/
export type LogProb = LogProbToken & {
/** List of the most likely tokens and their log probability */
top_logprobs: Array<LogProbToken> | null
}

/**
* Chat completion choice schema
*/
export type ChatCompletionChoice = {
/** The reason the model stopped generating tokens */
readonly finish_reason: 'stop' | 'length' | 'content_filter' | 'tool_calls'
/** The index of the choice in the list of choices */
readonly index: number
/** A chat completion message generated by the model */
readonly message: {
/** The contents of the message */
readonly content: string | null
/** The refusal message generated by the model */
readonly refusal: string | null
readonly tool_calls?: Array<ToolCall> | null
/** The role of the author of this message */
readonly role: string
}
/** Log probability information for the choice */
readonly logprobs?: {
/** A list of message content tokens with log probability information */
content: Array<LogProb> | null
/** A list of message refusal tokens with log probability information */
refusal: Array<LogProb> | null
} | null
}

/**
* Chat completion response schema.
*/
export type ChatCompletionResponse = {
/** A unique identifier for the chat completion */
readonly id: string
/** A list of chat completion choices. Can be more than one if n is greater than 1 */
readonly choices: Array<ChatCompletionChoice>
/** The Unix timestamp (in seconds) of when the chat completion was created */
readonly created: number
/** The model used for the chat completion */
readonly model: ModelName
/** The service tier used for processing the request */
readonly service_tier?: string | null
/** This fingerprint represents the backend configuration that the model runs with */
readonly system_fingerprint: string
/** The object type, which is always chat.completion */
readonly object: string
/** Usage statistics for the completion request */
readonly usage: {
/** Number of tokens in the generated completion */
readonly completion_tokens: number
/** Number of tokens in the prompt */
readonly prompt_tokens: number
/** Total number of tokens used in the request (prompt + completion) */
readonly total_tokens: number
/** Breakdown of tokens used in a completion */
readonly completion_tokens_details?: {
/** Tokens generated by the model for reasoning */
readonly reasoning_tokens: number
} | null
}
}
3 changes: 1 addition & 2 deletions src/lib/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {APIClient, Command} from '@heroku-cli/command'
import * as Heroku from '@heroku-cli/schema'
import {ux} from '@oclif/core'
import heredoc from 'tsheredoc'
// import {inspect} from 'node:util'
import {HerokuAPIError} from '@heroku-cli/command/lib/api-client'

export class NotFound extends Error {
Expand Down Expand Up @@ -48,7 +47,7 @@ export default abstract class extends Command {
const defaultHeaders = {
...this.heroku.defaults.headers,
accept: 'application/json',
'user-agent': `heroku-cli-plugin-ai/${this.config.version} ${this.config.platform}`,
'user-agent': `heroku-cli-plugin-ai/${process.env.npm_package_version} ${this.config.platform}`,
}
delete defaultHeaders.authorization

Expand Down
Loading

0 comments on commit 90ff2d2

Please sign in to comment.