Skip to content

Commit

Permalink
xAI/Grok Support (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
pushpak1300 authored Nov 10, 2024
1 parent 6ad1789 commit a594bcc
Show file tree
Hide file tree
Showing 20 changed files with 815 additions and 0 deletions.
4 changes: 4 additions & 0 deletions config/prism.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,9 @@
'api_key' => env('GROQ_API_KEY', ''),
'url' => env('GROQ_URL', 'https://api.groq.com/openai/v1'),
],
'xai' => [
'api_key' => env('XAI_API_KEY', ''),
'url' => env('XAI_URL', 'https://api.x.ai/v1'),
],
],
];
4 changes: 4 additions & 0 deletions docs/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ export default defineConfig({
text: "OpenAI",
link: "/providers/openai",
},
{
text: "XAI",
link: "/providers/xai",
},
],
},
{
Expand Down
1 change: 1 addition & 0 deletions docs/getting-started/introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ We currently offer first-party support for these leading AI providers:
- [Mistral](https://mistral.ai)
- [Ollama](https://ollama.com)
- [OpenAI](https://openai.com)
- [xAI](https://x.ai/)

Each provider brings its own strengths to the table, and Prism makes it easy to use them all through a consistent, elegajnt interface.
14 changes: 14 additions & 0 deletions docs/providers/xai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# xAI
## Configuration

```php
'xai' => [
'api_key' => env('XAI_API_KEY', ''),
'url' => env('XAI_URL', 'https://api.x.ai/v1'),
],
```

## Limitations
### Image Support

xAI does not support image inputs.
1 change: 1 addition & 0 deletions src/Enums/Provider.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ enum Provider: string
case OpenAI = 'openai';
case Mistral = 'mistral';
case Groq = 'groq';
case XAI = 'xai';
}
12 changes: 12 additions & 0 deletions src/PrismManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
use EchoLabs\Prism\Providers\Mistral\Mistral;
use EchoLabs\Prism\Providers\Ollama\Ollama;
use EchoLabs\Prism\Providers\OpenAI\OpenAI;
use EchoLabs\Prism\Providers\XAI\XAI;
use Illuminate\Contracts\Foundation\Application;
use InvalidArgumentException;
use RuntimeException;
Expand Down Expand Up @@ -147,4 +148,15 @@ protected function createGroqProvider(array $config): Groq
apiKey: $config['api_key'],
);
}

/**
* @param array<string, string> $config
*/
protected function createXaiProvider(array $config): XAI
{
return new XAI(
url: $config['url'],
apiKey: $config['api_key'],
);
}
}
58 changes: 58 additions & 0 deletions src/Providers/XAI/Client.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\XAI;

use Illuminate\Http\Client\PendingRequest;
use Illuminate\Http\Client\Response;
use Illuminate\Support\Facades\Http;

class Client
{
protected PendingRequest $client;

/**
* @param array<string, mixed> $options
*/
public function __construct(
public readonly string $url,
public readonly string $apiKey,
public readonly array $options = [],
) {
$this->client = Http::withHeaders(array_filter([
'Authorization' => sprintf('Bearer %s', $this->apiKey),
]))
->withOptions($this->options)
->baseUrl($this->url);
}

/**
* @param array<int, mixed> $messages
* @param array<int, mixed>|null $tools
* @param array<string, mixed>|string|null $toolChoice
*/
public function messages(
string $model,
array $messages,
?int $maxTokens,
int|float|null $temperature,
int|float|null $topP,
?array $tools,
string|array|null $toolChoice,
): Response {
return $this->client->post(
'chat/completions',
array_merge([
'model' => $model,
'messages' => $messages,
'max_tokens' => $maxTokens ?? 2048,
], array_filter([
'temperature' => $temperature,
'top_p' => $topP,
'tools' => $tools,
'tool_choice' => $toolChoice,
]))
);
}
}
105 changes: 105 additions & 0 deletions src/Providers/XAI/MessageMap.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\XAI;

use EchoLabs\Prism\Contracts\Message;
use EchoLabs\Prism\ValueObjects\Messages\AssistantMessage;
use EchoLabs\Prism\ValueObjects\Messages\SystemMessage;
use EchoLabs\Prism\ValueObjects\Messages\ToolResultMessage;
use EchoLabs\Prism\ValueObjects\Messages\UserMessage;
use EchoLabs\Prism\ValueObjects\ToolCall;
use Exception;

class MessageMap
{
/** @var array<int, mixed> */
protected $mappedMessages = [];

/**
* @param array<int, Message> $messages
*/
public function __construct(
protected array $messages,
protected string $systemPrompt
) {
if ($systemPrompt !== '' && $systemPrompt !== '0') {
$this->messages = array_merge(
[new SystemMessage($systemPrompt)],
$this->messages
);
}
}

/**
* @return array<int, mixed>
*/
public function __invoke(): array
{
array_map(
fn (Message $message) => $this->mapMessage($message),
$this->messages
);

return $this->mappedMessages;
}

public function mapMessage(Message $message): void
{
match ($message::class) {
UserMessage::class => $this->mapUserMessage($message),
AssistantMessage::class => $this->mapAssistantMessage($message),
ToolResultMessage::class => $this->mapToolResultMessage($message),
SystemMessage::class => $this->mapSystemMessage($message),
default => throw new Exception('Could not map message type '.$message::class),
};
}

protected function mapSystemMessage(SystemMessage $message): void
{
$this->mappedMessages[] = [
'role' => 'system',
'content' => $message->content,
];
}

protected function mapToolResultMessage(ToolResultMessage $message): void
{
foreach ($message->toolResults as $toolResult) {
$this->mappedMessages[] = [
'role' => 'tool',
'tool_call_id' => $toolResult->toolCallId,
'content' => $toolResult->result,
];
}
}

protected function mapUserMessage(UserMessage $message): void
{
$this->mappedMessages[] = [
'role' => 'user',
'content' => [
['type' => 'text', 'text' => $message->text()],
],
];
}

protected function mapAssistantMessage(AssistantMessage $message): void
{
$toolCalls = array_map(fn (ToolCall $toolCall): array => [
'id' => $toolCall->id,
'type' => 'function',
'function' => [
'name' => $toolCall->name,
'arguments' => json_encode($toolCall->arguments()),
],
], $message->toolCalls);

$this->mappedMessages[] = array_filter([
'role' => 'assistant',
'content' => $message->content,
'tool_calls' => $toolCalls,
]);
}
}
34 changes: 34 additions & 0 deletions src/Providers/XAI/Tool.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\XAI;

use EchoLabs\Prism\Providers\ProviderTool;
use EchoLabs\Prism\Tool as PrismTool;

class Tool extends ProviderTool
{
#[\Override]
public static function toArray(PrismTool $tool): array
{
return [
'type' => 'function',
'function' => [
'name' => $tool->name(),
'description' => $tool->description(),
'parameters' => [
'type' => 'object',
'properties' => collect($tool->parameters())
->keyBy('name')
->map(fn (array $field): array => [
'description' => $field['description'],
'type' => $field['type'],
])
->toArray(),
'required' => $tool->requiredParameters(),
],
],
];
}
}
Loading

0 comments on commit a594bcc

Please sign in to comment.