diff --git a/src/middleware.ts b/src/middleware.ts index 3dfc129..1b68da4 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -101,7 +101,7 @@ const getCspHeader = (nonce: string) => { base-uri 'self'; form-action 'self'; frame-ancestors 'none'; - upgrade-insecure-requests; + ${process.env.NODE_ENV === 'production' ? 'upgrade-insecure-requests;' : ''} `; return cspHeader.replace(/\s{2,}/g, ' ').trim(); diff --git a/src/modules/chat/api/useDeleteMessage.ts b/src/modules/chat/api/useDeleteMessage.ts new file mode 100644 index 0000000..13c02d5 --- /dev/null +++ b/src/modules/chat/api/useDeleteMessage.ts @@ -0,0 +1,48 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { useMutation, useQueryClient } from '@tanstack/react-query'; +import { deleteMessage } from '@/app/api/threads-messages'; +import { useProjectContext } from '@/layout/providers/ProjectProvider'; +import { messagesWithFilesQuery } from '../queries'; + +export function useDeleteMessage() { + const { project, organization } = useProjectContext(); + const queryClient = useQueryClient(); + + const mutation = useMutation({ + mutationFn: ({ threadId, messageId }: DeleteMutationParams) => + deleteMessage(organization.id, project.id, threadId, messageId), + onSuccess: (_, { threadId }) => { + queryClient.invalidateQueries({ + queryKey: [ + messagesWithFilesQuery( + organization.id, + project.id, + threadId, + ).queryKey.at(0), + ], + }); + }, + }); + + return mutation; +} + +interface DeleteMutationParams { + threadId: string; + messageId: string; +} diff --git a/src/modules/chat/assistant-plan/PlanStep.tsx b/src/modules/chat/assistant-plan/PlanStep.tsx index ecf877f..7897cb9 100644 --- a/src/modules/chat/assistant-plan/PlanStep.tsx +++ b/src/modules/chat/assistant-plan/PlanStep.tsx @@ -57,6 +57,7 @@ import { ToolApprovalValue } from '../types'; import classes from './PlanStep.module.scss'; import { useUserSetting } from '@/layout/hooks/useUserSetting'; import { useProjectContext } from '@/layout/providers/ProjectProvider'; +import { getToolApproval, getToolReferenceFromToolCall } from './utils'; interface Props { step: AssistantPlanStep; @@ -91,18 +92,7 @@ export function PlanStep({ step, toolCall, allStepsDone }: Props) { const status = getStepStatus(step, run); const toolKey = toolCall.type; - const tool = - toolKey === 'system' - ? { - type: toolKey, - id: toolCall.toolId, - } - : toolKey === 'user' - ? { - type: toolKey, - id: toolCall.toolId, - } - : { type: toolKey, id: toolKey }; + const tool = getToolReferenceFromToolCall(toolCall); const { toolName, toolIcon } = useToolInfo({ organization, @@ -115,18 +105,10 @@ export function PlanStep({ step, toolCall, allStepsDone }: Props) { const { setExpandedStep } = useExpandedStepActions(); const expanded = expandedState?.stepId === step.id; - const toolApproval = ( - run?.status === 'requires_action' && - run.required_action?.type === 'submit_tool_approvals' - ? run.required_action.submit_tool_approvals.tool_calls - : [] - ) - .map((tool) => ({ - id: tool.id, - toolId: getToolApprovalId(tool), - type: tool.type, - })) - .find((toolApproval) => toolApproval.toolId === tool.id); + const toolApproval = useMemo( + () => getToolApproval(toolCall, run), + [run, toolCall], + ); const handleToolApprovalSubmit = (value: ToolApprovalValue) => { if (value === 'always' && thread && toolApproval?.toolId) { diff --git a/src/modules/chat/assistant-plan/PlanWithSources.tsx b/src/modules/chat/assistant-plan/PlanWithSources.tsx index d110714..e84eea3 100644 --- a/src/modules/chat/assistant-plan/PlanWithSources.tsx +++ b/src/modules/chat/assistant-plan/PlanWithSources.tsx @@ -34,7 +34,11 @@ import { runStepsQuery } from '../queries'; import { BotChatMessage } from '../types'; import { PlanView } from './PlanView'; import classes from './PlanWithSources.module.scss'; -import { updatePlanWithRunStep } from './utils'; +import { + getToolApproval, + getToolReferenceFromToolCall, + updatePlanWithRunStep, +} from './utils'; import { TraceData } from '../trace/types'; import { useBuildTraceData } from '../trace/useBuildTraceData'; import { TraceInfoView } from '../trace/TraceInfoView'; @@ -108,6 +112,25 @@ function PlanWithSourcesComponent({ message, inView }: Props) { threadId: thread?.id, }); + useEffect(() => { + if (isOpen) return; + + const steps = plan.steps || []; + const hasToolApprovalRequest = steps.some((step) => { + if (step.status === 'in_progress') { + const lastToolCall = step.toolCalls.at(-1); + if (lastToolCall) { + const toolApproval = getToolApproval(lastToolCall, message.run); + return Boolean(toolApproval); + } + } + return false; + }); + if (hasToolApprovalRequest) { + setIsOpen(true); + } + }, [isOpen, message.run, plan.steps]); + useEffect(() => { setShowButton(!debugMode && !plan.pending); }, [debugMode, plan.pending]); diff --git a/src/modules/chat/assistant-plan/utils.ts b/src/modules/chat/assistant-plan/utils.ts index b6c4295..726eec5 100644 --- a/src/modules/chat/assistant-plan/utils.ts +++ b/src/modules/chat/assistant-plan/utils.ts @@ -21,6 +21,7 @@ import { RunStepDeltaEventResponse, StepToolCall, SystemToolResult, + ThreadRun, } from '@/app/api/threads-runs/types'; import { isArXivToolResult, @@ -28,6 +29,7 @@ import { isWebSearchToolResult, isWikipediaToolResult, } from '@/app/api/threads-runs/utils'; +import { getToolApprovalId } from '@/modules/tools/utils'; import { isNotNull } from '@/utils/helpers'; import { v4 as uuid } from 'uuid'; @@ -219,3 +221,37 @@ export function updatePlanWithRunStepDelta( } return plan; } + +export function getToolReferenceFromToolCall(toolCall: StepToolCall) { + const toolKey = toolCall.type; + return toolKey === 'system' + ? { + type: toolKey, + id: toolCall.toolId, + } + : toolKey === 'user' + ? { + type: toolKey, + id: toolCall.toolId, + } + : { type: toolKey, id: toolKey }; +} + +export function getToolApproval(toolCall: StepToolCall, run?: ThreadRun) { + const tool = getToolReferenceFromToolCall(toolCall); + + if ( + run?.status === 'requires_action' && + run.required_action?.type === 'submit_tool_approvals' + ) { + return run.required_action.submit_tool_approvals.tool_calls + .map((tool) => ({ + id: tool.id, + toolId: getToolApprovalId(tool), + type: tool.type, + })) + .find((toolApproval) => toolApproval.toolId === tool.id); + } + + return null; +} diff --git a/src/modules/chat/layout/InputBar.tsx b/src/modules/chat/layout/InputBar.tsx index 5ad018e..ef687dd 100644 --- a/src/modules/chat/layout/InputBar.tsx +++ b/src/modules/chat/layout/InputBar.tsx @@ -40,7 +40,7 @@ import { FilesMenu } from './FilesMenu'; import classes from './InputBar.module.scss'; import { PromptSuggestions } from './PromptSuggestions'; import { ThreadSettings } from './ThreadSettings'; -import { useMessages } from '../providers/useMessages'; +import { UserChatMessage } from '../types'; interface Props { showSuggestions?: boolean; @@ -100,6 +100,13 @@ export const InputBar = memo(function InputBar({ inputRef.current?.focus(); }; + const handleAfterRemoveSentMessage = useCallback( + (message: UserChatMessage) => { + setValue('input', message.content, { shouldValidate: true }); + }, + [setValue], + ); + const isPending = status !== 'ready'; const inputValue = watch('input'); const isFileUploadEnabled = isFeatureEnabled(FeatureName.Files); @@ -130,7 +137,10 @@ export const InputBar = memo(function InputBar({ handleSubmit(({ input }) => { onMessageSubmit?.(); resetForm(); - sendMessage(input).then((result) => { + + sendMessage(input, { + onAfterRemoveSentMessage: handleAfterRemoveSentMessage, + }).then((result) => { onMessageSent?.(result); }); })(); @@ -212,10 +222,8 @@ export const InputBar = memo(function InputBar({ size="sm" hasIconOnly iconDescription="Cancel" - disabled={status === 'waiting'} - onClick={() => { - cancel(); - }} + disabled={status === 'waiting' || status === 'aborting'} + onClick={() => cancel()} /> )} diff --git a/src/modules/chat/providers/ChatProvider.tsx b/src/modules/chat/providers/ChatProvider.tsx index 23dd8c7..30e472d 100644 --- a/src/modules/chat/providers/ChatProvider.tsx +++ b/src/modules/chat/providers/ChatProvider.tsx @@ -93,13 +93,15 @@ import { useModal } from '@/layout/providers/ModalProvider'; import { ApiError } from '@/app/api/errors'; import { UsageLimitModal } from '@/components/UsageLimitModal/UsageLimitModal'; import { PLAN_STEPS_QUERY_PARAMS } from '../assistant-plan/PlanWithSources'; +import { useDeleteMessage } from '../api/useDeleteMessage'; +import { Control } from 'react-hook-form'; interface CancelRunParams { threadId: string; runId: string; } -export type ChatStatus = 'ready' | 'fetching' | 'waiting'; +export type ChatStatus = 'ready' | 'fetching' | 'waiting' | 'aborting'; export interface RunController { abortController: AbortController | null; status: ChatStatus; @@ -162,8 +164,13 @@ export function ChatProvider({ const queryClient = useQueryClient(); const threadSettingsButtonRef = useRef(null); + const { mutateAsync: mutateDeleteMessage } = useDeleteMessage(); + const threadAssistant = useGetThreadAssistant(thread, initialThreadAssistant); - const [getMessages, setMessages] = useMessages({ + const { + messages: [getMessages, setMessages], + refetch: refetchMessages, + } = useMessages({ thread, initialData, }); @@ -285,7 +292,8 @@ export function ChatProvider({ const cancel = useCallback(() => { controllerRef.current.abortController?.abort(); - }, [controllerRef]); + setController((controller) => ({ ...controller, status: 'aborting' })); + }, [controllerRef, setController]); const reset = useCallback( (messages: ChatMessage[]) => { @@ -374,7 +382,7 @@ export function ChatProvider({ }); }, [controllerRef, mutateCancel, threadRef]); - const handlRunCompleted = useCallback(() => { + const handleRunCompleted = useCallback(() => { const lastMessage = getMessages().at(-1); setController(RUN_CONTROLLER_DEFAULT); @@ -443,7 +451,7 @@ export function ChatProvider({ } catch (err) { handleError(err, { toast: false }); } finally { - handlRunCompleted(); + handleRunCompleted(); } const aborted = controller.abortController?.signal.aborted; @@ -456,7 +464,7 @@ export function ChatProvider({ chatStream, controller.abortController?.signal.aborted, controllerRef, - handlRunCompleted, + handleRunCompleted, handleCancelCurrentRun, handleError, setController, @@ -515,7 +523,10 @@ export function ChatProvider({ ]); const sendMessage = useCallback( - async (input: string, { regenerate }: SendMessageOptions = {}) => { + async ( + input: string, + { regenerate, onAfterRemoveSentMessage }: SendMessageOptions = {}, + ) => { if (controllerRef.current.status !== 'ready') { return { aborted: true, thread: null }; } @@ -528,35 +539,50 @@ export function ChatProvider({ }); // Remove last bot message if it was empty, and also last user message - function removeLastMessagePair(ignoreError?: boolean) { + async function removeLastMessagePair(ignoreError?: boolean) { + // synchronize messages before removing + await refetchMessages(); + setMessages((messages) => { let message = messages.at(-1); - let shouldRemoveUserMessage = true; - if (message?.role === 'assistant') { - if (!isBotMessage(message)) { - throw new Error('Unexpected last message found.'); + if (!isBotMessage(message)) { + throw new Error('Unexpected last message found.'); + } + + if (message.plan) { + message.plan.steps = message.plan.steps.map((step) => + step.status === 'in_progress' + ? { ...step, status: 'cancelled' } + : step, + ); + } + + if ( + !regenerate && + messages.length > 2 && + !message.content && + message.plan == null && + (ignoreError || message.error == null) + ) { + messages.pop(); + if (thread && message.id) { + mutateDeleteMessage({ + threadId: thread?.id, + messageId: message.id, + }); } - if ( - !message.content && - message.plan == null && - (ignoreError || message.error == null) - ) { + + message = messages.at(-1); + if (message?.role === 'user') { messages.pop(); - message = messages.at(-1); - } else { - shouldRemoveUserMessage = false; - if (message.plan) { - message.plan.steps = message.plan.steps.map((step) => - step.status === 'in_progress' - ? { ...step, status: 'cancelled' } - : step, - ); - } + if (thread && message.id) + mutateDeleteMessage({ + threadId: thread?.id, + messageId: message.id, + }); + onAfterRemoveSentMessage?.(message); } } - if (message?.role === 'user' && shouldRemoveUserMessage) { - messages.pop(); - } }); } @@ -657,44 +683,46 @@ export function ChatProvider({ const { tools, toolApprovals } = getUsedTools(thread); - await chatStream({ - action: { - id: 'create-run', - body: { - assistant_id: assistant.id, - tools, - tool_approvals: toolApprovals, - uiMetadata: { - resources: getRunResources(thread, assistant), + if (!abortController.signal.aborted) { + await chatStream({ + action: { + id: 'create-run', + body: { + assistant_id: assistant.id, + tools, + tool_approvals: toolApprovals, + uiMetadata: { + resources: getRunResources(thread, assistant), + }, }, }, - }, - onMessageCompleted: (response) => { - setMessagesWithFilesQueryData( - thread?.id, - response.data, - response.data?.run_id, - ); - - if (files.length > 0) { - queryClient.invalidateQueries({ - queryKey: threadsQuery(organization.id, project.id).queryKey, - }); + onMessageCompleted: (response) => { + setMessagesWithFilesQueryData( + thread?.id, + response.data, + response.data?.run_id, + ); + + if (files.length > 0) { + queryClient.invalidateQueries({ + queryKey: threadsQuery(organization.id, project.id).queryKey, + }); - queryClient.invalidateQueries({ - queryKey: threadQuery( - organization.id, - project.id, - thread?.id ?? '', - ).queryKey, - }); - } - }, - }); + queryClient.invalidateQueries({ + queryKey: threadQuery( + organization.id, + project.id, + thread?.id ?? '', + ).queryKey, + }); + } + }, + }); + } } catch (err) { handleChatError(err); } finally { - handlRunCompleted(); + handleRunCompleted(); } const aborted = abortController.signal.aborted; @@ -710,7 +738,9 @@ export function ChatProvider({ [ controllerRef, setController, + refetchMessages, setMessages, + mutateDeleteMessage, handleCancelCurrentRun, openModal, handleError, @@ -722,12 +752,12 @@ export function ChatProvider({ assistant, ensureThread, getUsedTools, - chatStream, onBeforePostMessage, getMessages, setMessagesWithFilesQueryData, + chatStream, queryClient, - handlRunCompleted, + handleRunCompleted, ], ); @@ -804,6 +834,7 @@ export function ChatProvider({ export type SendMessageOptions = { regenerate?: boolean; + onAfterRemoveSentMessage?: (message: UserChatMessage) => void; }; export type SendMessageResult = { diff --git a/src/modules/chat/providers/useMessages.ts b/src/modules/chat/providers/useMessages.ts index 1eb07fa..02964d7 100644 --- a/src/modules/chat/providers/useMessages.ts +++ b/src/modules/chat/providers/useMessages.ts @@ -33,7 +33,7 @@ export function useMessages({ }) { const { project, organization } = useProjectContext(); - const { data } = useQuery({ + const { data, refetch } = useQuery({ ...messagesWithFilesQuery(organization.id, project.id, thread?.id || '', { limit: MESSAGES_PAGE_SIZE, }), @@ -48,7 +48,10 @@ export function useMessages({ enabled: Boolean(thread), }); - return useImmerWithGetter( - thread ? getMessagesFromThreadMessages(data ?? []) : [], - ); + return { + messages: useImmerWithGetter( + thread ? getMessagesFromThreadMessages(data ?? []) : [], + ), + refetch, + }; }