diff --git a/frontend/package.json b/frontend/package.json index 92bef6b0499..0344c3fbbc3 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -37,6 +37,7 @@ "@lezer/common": "^1.2.1", "@lezer/highlight": "^1.2.1", "@lezer/lr": "^1.4.2", + "@lezer/python": "^1.1.15", "@marimo-team/marimo-api": "file:../openapi", "@marimo-team/react-slotz": "^0.1.8", "@open-rpc/client-js": "^1.8.1", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index f4a3888f769..05d944e15e1 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -87,6 +87,9 @@ importers: '@lezer/lr': specifier: ^1.4.2 version: 1.4.2 + '@lezer/python': + specifier: ^1.1.15 + version: 1.1.15 '@marimo-team/marimo-api': specifier: file:../openapi version: file:../openapi @@ -1472,8 +1475,8 @@ packages: '@lezer/php@1.0.1': resolution: {integrity: sha512-aqdCQJOXJ66De22vzdwnuC502hIaG9EnPK2rSi+ebXyUd+j7GAX1mRjWZOVOmf3GST1YUfUCu6WXDiEgDGOVwA==} - '@lezer/python@1.1.7': - resolution: {integrity: sha512-RbhKQ9+Y/r/Xv6OcJmETEM5tBFdpdAJRqrgi3akJkWBLCuiAaLP/jKdYzu+ICljaSXPCQeznrv+r9HUEnjq3HQ==} + '@lezer/python@1.1.15': + resolution: {integrity: sha512-aVQ43m2zk4FZYedCqL0KHPEUsqZOrmAvRhkhHlVPnDD1HODDyyQv5BRIuod4DadkgBEZd53vQOtXTonNbEgjrQ==} '@lezer/rust@1.0.1': resolution: {integrity: sha512-j+ToFKM6Wpglv3OQ4ebHYdYIMT2dh0ziCCV0rTf47AWiHOVhR0WjaKrBq+yuvDQNEhr5sxPxVI7+naJIgpqcsQ==} @@ -9477,7 +9480,7 @@ snapshots: '@codemirror/language': 6.10.3 '@codemirror/state': 6.4.1 '@lezer/common': 1.2.1 - '@lezer/python': 1.1.7 + '@lezer/python': 1.1.15 transitivePeerDependencies: - '@codemirror/view' @@ -10080,8 +10083,9 @@ snapshots: '@lezer/highlight': 1.2.1 '@lezer/lr': 1.4.2 - '@lezer/python@1.1.7': + '@lezer/python@1.1.15': dependencies: + '@lezer/common': 1.2.1 '@lezer/highlight': 1.2.1 '@lezer/lr': 1.4.2 diff --git a/frontend/src/components/editor/Output.tsx b/frontend/src/components/editor/Output.tsx index d049875ad83..c758ec5cd2b 100644 --- a/frontend/src/components/editor/Output.tsx +++ b/frontend/src/components/editor/Output.tsx @@ -48,9 +48,10 @@ type MimeBundleOrTuple = MimeBundle | [MimeBundle, { [key: string]: unknown }]; */ export const OutputRenderer: React.FC<{ message: Pick; + cellId?: CellId; onRefactorWithAI?: (opts: { prompt: string }) => void; }> = memo((props) => { - const { message, onRefactorWithAI } = props; + const { message, onRefactorWithAI, cellId } = props; const { theme } = useTheme(); // Memoize parsing the json data @@ -123,7 +124,7 @@ export const OutputRenderer: React.FC<{ case "application/vnd.marimo+error": invariant(Array.isArray(data), "Expected array data"); - return ; + return ; case "application/vnd.marimo+traceback": invariant( @@ -190,7 +191,8 @@ OutputRenderer.displayName = "OutputRenderer"; const MimeBundleOutputRenderer: React.FC<{ channel: OutputMessage["channel"]; data: MimeBundleOrTuple; -}> = memo(({ data, channel }) => { + cellId?: CellId; +}> = memo(({ data, channel, cellId }) => { const mimebundle = Array.isArray(data) ? data[0] : data; // If there is none, return null @@ -203,6 +205,7 @@ const MimeBundleOutputRenderer: React.FC<{ if (Object.keys(mimebundle).length === 1) { return ( - + ); diff --git a/frontend/src/components/editor/chrome/panels/error-panel.tsx b/frontend/src/components/editor/chrome/panels/error-panel.tsx index 619769f7f7e..d87f82f1490 100644 --- a/frontend/src/components/editor/chrome/panels/error-panel.tsx +++ b/frontend/src/components/editor/chrome/panels/error-panel.tsx @@ -21,7 +21,11 @@ export const ErrorsPanel: React.FC = () => {
- +
))} diff --git a/frontend/src/components/editor/errors/auto-fix.tsx b/frontend/src/components/editor/errors/auto-fix.tsx new file mode 100644 index 00000000000..df3336ffdef --- /dev/null +++ b/frontend/src/components/editor/errors/auto-fix.tsx @@ -0,0 +1,55 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { Button } from "@/components/ui/button"; +import { Tooltip } from "@/components/ui/tooltip"; +import { useCellActions, notebookAtom } from "@/core/cells/cells"; +import type { CellId } from "@/core/cells/ids"; +import { getAutoFixes } from "@/core/errors/errors"; +import type { MarimoError } from "@/core/kernel/messages"; +import { store } from "@/core/state/jotai"; +import { LightbulbIcon } from "lucide-react"; + +export const AutoFixButton = ({ + errors, + cellId, +}: { errors: MarimoError[]; cellId: CellId }) => { + const { createNewCell } = useCellActions(); + const autoFixes = errors.flatMap((error) => getAutoFixes(error)); + + if (autoFixes.length === 0) { + return null; + } + + // TODO: Add a dropdown menu with the auto-fixes, when we need to support + // multiple fixes. + const firstFix = autoFixes[0]; + + return ( + + + + ); +}; diff --git a/frontend/src/components/editor/output/ConsoleOutput.tsx b/frontend/src/components/editor/output/ConsoleOutput.tsx index 93052fbac4a..7e07f7e986b 100644 --- a/frontend/src/components/editor/output/ConsoleOutput.tsx +++ b/frontend/src/components/editor/output/ConsoleOutput.tsx @@ -138,6 +138,7 @@ export const ConsoleOutput = (props: Props): React.ReactNode => { return ( diff --git a/frontend/src/components/editor/output/MarimoErrorOutput.tsx b/frontend/src/components/editor/output/MarimoErrorOutput.tsx index 36d4126638f..b2cc454005e 100644 --- a/frontend/src/components/editor/output/MarimoErrorOutput.tsx +++ b/frontend/src/components/editor/output/MarimoErrorOutput.tsx @@ -15,6 +15,7 @@ import { import { Fragment } from "react"; import { CellLinkError } from "../links/cell-link"; import type { CellId } from "@/core/cells/ids"; +import { AutoFixButton } from "../errors/auto-fix"; const Tip = (props: { className?: string; @@ -31,6 +32,7 @@ const Tip = (props: { }; interface Props { + cellId: CellId | undefined; errors: MarimoError[]; className?: string; } @@ -40,6 +42,7 @@ interface Props { */ export const MarimoErrorOutput = ({ errors, + cellId, className, }: Props): JSX.Element => { let titleContents = "This cell wasn't run because it has errors"; @@ -55,7 +58,7 @@ export const MarimoErrorOutput = ({ case "cycle": return ( -

{"This cell is in a cycle:"}

+

{"This cell is in a cycle:"}

    {error.edges_with_vars.map((edge) => (
  • @@ -76,9 +79,7 @@ export const MarimoErrorOutput = ({ case "multiple-defs": return ( -

    - {`The variable '${error.name}' was defined by another cell:`} -

    +

    {`The variable '${error.name}' was defined by another cell:`}

      {error.cells.map((cid) => (
    • @@ -97,7 +98,7 @@ export const MarimoErrorOutput = ({ case "delete-nonlocal": return ( -
      +
      {`The variable '${error.name}' can't be deleted because it was defined by another cell (`} {")"} @@ -195,7 +196,7 @@ export const MarimoErrorOutput = ({ }); const title = ( - + {titleContents} ); @@ -204,7 +205,7 @@ export const MarimoErrorOutput = ({ @@ -212,6 +213,7 @@ export const MarimoErrorOutput = ({
        {msgs}
      + {cellId && }
      ); }; diff --git a/frontend/src/core/cells/__tests__/runs.test.ts b/frontend/src/core/cells/__tests__/runs.test.ts index f0e9f70757a..e5145d5e9fd 100644 --- a/frontend/src/core/cells/__tests__/runs.test.ts +++ b/frontend/src/core/cells/__tests__/runs.test.ts @@ -15,7 +15,7 @@ const { reducer, initialState, isPureMarkdown } = exportedForTesting; function first(map: Map | undefined): T { invariant(map, "Map is undefined"); - return map.values().next().value; + return map.values().next().value as T; } describe("RunsState Reducer", () => { diff --git a/frontend/src/core/errors/__tests__/errors.test.ts b/frontend/src/core/errors/__tests__/errors.test.ts new file mode 100644 index 00000000000..32da1c60873 --- /dev/null +++ b/frontend/src/core/errors/__tests__/errors.test.ts @@ -0,0 +1,62 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { describe, it, expect } from "vitest"; +import { getAutoFixes, getImportCode } from "../errors"; +import type { MarimoError } from "@/core/kernel/messages"; + +describe("getImportCode", () => { + it("returns simple import for same name", () => { + expect(getImportCode("json")).toBe("import json"); + expect(getImportCode("math")).toBe("import math"); + }); + + it("returns aliased import for different names", () => { + expect(getImportCode("np")).toBe("import numpy as np"); + expect(getImportCode("pd")).toBe("import pandas as pd"); + expect(getImportCode("plt")).toBe("import matplotlib.pyplot as plt"); + }); +}); + +describe("getAutoFixes", () => { + it("returns wrap in function fix for multiple-defs error", () => { + const error: MarimoError = { + type: "multiple-defs", + name: "foo", + cells: ["foo"], + }; + + const fixes = getAutoFixes(error); + expect(fixes).toHaveLength(1); + expect(fixes[0].title).toBe("Wrap in a function"); + }); + + it("returns import fix for NameError with known import", () => { + const error: MarimoError = { + type: "exception", + exception_type: "NameError", + msg: "name 'np' is not defined", + }; + + const fixes = getAutoFixes(error); + expect(fixes).toHaveLength(1); + expect(fixes[0].title).toBe("Add 'import numpy as np'"); + }); + + it("returns no fixes for NameError with unknown import", () => { + const error: MarimoError = { + type: "exception", + exception_type: "NameError", + msg: "name 'unknown_module' is not defined", + }; + + expect(getAutoFixes(error)).toHaveLength(0); + }); + + it("returns no fixes for other error types", () => { + const error: MarimoError = { + type: "syntax", + msg: "invalid syntax", + }; + + expect(getAutoFixes(error)).toHaveLength(0); + }); +}); diff --git a/frontend/src/core/errors/__tests__/utils.test.ts b/frontend/src/core/errors/__tests__/utils.test.ts new file mode 100644 index 00000000000..0dfc68db9fc --- /dev/null +++ b/frontend/src/core/errors/__tests__/utils.test.ts @@ -0,0 +1,108 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { describe, test, expect } from "vitest"; +import { wrapInFunction } from "../utils"; + +describe("wrapInFunction", () => { + test("wraps single line expression", () => { + const input = "1 + 2"; + const expected = `def _(): + return 1 + 2 + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("wraps multiline expression", () => { + const input = `(1 + +2)`; + const expected = `def _(): + return (1 + + 2) + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("wraps complex multiline expression", () => { + const input = `foo( + bar(1, 2), + baz(3, 4) +)`; + const expected = `def _(): + return foo( + bar(1, 2), + baz(3, 4) + ) + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("preserves existing indentation", () => { + const input = `def foo(): + x = 1 + y = 2`; + const expected = `def _(): + def foo(): + x = 1 + y = 2 + return + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("multi-line parentheses", () => { + const input = `x = "foo" +y = "bar" +(alt.Chart(df).mark_line().encode( + x="x", + y="y", +))`; + const expected = `def _(): + x = "foo" + y = "bar" + return (alt.Chart(df).mark_line().encode( + x="x", + y="y", + )) + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("preserves empty lines", () => { + const input = `x = 1 + +y = 2`; + const expected = `def _(): + x = 1 + + y = 2 + return + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); + + test("handles code with existing indentation", () => { + const input = `if True: + x = 1 + y = 2`; + const expected = `def _(): + if True: + x = 1 + y = 2 + return + + +_()`; + expect(wrapInFunction(input)).toBe(expected); + }); +}); diff --git a/frontend/src/core/errors/errors.ts b/frontend/src/core/errors/errors.ts new file mode 100644 index 00000000000..369266db426 --- /dev/null +++ b/frontend/src/core/errors/errors.ts @@ -0,0 +1,96 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import type { EditorView } from "@codemirror/view"; +import type { CellId } from "../cells/ids"; +import type { MarimoError } from "../kernel/messages"; +import { wrapInFunction } from "./utils"; +import { invariant } from "@/utils/invariant"; + +export interface AutoFix { + title: string; + description: string; + onFix: (ctx: { + addCodeBelow: (code: string) => void; + editor: EditorView | undefined; + cellId: CellId; + }) => Promise; +} + +export function getAutoFixes(error: MarimoError): AutoFix[] { + if (error.type === "multiple-defs") { + return [ + { + title: "Wrap in a function", + description: + "Make this cell's variables local by wrapping the cell in a function.", + onFix: async (ctx) => { + invariant(ctx.editor, "Editor is null"); + const code = wrapInFunction(ctx.editor.state.doc.toString()); + ctx.editor.dispatch({ + changes: { + from: 0, + to: ctx.editor.state.doc.length, + insert: code, + }, + }); + }, + }, + ]; + } + + if (error.type === "exception" && error.exception_type === "NameError") { + const name = error.msg.match(/name '(.+)' is not defined/)?.[1]; + + if (!name || !(name in IMPORT_MAPPING)) { + return []; + } + + const cellCode = getImportCode(name); + + return [ + { + title: `Add '${cellCode}'`, + description: "Add a new cell for the missing import", + onFix: async (ctx) => { + ctx.addCodeBelow(cellCode); + }, + }, + ]; + } + + return []; +} + +export function getImportCode(name: string): string { + const moduleName = IMPORT_MAPPING[name]; + return moduleName === name + ? `import ${moduleName}` + : `import ${moduleName} as ${name}`; +} + +const IMPORT_MAPPING: Record = { + // libraries + mo: "marimo", + alt: "altair", + bokeh: "bokeh", + dask: "dask", + np: "numpy", + pd: "pandas", + pl: "polars", + plotly: "plotly", + plt: "matplotlib.pyplot", + px: "plotly.express", + scipy: "scipy", + sk: "sklearn", + sns: "seaborn", + stats: "scipy.stats", + tf: "tensorflow", + torch: "torch", + xr: "xarray", + // built-ins + dt: "datetime", + json: "json", + math: "math", + os: "os", + re: "re", + sys: "sys", +}; diff --git a/frontend/src/core/errors/utils.ts b/frontend/src/core/errors/utils.ts new file mode 100644 index 00000000000..bfc3c2c6666 --- /dev/null +++ b/frontend/src/core/errors/utils.ts @@ -0,0 +1,67 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { parser } from "@lezer/python"; +import type { SyntaxNode, Tree } from "@lezer/common"; + +function getLastStatement(tree: Tree): SyntaxNode | null { + let lastStmt: SyntaxNode | null = null; + const cursor = tree.cursor(); + + do { + if (cursor.name === "ExpressionStatement") { + lastStmt = cursor.node; + } + } while (cursor.next()); + + return lastStmt; +} + +export function wrapInFunction(code: string) { + const lines = code.split("\n"); + const indentation = " "; + + const tree = parser.parse(code); + const lastStmt = getLastStatement(tree); + + if (!lastStmt) { + return [ + "def _():", + ...indentLines(lines, indentation), + `${indentation}return`, + "", + "", + "_()", + ].join("\n"); + } + + const codeBeforeLastStmt = code.slice(0, lastStmt.from).trim(); + const codeRest = code.slice(lastStmt.from).trim(); + const linesBeforeLastStmt = codeBeforeLastStmt.split("\n"); + const linesRest = codeRest.split("\n"); + + return [ + "def _():", + ...indentLines(linesBeforeLastStmt, indentation), + `${indentation}return ${linesRest[0]}`, + ...indentLines(linesRest.slice(1), indentation), + "", + "", + "_()", + ].join("\n"); +} + +function indentLines(lines: string[], indentation: string): string[] { + if (lines.length === 1 && lines[0] === "") { + return []; + } + + const indentedLines = []; + for (const line of lines) { + if (line === "") { + indentedLines.push(""); + } else { + indentedLines.push(indentation + line); + } + } + return indentedLines; +} diff --git a/marimo/_smoke_tests/errors/autofix.py b/marimo/_smoke_tests/errors/autofix.py new file mode 100644 index 00000000000..9a8e813422d --- /dev/null +++ b/marimo/_smoke_tests/errors/autofix.py @@ -0,0 +1,39 @@ +import marimo + +__generated_with = "0.10.9" +app = marimo.App(width="medium") + + +@app.cell +def _(): + x = 1 + return (x,) + + +@app.cell +def _(): + x = 2 + x + return (x,) + + +@app.cell +def _(): + x = 3 + return (x,) + + +@app.cell +def _(mo): + mo.md() + return + + +@app.cell +def _(alt): + alt.Chart() + return + + +if __name__ == "__main__": + app.run() diff --git a/tests/_runtime/snapshots/docstrings_class.txt b/tests/_runtime/snapshots/docstrings_class.txt index fce203d2687..b3138a86aa4 100644 --- a/tests/_runtime/snapshots/docstrings_class.txt +++ b/tests/_runtime/snapshots/docstrings_class.txt @@ -1,4 +1,4 @@ -
      class MyClass()
      +
      class MyClass()
       
      Some docstring for the class. diff --git a/tests/_runtime/snapshots/docstrings_function.txt b/tests/_runtime/snapshots/docstrings_function.txt index cfc19f32421..f7d462a2e8e 100644 --- a/tests/_runtime/snapshots/docstrings_function.txt +++ b/tests/_runtime/snapshots/docstrings_function.txt @@ -1,4 +1,4 @@ -
      def my_func(arg1, arg2)
      +
      def my_func(arg1, arg2)
       
      This is a simple docstring for a function. \ No newline at end of file diff --git a/tests/_runtime/snapshots/docstrings_function_external.txt b/tests/_runtime/snapshots/docstrings_function_external.txt index 4c191ca40d8..f97673b1d31 100644 --- a/tests/_runtime/snapshots/docstrings_function_external.txt +++ b/tests/_runtime/snapshots/docstrings_function_external.txt @@ -1,4 +1,4 @@ -
      def my_func(arg1, arg2)
      +
      def my_func(arg1, arg2)
       
      diff --git a/tests/_runtime/snapshots/docstrings_function_google.txt b/tests/_runtime/snapshots/docstrings_function_google.txt index 326801eeceb..bde9313829a 100644 --- a/tests/_runtime/snapshots/docstrings_function_google.txt +++ b/tests/_runtime/snapshots/docstrings_function_google.txt @@ -1,4 +1,4 @@ -
      def my_func(arg1, arg2)
      +
      def my_func(arg1, arg2)
       

      Arguments