Skip to content

Commit

Permalink
Implemented 'ai:models:call' for image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sbosio committed Sep 27, 2024
1 parent 90ff2d2 commit e4c2faa
Show file tree
Hide file tree
Showing 9 changed files with 583 additions and 162 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,28 +101,29 @@ make an inference request to a specific AI model resource

```
USAGE
$ heroku ai:models:call [MODEL_RESOURCE] -a <value> -p <value> [-j] [--optfile <value>] [--opts <value>] [-o
<value>] [-r <value>]
$ heroku ai:models:call [MODEL_RESOURCE] -p <value> [-a <value>] [--browser <value>] [-j] [--optfile <value>]
[--opts <value>] [-o <value>] [-r <value>]
ARGUMENTS
MODEL_RESOURCE The resource ID or alias of the model to call.
FLAGS
-a, --app=<value> (required) app to run command against
-a, --app=<value> app to run command against
-j, --json Output response as JSON
-o, --output=<value> The file path where the command writes the model response.
-p, --prompt=<value> (required) The input prompt for the model.
-r, --remote=<value> git remote of app to use
--browser=<value> browser to open images with (example: "firefox", "safari")
--optfile=<value> Additional options for model inference, provided as a JSON config file.
--opts=<value> Additional options for model inference, provided as a JSON string.
DESCRIPTION
make an inference request to a specific AI model resource
EXAMPLES
$ heroku ai:models:call my_llm --prompt "What is the meaning of life?"
$ heroku ai:models:call my_llm --app my-app --prompt "What is the meaning of life?"
$ heroku ai:models:call sdxl --prompt "Generate an image of a sunset" --opts '{"quality": "hd"}'
$ heroku ai:models:call sdxl --app my-app --prompt "Generate an image of a sunset" --opts '{"quality":"hd"}' -o sunset.png
```

