Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable google search for gemini-1.5-pro and gemini-2.0-flash-exp #5063

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 94 additions & 18 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const { ChatVertexAI } = require('@langchain/google-vertexai');
const { GoogleVertexAI } = require('@langchain/google-vertexai');
const { ChatGoogleVertexAI } = require('@langchain/google-vertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { GoogleGenerativeAI: GenAI, DynamicRetrievalMode } = require('@google/generative-ai');
const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
Expand All @@ -27,15 +27,18 @@ const {
truncateText,
} = require('./prompts');
const BaseClient = require('./BaseClient');

const loc = process.env.GOOGLE_LOC || 'us-central1';
const publisher = 'google';
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
const tokenizersCache = {};

const settings = endpointSettings[EModelEndpoint.google];
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;


const isNewGeminiModel = (model) => {
return model.includes('gemini-2.');
};

class GoogleClient extends BaseClient {
constructor(credentials, options = {}) {
super('apiKey', options);
Expand Down Expand Up @@ -124,7 +127,11 @@ class GoogleClient extends BaseClient {
.filter((ex) => ex)
.filter((obj) => obj.input.content !== '' && obj.output.content !== '');

this.modelOptions = this.options.modelOptions || {};
// Set modelOptions, ensuring enableSearch is included
this.modelOptions = {
...(this.options.modelOptions || {}),
enableSearch: this.options.enableSearch,
};

this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));

Expand Down Expand Up @@ -420,9 +427,9 @@ class GoogleClient extends BaseClient {

logger.debug('[GoogleClient]', {
orderedMessages,
parentMessageId,
parentMessageId,
});

const formattedMessages = orderedMessages.map((message) => ({
author: message.isCreatedByUser ? this.userLabel : this.modelLabel,
content: message?.content ?? message.text,
Expand Down Expand Up @@ -624,7 +631,29 @@ class GoogleClient extends BaseClient {
return new ChatVertexAI(clientOptions);
} else if (!EXCLUDED_GENAI_MODELS.test(model)) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel({ ...clientOptions, model }, requestOptions);
const tools = [];
if (this.modelOptions.enableSearch) {
logger.debug('[GoogleClient] Adding search tool');
if (isNewGeminiModel(model)) {
tools.push({
googleSearch: {}
});
} else {
tools.push({
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: DynamicRetrievalMode.MODE_DYNAMIC,
dynamicThreshold: 0.7,
},
},
});
}
}
return new GenAI(this.apiKey).getGenerativeModel({
...clientOptions,
model,
tools,
}, requestOptions );
}

logger.debug('Creating Chat Google Generative AI client');
Expand Down Expand Up @@ -711,17 +740,48 @@ class GoogleClient extends BaseClient {

const delay = modelName.includes('flash') ? 8 : 15;
const result = await client.generateContentStream(requestOptions);
let lastGroundingMetadata = null;

for await (const chunk of result.stream) {
const chunkText = chunk.text();
await this.generateTextStream(chunkText, onProgress, {
delay,
});
reply += chunkText;
await sleep(streamRate);
// Get the text content from the first candidate's content parts
const text = chunk.candidates?.[0]?.content?.parts?.[0]?.text ?? '';

// Store the grounding metadata from the last chunk that has it
if (chunk.candidates?.[0]?.groundingMetadata) {
lastGroundingMetadata = chunk.candidates[0].groundingMetadata;
}

// Only send text content if there is any
if (text) {
await this.generateTextStream(text, onProgress, {
delay,
metadata: lastGroundingMetadata ? { groundingMetadata: lastGroundingMetadata } : undefined
});
reply += text;
await sleep(streamRate);
}
}
return reply;
}

// Send final completion message with metadata
const finalMessage = {
text: reply,
isComplete: true,
metadata: lastGroundingMetadata ? { groundingMetadata: lastGroundingMetadata } : undefined
};

await onProgress(finalMessage);

// Set metadata for BaseClient to save
if (lastGroundingMetadata) {
this.metadata = { groundingMetadata: lastGroundingMetadata };
}

return {
text: reply,
groundingMetadata: lastGroundingMetadata
};

}
const stream = await model.stream(messages, {
signal: abortController.signal,
safetySettings: _payload.safetySettings,
Expand Down Expand Up @@ -890,11 +950,26 @@ class GoogleClient extends BaseClient {
async sendCompletion(payload, opts = {}) {
payload.safetySettings = this.getSafetySettings();

let reply = '';
reply = await this.getCompletion(payload, opts);
return reply.trim();
const response = await this.getCompletion(payload, opts);

// Handle both string and object responses
if (typeof response === 'string') {
return response.trim();
}

// If response is an object with text and metadata
if (response && typeof response === 'object') {
const { text, groundingMetadata } = response;
if (groundingMetadata) {
this.metadata = { groundingMetadata };
}
return text.trim();
}

return '';
}


getSafetySettings() {
return [
{
Expand Down Expand Up @@ -943,3 +1018,4 @@ class GoogleClient extends BaseClient {
}

module.exports = GoogleClient;

9 changes: 8 additions & 1 deletion api/app/clients/TextStream.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TextStream extends Readable {
this.minChunkSize = options.minChunkSize ?? 2;
this.maxChunkSize = options.maxChunkSize ?? 4;
this.delay = options.delay ?? 20; // Time in milliseconds
this.metadata = options.metadata;
}

_read() {
Expand All @@ -35,7 +36,13 @@ class TextStream extends Readable {
async processTextStream(onProgressCallback) {
const streamPromise = new Promise((resolve, reject) => {
this.on('data', (chunk) => {
onProgressCallback(chunk.toString());
const payload = {
text: chunk.toString(),
};
if (this.metadata) {
payload.metadata = this.metadata;
}
onProgressCallback(payload);
});

this.on('end', () => {
Expand Down
1 change: 1 addition & 0 deletions api/models/Message.js
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ async function updateMessage(req, message, metadata) {
isCreatedByUser: updatedMessage.isCreatedByUser,
tokenCount: updatedMessage.tokenCount,
isEdited: true,
groundingMetadata: updatedMessage.groundingMetadata,
};
} catch (err) {
logger.error('Error updating message:', err);
Expand Down
5 changes: 5 additions & 0 deletions api/models/schema/defaults.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ const conversationPreset = {
max_tokens: {
type: Number,
},
// for google only
enableSearch: {
type: Boolean,
required: false,
},
};

const agentOptions = {
Expand Down
1 change: 1 addition & 0 deletions api/models/schema/messageSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const messageSchema = mongoose.Schema(
type: String,
},
attachments: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined },
groundingMetadata: { type: mongoose.Schema.Types.Mixed, default: undefined },
/*
attachments: {
type: [
Expand Down
22 changes: 19 additions & 3 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
let promptTokens;
let userMessageId;
let responseMessageId;
let currentMetadata = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
Expand Down Expand Up @@ -60,7 +61,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
({ text: partialText, metadata }) => {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
Expand All @@ -76,6 +77,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
}, Time.FIVE_MINUTES);
*/

if (metadata) {
currentMetadata = metadata;
}
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
},
3000,
Expand All @@ -94,6 +98,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
text: getPartialText(),
userMessage,
promptTokens,
...(currentMetadata ? { metadata: currentMetadata } : {}),
});

const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
Expand Down Expand Up @@ -131,6 +136,15 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;

// Add metadata to the final response if available
if (currentMetadata) {
if (currentMetadata.metadata?.groundingMetadata) {
response.groundingMetadata = currentMetadata.metadata.groundingMetadata;
} else if (currentMetadata.groundingMetadata) {
response.groundingMetadata = currentMetadata.groundingMetadata;
}
}

const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
Expand All @@ -142,19 +156,21 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
}

if (!abortController.signal.aborted) {
const finalResponse = { ...response };

sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: userMessage,
responseMessage: response,
responseMessage: finalResponse,
});
res.end();

if (!client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...response, user },
{ ...finalResponse, user },
{ context: 'api/server/controllers/AskController.js - response end' },
);
}
Expand Down
3 changes: 3 additions & 0 deletions api/server/services/Endpoints/google/build.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ const buildOptions = (endpoint, parsedBody) => {
greeting,
spec,
artifacts,
enableSearch,
...modelOptions
} = parsedBody;

const endpointOption = removeNullishValues({
examples,
endpoint,
Expand All @@ -22,6 +24,7 @@ const buildOptions = (endpoint, parsedBody) => {
iconURL,
greeting,
spec,
enableSearch,
modelOptions,
});

Expand Down
28 changes: 27 additions & 1 deletion api/server/utils/handleText.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,32 @@ const base = { message: true, initial: true };
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0;
let tokens = addSpaceIfNeeded(generation);
let currentMetadata = null;

const basePayload = Object.assign({}, base, { text: tokens || '' });

const progressCallback = (chunk, { res, ...rest }) => {
basePayload.text = basePayload.text + chunk;
// If chunk is an object with text and metadata
if (typeof chunk === 'object' && chunk.text !== undefined) {
basePayload.text = basePayload.text + chunk.text;
if (chunk.metadata) {
currentMetadata = chunk.metadata;
}
} else {
// Handle plain text chunks
basePayload.text = basePayload.text + chunk;
}

// Always include current metadata in the payload if available
const payload = Object.assign({}, basePayload, rest);
if (currentMetadata) {
if (currentMetadata.metadata?.groundingMetadata) {
payload.groundingMetadata = currentMetadata.metadata.groundingMetadata;
} else if (currentMetadata.groundingMetadata) {
payload.groundingMetadata = currentMetadata.groundingMetadata;
}
}

sendMessage(res, payload);
if (_onProgress) {
_onProgress(payload);
Expand All @@ -41,6 +60,13 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
basePayload.text = basePayload.text + extraTokens;
const message = Object.assign({}, basePayload, payload);
if (currentMetadata) {
if (currentMetadata.metadata?.groundingMetadata) {
message.groundingMetadata = currentMetadata.metadata.groundingMetadata;
} else if (currentMetadata.groundingMetadata) {
message.groundingMetadata = currentMetadata.groundingMetadata;
}
}
sendMessage(res, message);
if (i === 0) {
basePayload.initial = false;
Expand Down
15 changes: 13 additions & 2 deletions client/src/components/Chat/Messages/Content/MessageContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,22 @@ const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplay
() => message.messageId === latestMessage?.messageId,
[message.messageId, latestMessage?.messageId],
);

let content: React.ReactElement;
if (!isCreatedByUser) {
content = (
<Markdown content={text} showCursor={showCursorState} isLatestMessage={isLatestMessage} />
<>
<Markdown content={text} showCursor={showCursorState} isLatestMessage={isLatestMessage} />
{message.groundingMetadata?.searchEntryPoint?.renderedContent && (
<div className="mt-4 rounded-lg border border-token-border-light bg-token-surface-secondary p-4">
<div
className="prose dark:prose-invert"
dangerouslySetInnerHTML={{
__html: message.groundingMetadata.searchEntryPoint.renderedContent
}}
/>
</div>
)}
</>
);
} else if (enableUserMsgMarkdown) {
content = <MarkdownLite content={text} />;
Expand Down
Loading
Loading