Skip to content

Commit

Permalink
feat: Add support for @ mentions of dataframes and sql tables in AI a…
Browse files Browse the repository at this point in the history
…ssistant (#2181)
  • Loading branch information
mscolnick authored Aug 30, 2024
1 parent e2cfacd commit bd58bf9
Show file tree
Hide file tree
Showing 13 changed files with 546 additions and 66 deletions.
1 change: 1 addition & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"@types/react-grid-layout": "^1.3.5",
"@uidotdev/usehooks": "^2.4.1",
"@uiw/codemirror-extensions-langs": "^4.23.0",
"@uiw/codemirror-extensions-mentions": "^4.23.0",
"@uiw/react-codemirror": "^4.23.0",
"@valtown/codemirror-codeium": "^1.1.1",
"@xterm/addon-attach": "^0.11.0",
Expand Down
13 changes: 13 additions & 0 deletions frontend/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

262 changes: 237 additions & 25 deletions frontend/src/components/editor/ai/add-cell-with-ai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,83 @@
import { useCellActions } from "../../../core/cells/cells";
import { cn } from "@/utils/cn";
import { Button } from "@/components/ui/button";
import { Loader2Icon, SparklesIcon, XIcon } from "lucide-react";
import { ChevronsUpDown, Loader2Icon, SparklesIcon, XIcon } from "lucide-react";
import { toast } from "@/components/ui/use-toast";
import { getCodes } from "@/core/codemirror/copilot/getCodes";
import { API } from "@/core/network/api";
import { prettyError } from "@/utils/errors";
import { useCompletion } from "ai/react";
import ReactCodeMirror, { EditorView } from "@uiw/react-codemirror";
import ReactCodeMirror, {
EditorView,
keymap,
minimalSetup,
} from "@uiw/react-codemirror";
import { Prec } from "@codemirror/state";
import { customPythonLanguageSupport } from "@/core/codemirror/language/python";
import { asURL } from "@/utils/url";
import { mentions } from "@uiw/codemirror-extensions-mentions";
import { useMemo, useState } from "react";
import { store } from "@/core/state/jotai";
import { datasetTablesAtom } from "@/core/datasets/state";
import { Logger } from "@/utils/Logger";
import { Maps } from "@/utils/maps";
import type { DataTable } from "@/core/kernel/messages";
import { useAtom, useAtomValue } from "jotai";
import type { Completion } from "@codemirror/autocomplete";
import {
DropdownMenu,
DropdownMenuTrigger,
DropdownMenuContent,
DropdownMenuItem,
} from "@/components/ui/dropdown-menu";
import { sql } from "@codemirror/lang-sql";
import { SQLLanguageAdapter } from "@/core/codemirror/language/sql";
import { atomWithStorage } from "jotai/utils";

const extensions = [customPythonLanguageSupport(), EditorView.lineWrapping];
const pythonExtensions = [
customPythonLanguageSupport(),
EditorView.lineWrapping,
];
const sqlExtensions = [sql(), EditorView.lineWrapping];

function getCompletionBody(input: string): object {
const datasets = extractDatasets(input);
Logger.debug("Included datasets", datasets);

return {
includeOtherCode: getCodes(""),
context: {
schema: datasets.map((dataset) => ({
name: dataset.name,
columns: dataset.columns.map((column) => ({
name: column.name,
type: column.type,
})),
})),
},
code: "",
};
}

function extractDatasets(input: string): DataTable[] {
const datasets = store.get(datasetTablesAtom);
const existingDatasets = Maps.keyBy(datasets, (dataset) => dataset.name);

// Extract dataset mentions from the input
const mentionedDatasets = input.match(/@(\w+)/g) || [];

// Filter to only include datasets that exist
return mentionedDatasets
.map((mention) => mention.slice(1))
.map((name) => existingDatasets.get(name))
.filter(Boolean);
}

// Persist across sessions
const languageAtom = atomWithStorage<"python" | "sql">(
"marimo:ai-language",
"python",
);

/**
* Add a cell with AI.
Expand All @@ -21,22 +87,24 @@ export const AddCellWithAI: React.FC<{
onClose: () => void;
}> = ({ onClose }) => {
const { createNewCell } = useCellActions();
const [completionBody, setCompletionBody] = useState<object>({});
const [language, setLanguage] = useAtom(languageAtom);

const {
completion,
input,
stop,
isLoading,
setCompletion,
handleInputChange,
setInput,
handleSubmit,
} = useCompletion({
api: asURL("api/ai/completion").toString(),
headers: API.headers(),
streamMode: "text",
body: {
includeOtherCode: getCodes(""),
code: "",
...completionBody,
language: language,
},
onError: (error) => {
toast({
Expand All @@ -47,25 +115,40 @@ export const AddCellWithAI: React.FC<{
});

const inputComponent = (
<div className="flex items-center gap-3 px-3 py-2">
<div className="flex items-center px-3">
<SparklesIcon className="size-4 text-[var(--blue-11)]" />
<input
className="h-8 outline-none text-base focus-visible:shadow-none flex-1 rounded-none border-none focus:border-none"
autoFocus={true}
<DropdownMenu modal={false}>
<DropdownMenuTrigger asChild={true}>
<Button
variant="text"
className="ml-2"
size="xs"
data-testid="language-button"
>
{language === "python" ? "Python" : "SQL"}
<ChevronsUpDown className="ml-1 h-3.5 w-3.5 text-muted-foreground/70" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="center">
<DropdownMenuItem onClick={() => setLanguage("python")}>
Python
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setLanguage("sql")}>
SQL
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
<PromptInput
onClose={() => {
setCompletion("");
onClose();
}}
value={input}
onChange={handleInputChange}
onKeyDown={(e) => {
if (e.key === "Enter") {
e.preventDefault();
handleSubmit(e as unknown as React.FormEvent<HTMLFormElement>);
}
if (e.key === "Escape") {
e.preventDefault();
setCompletion("");
onClose();
}
onChange={(newValue) => {
setInput(newValue);
setCompletionBody(getCompletionBody(newValue));
}}
placeholder="Generate with AI"
onSubmit={handleSubmit}
/>
{isLoading && (
<Button
Expand All @@ -90,7 +173,10 @@ export const AddCellWithAI: React.FC<{
createNewCell({
cellId: "__end__",
before: false,
code: completion,
code:
language === "python"
? completion
: SQLLanguageAdapter.fromQuery(completion),
});
setCompletion("");
onClose();
Expand All @@ -106,16 +192,142 @@ export const AddCellWithAI: React.FC<{
);

return (
<div className={cn("flex flex-col w-full")}>
<div className={cn("flex flex-col w-full gap-2 py-2")}>
{inputComponent}
{!completion && (
<span className="text-xs text-muted-foreground px-3">
You can mention{" "}
<span className="text-[var(--cyan-11)]">@dataframe</span> or{" "}
<span className="text-[var(--cyan-11)]">@sql_table</span> to pull
additional context such as column names.
</span>
)}
{completion && (
<ReactCodeMirror
value={completion}
className="cm border-t"
onChange={setCompletion}
extensions={extensions}
extensions={language === "python" ? pythonExtensions : sqlExtensions}
/>
)}
</div>
);
};

interface PromptInputProps {
value: string;
onClose: () => void;
onChange: (value: string) => void;
onSubmit: () => void;
}

const PromptInput = ({
value,
onChange,
onSubmit,
onClose,
}: PromptInputProps) => {
const handleSubmit = onSubmit;
const handleEscape = onClose;
const tables = useAtomValue(datasetTablesAtom);

const extensions = useMemo(() => {
const completions = tables.map(
(table): Completion => ({
label: `@${table.name}`,
info: () => {
const shape = [
table.num_rows == null ? undefined : `${table.num_rows} rows`,
table.num_columns == null
? undefined
: `${table.num_columns} columns`,
]
.filter(Boolean)
.join(", ");

const infoContainer = document.createElement("div");
infoContainer.classList.add("prose", "prose-sm", "dark:prose-invert");

if (shape) {
const shapeElement = document.createElement("div");
shapeElement.textContent = shape;
shapeElement.style.fontWeight = "bold";
infoContainer.append(shapeElement);
}

if (table.source) {
const sourceElement = document.createElement("figcaption");
sourceElement.textContent = `Source: ${table.source}`;
infoContainer.append(sourceElement);
}

if (table.columns) {
const columnsTable = document.createElement("table");
const headerRow = columnsTable.insertRow();
const nameHeader = headerRow.insertCell();
nameHeader.textContent = "Column";
nameHeader.style.fontWeight = "bold";
const typeHeader = headerRow.insertCell();
typeHeader.textContent = "Type";
typeHeader.style.fontWeight = "bold";

table.columns.forEach((column) => {
const row = columnsTable.insertRow();
const nameCell = row.insertCell();
nameCell.textContent = column.name;
const typeCell = row.insertCell();
typeCell.textContent = column.type;
});

infoContainer.append(columnsTable);
}

return infoContainer;
},
}),
);

return [
mentions(completions),
EditorView.lineWrapping,
minimalSetup(),
Prec.highest(
keymap.of([
{
key: "Enter",
preventDefault: true,
stopPropagation: true,
run: () => {
handleSubmit();
return true;
},
},
]),
),
keymap.of([
{
key: "Escape",
preventDefault: true,
stopPropagation: true,
run: () => {
handleEscape();
return true;
},
},
]),
];
}, [tables, handleSubmit, handleEscape]);

return (
<ReactCodeMirror
className="flex-1 font-sans"
autoFocus={true}
width="100%"
value={value}
basicSetup={false}
extensions={extensions}
onChange={onChange}
placeholder={"Generate with AI"}
/>
);
};
3 changes: 2 additions & 1 deletion frontend/src/components/editor/renderers/CellArray.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ const AddCellButtons: React.FC = () => {
className={cn(
"shadow-sm border border-border rounded transition-all duration-200 overflow-hidden divide-x divide-border flex",
!isAiButtonOpen && "w-fit",
isAiButtonOpen && "opacity-100 w-full max-w-4xl shadow-lg",
isAiButtonOpen &&
"opacity-100 w-full max-w-4xl shadow-lg shadow-[var(--blue-3)]",
)}
>
{renderBody()}
Expand Down
1 change: 1 addition & 0 deletions frontend/src/core/codemirror/language/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const regexes = pairs.map(
export class SQLLanguageAdapter implements LanguageAdapter {
readonly type = "sql";
readonly defaultCode = `_df = mo.sql(f"""SELECT * FROM """)`;
static fromQuery = (query: string) => `_df = mo.sql(f"""${query.trim()}""")`;

dataframeName = "_df";
lastQuotePrefix: QuotePrefixKind = "f";
Expand Down
3 changes: 1 addition & 2 deletions frontend/src/css/app/codemirror.css
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
/* -- Tooltips: code completion and type hints -- */

#root .cm-tooltip {
border: none;
border-radius: 4px;

@apply bg-popover shadow-smSolid shadow-shade;
@apply bg-popover shadow-sm shadow-shade border-border;
}

/*
Expand Down
Loading

0 comments on commit bd58bf9

Please sign in to comment.