diff --git a/frontend/package.json b/frontend/package.json index 3ad5157275b..cd9ec2f2169 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -78,6 +78,7 @@ "@types/react-grid-layout": "^1.3.5", "@uidotdev/usehooks": "^2.4.1", "@uiw/codemirror-extensions-langs": "^4.23.0", + "@uiw/codemirror-extensions-mentions": "^4.23.0", "@uiw/react-codemirror": "^4.23.0", "@valtown/codemirror-codeium": "^1.1.1", "@xterm/addon-attach": "^0.11.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 3f7a0cd4fdb..355d30435ec 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -200,6 +200,9 @@ dependencies: '@uiw/codemirror-extensions-langs': specifier: ^4.23.0 version: 4.23.0(@codemirror/autocomplete@6.18.0)(@codemirror/language-data@6.5.1)(@codemirror/language@6.10.2)(@codemirror/legacy-modes@6.4.0)(@codemirror/state@6.4.1)(@codemirror/view@6.32.0)(@lezer/common@1.2.1)(@lezer/highlight@1.2.1)(@lezer/javascript@1.4.14)(@lezer/lr@1.4.2) + '@uiw/codemirror-extensions-mentions': + specifier: ^4.23.0 + version: 4.23.0(@codemirror/state@6.4.1)(@codemirror/view@6.32.0) '@uiw/react-codemirror': specifier: ^4.23.0 version: 4.23.0(@babel/runtime@7.24.4)(@codemirror/autocomplete@6.18.0)(@codemirror/language@6.10.2)(@codemirror/lint@6.8.1)(@codemirror/search@6.5.6)(@codemirror/state@6.4.1)(@codemirror/theme-one-dark@6.1.2)(@codemirror/view@6.32.0)(codemirror@6.0.1)(react-dom@18.3.1)(react@18.3.1) @@ -8765,6 +8768,16 @@ packages: - '@lezer/lr' dev: false + /@uiw/codemirror-extensions-mentions@4.23.0(@codemirror/state@6.4.1)(@codemirror/view@6.32.0): + resolution: {integrity: sha512-l+UwQcnpXvf/r8m3LLx27+5oB41Uh/6hv0JkOy1kHfNj8bZqGBMBgrnObGFfkC/oYOJHeOXvtCb0VFShTDcBRw==} + peerDependencies: + '@codemirror/state': '>=6.0.0' + '@codemirror/view': '>=6.0.0' + dependencies: + '@codemirror/state': 6.4.1 + '@codemirror/view': 6.32.0 + dev: false + /@uiw/react-codemirror@4.23.0(@babel/runtime@7.24.4)(@codemirror/autocomplete@6.18.0)(@codemirror/language@6.10.2)(@codemirror/lint@6.8.1)(@codemirror/search@6.5.6)(@codemirror/state@6.4.1)(@codemirror/theme-one-dark@6.1.2)(@codemirror/view@6.32.0)(codemirror@6.0.1)(react-dom@18.3.1)(react@18.3.1): resolution: {integrity: sha512-MnqTXfgeLA3fsUUQjqjJgemEuNyoGALgsExVm0NQAllAAi1wfj+IoKFeK+h3XXMlTFRCFYOUh4AHDv0YXJLsOg==} peerDependencies: diff --git a/frontend/src/components/editor/ai/add-cell-with-ai.tsx b/frontend/src/components/editor/ai/add-cell-with-ai.tsx index 3a6af39ab12..d08e9f2c5a4 100644 --- a/frontend/src/components/editor/ai/add-cell-with-ai.tsx +++ b/frontend/src/components/editor/ai/add-cell-with-ai.tsx @@ -2,17 +2,83 @@ import { useCellActions } from "../../../core/cells/cells"; import { cn } from "@/utils/cn"; import { Button } from "@/components/ui/button"; -import { Loader2Icon, SparklesIcon, XIcon } from "lucide-react"; +import { ChevronsUpDown, Loader2Icon, SparklesIcon, XIcon } from "lucide-react"; import { toast } from "@/components/ui/use-toast"; import { getCodes } from "@/core/codemirror/copilot/getCodes"; import { API } from "@/core/network/api"; import { prettyError } from "@/utils/errors"; import { useCompletion } from "ai/react"; -import ReactCodeMirror, { EditorView } from "@uiw/react-codemirror"; +import ReactCodeMirror, { + EditorView, + keymap, + minimalSetup, +} from "@uiw/react-codemirror"; +import { Prec } from "@codemirror/state"; import { customPythonLanguageSupport } from "@/core/codemirror/language/python"; import { asURL } from "@/utils/url"; +import { mentions } from "@uiw/codemirror-extensions-mentions"; +import { useMemo, useState } from "react"; +import { store } from "@/core/state/jotai"; +import { datasetTablesAtom } from "@/core/datasets/state"; +import { Logger } from "@/utils/Logger"; +import { Maps } from "@/utils/maps"; +import type { DataTable } from "@/core/kernel/messages"; +import { useAtom, useAtomValue } from "jotai"; +import type { Completion } from "@codemirror/autocomplete"; +import { + DropdownMenu, + DropdownMenuTrigger, + DropdownMenuContent, + DropdownMenuItem, +} from "@/components/ui/dropdown-menu"; +import { sql } from "@codemirror/lang-sql"; +import { SQLLanguageAdapter } from "@/core/codemirror/language/sql"; +import { atomWithStorage } from "jotai/utils"; -const extensions = [customPythonLanguageSupport(), EditorView.lineWrapping]; +const pythonExtensions = [ + customPythonLanguageSupport(), + EditorView.lineWrapping, +]; +const sqlExtensions = [sql(), EditorView.lineWrapping]; + +function getCompletionBody(input: string): object { + const datasets = extractDatasets(input); + Logger.debug("Included datasets", datasets); + + return { + includeOtherCode: getCodes(""), + context: { + schema: datasets.map((dataset) => ({ + name: dataset.name, + columns: dataset.columns.map((column) => ({ + name: column.name, + type: column.type, + })), + })), + }, + code: "", + }; +} + +function extractDatasets(input: string): DataTable[] { + const datasets = store.get(datasetTablesAtom); + const existingDatasets = Maps.keyBy(datasets, (dataset) => dataset.name); + + // Extract dataset mentions from the input + const mentionedDatasets = input.match(/@(\w+)/g) || []; + + // Filter to only include datasets that exist + return mentionedDatasets + .map((mention) => mention.slice(1)) + .map((name) => existingDatasets.get(name)) + .filter(Boolean); +} + +// Persist across sessions +const languageAtom = atomWithStorage<"python" | "sql">( + "marimo:ai-language", + "python", +); /** * Add a cell with AI. @@ -21,6 +87,8 @@ export const AddCellWithAI: React.FC<{ onClose: () => void; }> = ({ onClose }) => { const { createNewCell } = useCellActions(); + const [completionBody, setCompletionBody] = useState({}); + const [language, setLanguage] = useAtom(languageAtom); const { completion, @@ -28,15 +96,15 @@ export const AddCellWithAI: React.FC<{ stop, isLoading, setCompletion, - handleInputChange, + setInput, handleSubmit, } = useCompletion({ api: asURL("api/ai/completion").toString(), headers: API.headers(), streamMode: "text", body: { - includeOtherCode: getCodes(""), - code: "", + ...completionBody, + language: language, }, onError: (error) => { toast({ @@ -47,25 +115,40 @@ export const AddCellWithAI: React.FC<{ }); const inputComponent = ( -
+
- + + + + + setLanguage("python")}> + Python + + setLanguage("sql")}> + SQL + + + + { + setCompletion(""); + onClose(); + }} value={input} - onChange={handleInputChange} - onKeyDown={(e) => { - if (e.key === "Enter") { - e.preventDefault(); - handleSubmit(e as unknown as React.FormEvent); - } - if (e.key === "Escape") { - e.preventDefault(); - setCompletion(""); - onClose(); - } + onChange={(newValue) => { + setInput(newValue); + setCompletionBody(getCompletionBody(newValue)); }} - placeholder="Generate with AI" + onSubmit={handleSubmit} /> {isLoading && (