Skip to content

Commit

Permalink
fix(chat): cancel run (#168)
Browse files Browse the repository at this point in the history
Signed-off-by: Petr Kadlec <[email protected]>
  • Loading branch information
kapetr authored Jan 7, 2025
1 parent d794c1f commit f8b1201
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ const getCspHeader = (nonce: string) => {
base-uri 'self';
form-action 'self';
frame-ancestors 'none';
upgrade-insecure-requests;
${process.env.NODE_ENV === 'production' ? 'upgrade-insecure-requests;' : ''}
`;

return cspHeader.replace(/\s{2,}/g, ' ').trim();
Expand Down
48 changes: 48 additions & 0 deletions src/modules/chat/api/useDeleteMessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { useMutation, useQueryClient } from '@tanstack/react-query';
import { deleteMessage } from '@/app/api/threads-messages';
import { useProjectContext } from '@/layout/providers/ProjectProvider';
import { messagesWithFilesQuery } from '../queries';

export function useDeleteMessage() {
const { project, organization } = useProjectContext();
const queryClient = useQueryClient();

const mutation = useMutation({
mutationFn: ({ threadId, messageId }: DeleteMutationParams) =>
deleteMessage(organization.id, project.id, threadId, messageId),
onSuccess: (_, { threadId }) => {
queryClient.invalidateQueries({
queryKey: [
messagesWithFilesQuery(
organization.id,
project.id,
threadId,
).queryKey.at(0),
],
});
},
});

return mutation;
}

interface DeleteMutationParams {
threadId: string;
messageId: string;
}
30 changes: 6 additions & 24 deletions src/modules/chat/assistant-plan/PlanStep.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import { ToolApprovalValue } from '../types';
import classes from './PlanStep.module.scss';
import { useUserSetting } from '@/layout/hooks/useUserSetting';
import { useProjectContext } from '@/layout/providers/ProjectProvider';
import { getToolApproval, getToolReferenceFromToolCall } from './utils';

interface Props {
step: AssistantPlanStep;
Expand Down Expand Up @@ -91,18 +92,7 @@ export function PlanStep({ step, toolCall, allStepsDone }: Props) {
const status = getStepStatus(step, run);

const toolKey = toolCall.type;
const tool =
toolKey === 'system'
? {
type: toolKey,
id: toolCall.toolId,
}
: toolKey === 'user'
? {
type: toolKey,
id: toolCall.toolId,
}
: { type: toolKey, id: toolKey };
const tool = getToolReferenceFromToolCall(toolCall);

const { toolName, toolIcon } = useToolInfo({
organization,
Expand All @@ -115,18 +105,10 @@ export function PlanStep({ step, toolCall, allStepsDone }: Props) {
const { setExpandedStep } = useExpandedStepActions();
const expanded = expandedState?.stepId === step.id;

const toolApproval = (
run?.status === 'requires_action' &&
run.required_action?.type === 'submit_tool_approvals'
? run.required_action.submit_tool_approvals.tool_calls
: []
)
.map((tool) => ({
id: tool.id,
toolId: getToolApprovalId(tool),
type: tool.type,
}))
.find((toolApproval) => toolApproval.toolId === tool.id);
const toolApproval = useMemo(
() => getToolApproval(toolCall, run),
[run, toolCall],
);

const handleToolApprovalSubmit = (value: ToolApprovalValue) => {
if (value === 'always' && thread && toolApproval?.toolId) {
Expand Down
25 changes: 24 additions & 1 deletion src/modules/chat/assistant-plan/PlanWithSources.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ import { runStepsQuery } from '../queries';
import { BotChatMessage } from '../types';
import { PlanView } from './PlanView';
import classes from './PlanWithSources.module.scss';
import { updatePlanWithRunStep } from './utils';
import {
getToolApproval,
getToolReferenceFromToolCall,
updatePlanWithRunStep,
} from './utils';
import { TraceData } from '../trace/types';
import { useBuildTraceData } from '../trace/useBuildTraceData';
import { TraceInfoView } from '../trace/TraceInfoView';
Expand Down Expand Up @@ -108,6 +112,25 @@ function PlanWithSourcesComponent({ message, inView }: Props) {
threadId: thread?.id,
});

useEffect(() => {
if (isOpen) return;

const steps = plan.steps || [];
const hasToolApprovalRequest = steps.some((step) => {
if (step.status === 'in_progress') {
const lastToolCall = step.toolCalls.at(-1);
if (lastToolCall) {
const toolApproval = getToolApproval(lastToolCall, message.run);
return Boolean(toolApproval);
}
}
return false;
});
if (hasToolApprovalRequest) {
setIsOpen(true);
}
}, [isOpen, message.run, plan.steps]);

useEffect(() => {
setShowButton(!debugMode && !plan.pending);
}, [debugMode, plan.pending]);
Expand Down
36 changes: 36 additions & 0 deletions src/modules/chat/assistant-plan/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import {
RunStepDeltaEventResponse,
StepToolCall,
SystemToolResult,
ThreadRun,
} from '@/app/api/threads-runs/types';
import {
isArXivToolResult,
isRunStepDeltaEventDetailsThought,
isWebSearchToolResult,
isWikipediaToolResult,
} from '@/app/api/threads-runs/utils';
import { getToolApprovalId } from '@/modules/tools/utils';
import { isNotNull } from '@/utils/helpers';
import { v4 as uuid } from 'uuid';

Expand Down Expand Up @@ -219,3 +221,37 @@ export function updatePlanWithRunStepDelta(
}
return plan;
}

export function getToolReferenceFromToolCall(toolCall: StepToolCall) {
const toolKey = toolCall.type;
return toolKey === 'system'
? {
type: toolKey,
id: toolCall.toolId,
}
: toolKey === 'user'
? {
type: toolKey,
id: toolCall.toolId,
}
: { type: toolKey, id: toolKey };
}

export function getToolApproval(toolCall: StepToolCall, run?: ThreadRun) {
const tool = getToolReferenceFromToolCall(toolCall);

if (
run?.status === 'requires_action' &&
run.required_action?.type === 'submit_tool_approvals'
) {
return run.required_action.submit_tool_approvals.tool_calls
.map((tool) => ({
id: tool.id,
toolId: getToolApprovalId(tool),
type: tool.type,
}))
.find((toolApproval) => toolApproval.toolId === tool.id);
}

return null;
}
20 changes: 14 additions & 6 deletions src/modules/chat/layout/InputBar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import { FilesMenu } from './FilesMenu';
import classes from './InputBar.module.scss';
import { PromptSuggestions } from './PromptSuggestions';
import { ThreadSettings } from './ThreadSettings';
import { useMessages } from '../providers/useMessages';
import { UserChatMessage } from '../types';

interface Props {
showSuggestions?: boolean;
Expand Down Expand Up @@ -100,6 +100,13 @@ export const InputBar = memo(function InputBar({
inputRef.current?.focus();
};

const handleAfterRemoveSentMessage = useCallback(
(message: UserChatMessage) => {
setValue('input', message.content, { shouldValidate: true });
},
[setValue],
);

const isPending = status !== 'ready';
const inputValue = watch('input');
const isFileUploadEnabled = isFeatureEnabled(FeatureName.Files);
Expand Down Expand Up @@ -130,7 +137,10 @@ export const InputBar = memo(function InputBar({
handleSubmit(({ input }) => {
onMessageSubmit?.();
resetForm();
sendMessage(input).then((result) => {

sendMessage(input, {
onAfterRemoveSentMessage: handleAfterRemoveSentMessage,
}).then((result) => {
onMessageSent?.(result);
});
})();
Expand Down Expand Up @@ -212,10 +222,8 @@ export const InputBar = memo(function InputBar({
size="sm"
hasIconOnly
iconDescription="Cancel"
disabled={status === 'waiting'}
onClick={() => {
cancel();
}}
disabled={status === 'waiting' || status === 'aborting'}
onClick={() => cancel()}
/>
)}
</div>
Expand Down
Loading

0 comments on commit f8b1201

Please sign in to comment.