_See code: [dist/commands/ai/models/call.ts](https://github.com/heroku/heroku-cli-plugin-integration/blob/v0.0.0/dist/commands/ai/models/call.ts)_
Expand Down
35 changes: 2 additions & 33 deletions src/commands/ai/docs.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import color from '@heroku-cli/color'
import {flags} from '@heroku-cli/command'
import {ux} from '@oclif/core'
import {CLIError} from '@oclif/core/lib/errors'
import open from 'open'
import {openUrl} from '../../lib/open-url'
import Command from '../../lib/base'

export default class Docs extends Command {
Expand All @@ -12,39 +9,11 @@ export default class Docs extends Command {
browser: flags.string({description: 'browser to open docs with (example: "firefox", "safari")'}),
}

static urlOpener: (...args: Parameters<typeof open>) => ReturnType<typeof open> = open

public async run(): Promise<void> {
const {flags} = await this.parse(Docs)
const browser = flags.browser
const url = process.env.HEROKU_AI_DOCS_URL || Docs.defaultUrl

let browserErrorShown = false
const showBrowserError = (browser?: string) => {
if (browserErrorShown) return

ux.warn(`Unable to open ${browser ? browser : 'your default'} browser. Please visit ${color.cyan(url)} to view the documentation.`)
browserErrorShown = true
}

ux.log(`Opening ${color.cyan(url)} in ${browser ? browser : 'your default'} browser…`)

try {
await ux.anykey(
`Press any key to open up the browser to show Heroku AI documentation, or ${color.yellow('q')} to exit`
)
} catch (error) {
const {message, oclif} = error as CLIError
ux.error(message, {exit: oclif?.exit || 1})
}

const cp = await Docs.urlOpener(url, {wait: false, ...(browser ? {app: {name: browser}} : {})})
cp.on('error', (err: Error) => {
ux.warn(err)
showBrowserError(browser)
})
cp.on('close', (code: number) => {
if (code !== 0) showBrowserError(browser)
})
await openUrl(url, browser, 'view the documentation')
}
}
62 changes: 51 additions & 11 deletions src/commands/ai/models/call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import color from '@heroku-cli/color'
import {flags} from '@heroku-cli/command'
import {Args, ux} from '@oclif/core'
import fs from 'node:fs'
// import path from 'node:path'
import {ChatCompletionResponse, ModelList} from '../../../lib/ai/types'
import {ChatCompletionResponse, ImageResponse, ModelList} from '../../../lib/ai/types'
import Command from '../../../lib/base'
import {openUrl} from '../../../lib/open-url'

export default class Call extends Command {
static args = {
Expand All @@ -16,17 +16,18 @@ export default class Call extends Command {

static description = 'make an inference request to a specific AI model resource'
static examples = [
'heroku ai:models:call my_llm --prompt "What is the meaning of life?"',
'heroku ai:models:call sdxl --prompt "Generate an image of a sunset" --opts \'{"quality": "hd"}\'',
'heroku ai:models:call my_llm --app my-app --prompt "What is the meaning of life?"',
'heroku ai:models:call sdxl --app my-app --prompt "Generate an image of a sunset" --opts \'{"quality":"hd"}\' -o sunset.png',
]

static flags = {
app: flags.app({required: true}),
app: flags.app({required: false}),
// interactive: flags.boolean({
// char: 'i',
// description: 'Use interactive mode for conversation beyond the initial prompt (not available on all models)',
// default: false,
// }),
browser: flags.string({description: 'browser to open images with (example: "firefox", "safari")'}),
json: flags.boolean({char: 'j', description: 'Output response as JSON'}),
optfile: flags.string({
description: 'Additional options for model inference, provided as a JSON config file.',
Expand All @@ -53,7 +54,7 @@ export default class Call extends Command {
public async run(): Promise<void> {
const {args, flags} = await this.parse(Call)
const {model_resource: modelResource} = args
const {app, json, optfile, opts, output, prompt} = flags
const {app, browser, json, optfile, opts, output, prompt} = flags

// Initially, configure the default client to fetch the available model classes
await this.configureHerokuAIClient()
Expand All @@ -69,12 +70,15 @@ export default class Call extends Command {
case 'Embedding':
break

case 'Text-to-Image':
case 'Text-to-Image': {
const image = await this.generateImage(prompt, options)
await this.displayImageResult(image, output, browser, json)
break
}

case 'Text-to-Text': {
const completion = await this.createChatCompletion(prompt, options)
this.displayChatCompletion(completion, output, json)
await this.displayChatCompletion(completion, output, json)
break
}

Expand Down Expand Up @@ -146,13 +150,49 @@ export default class Call extends Command {
return chatCompletionResponse
}

private displayChatCompletion(completion: ChatCompletionResponse, output?: string, json = false) {
const content = json ? JSON.stringify(completion, null, 2) : completion.choices[0].message.content || ''
private async displayChatCompletion(completion: ChatCompletionResponse, output?: string, json = false) {
const content = completion.choices[0].message.content || ''

if (output) {
fs.writeFileSync(output, content)
fs.writeFileSync(output, json ? JSON.stringify(completion, null, 2) : content)
} else {
json ? ux.styledJSON(completion) : ux.log(content)
}
}

private async generateImage(prompt: string, options = {}) {
const {body: imageResponse} = await this.herokuAI.post<ImageResponse>('/v1/images/generations', {
body: {
...options,
model: this.apiModelId,
prompt,
},
headers: {authorization: `Bearer ${this.apiKey}`},
})

return imageResponse
}

private async displayImageResult(image: ImageResponse, output?: string, browser?: string, json = false) {
if (image.data[0].b64_json) {
if (output) {
const content = json ? JSON.stringify(image, null, 2) : Buffer.from(image.data[0].b64_json, 'base64')
fs.writeFileSync(output, content)
} else
json ? ux.styledJSON(image) : process.stdout.write(image.data[0].b64_json)
return
}

if (image.data[0].url) {
if (output)
fs.writeFileSync(output, json ? JSON.stringify(image, null, 2) : image.data[0].url)
else if (json)
ux.styledJSON(image)
else
await openUrl(image.data[0].url, browser, 'view the image')
return
}

ux.error('Unexpected response format')
}
}
22 changes: 22 additions & 0 deletions src/lib/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,25 @@ export type ChatCompletionResponse = {
} | null
}
}

/**
* Image schema
*/
export type Image = {
/** The base64-encoded JSON of the generated image, if 'response_format' is 'b64_json' */
readonly b64_json?: string | null
/** The prompt that was used to generate the image, if there was any revision to the prompt */
readonly revised_prompt: string
/** The URL of the generated image, if 'response_format' is 'url' (default) */
readonly url?: string | null
}

/**
* Image response schema.
*/
export type ImageResponse = {
/** The Unix timestamp (in seconds) of when the image was generated */
readonly created: number
/** A list of images */
readonly data: Array<Image>
}
36 changes: 36 additions & 0 deletions src/lib/open-url.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import color from '@heroku-cli/color'
import {ux} from '@oclif/core'
import {CLIError} from '@oclif/core/lib/errors'
import open from 'open'

export const urlOpener: (...args: Parameters<typeof open>) => ReturnType<typeof open> = open

export async function openUrl(url: string, browser?: string, action?: string) {
let browserErrorShown = false
const showBrowserError = (browser?: string) => {
if (browserErrorShown) return

ux.warn(`Unable to open ${browser ? browser : 'your default'} browser. Please visit ${color.cyan(url)}${action ? ` to ${action}` : ''}.`)
browserErrorShown = true
}

ux.log(`Opening ${color.cyan(url)} in ${browser ? browser : 'your default'} browser…`)

try {
await ux.anykey(
`Press any key to open up the browser${action ? ` to ${action}` : ''}, or ${color.yellow('q')} to exit`
)
} catch (error) {
const {message, oclif} = error as CLIError
ux.error(message, {exit: oclif?.exit || 1})
}

const cp = await urlOpener(url, {wait: false, ...(browser ? {app: {name: browser}} : {})})
cp.on('error', (err: Error) => {
ux.warn(err)
showBrowserError(browser)
})
cp.on('close', (code: number) => {
if (code !== 0) showBrowserError(browser)
})
}
Loading

0 comments on commit e4c2faa

Please sign in to comment.