Skip to content

Commit

Permalink
Allow print in python script tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Shulyaka committed Dec 10, 2024
1 parent 6f93417 commit 0f27812
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 26 deletions.
70 changes: 44 additions & 26 deletions custom_components/powerllm/tools/python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
full_write_guard,
guarded_iter_unpack_sequence,
guarded_unpack_sequence,
raise_,
safer_getattr,
)

Expand All @@ -44,21 +45,27 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No
"""Execute python code in a restricted environment.
Use this tool for math calculations among other things.
Use `output` dictionary for output and `logger` object for logging.
Use `output` dictionary for output, `logger` object for logging,
and `print` for printing.
"""
log_handler = MyHandler()
logger = logging.getLogger(f"{__name__}.script")
logger.setLevel(logging.DEBUG)
logger.addHandler(log_handler)
try:
output = execute(hass, source, data)
printed = []
output = execute(hass, source, data, printed)
except HomeAssistantError as e:
output = {"error": type(e).__name__}
if str(e):
output["error_text"] = str(e)
logger.removeHandler(log_handler)

result = {"output": output}
if printed:
result["printed"] = "".join(printed)
if not output:
del result["output"]
logs = [
{"level": record.levelno, "msg": record.getMessage()}
for record in log_handler.logs
Expand Down Expand Up @@ -153,18 +160,26 @@ def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any:
return op_fun(target, operand)


def execute(hass, source, data=None):
def execute(hass, source, data=None, printed: list[str] | None = None):
"""Execute Python source."""

if printed is None:
printed = []

compiled = compile_restricted_exec(source)

if compiled.errors:
raise ScriptError("Compile error: %s", ", ".join(compiled.errors))

logger = logging.getLogger(f"{__name__}.script")

if compiled.warnings:
logger.warning("Warning loading script: %s", ", ".join(compiled.warnings))
compiled_warnings = [
w
for w in compiled.warnings.copy()
if "Prints, but never reads 'printed' variable." not in w
]
if compiled_warnings:
logger.warning("Warning loading script: %s", ", ".join(compiled_warnings))

class ProtectedLogger:
"""Replacement logger for compatibility."""
Expand All @@ -173,6 +188,26 @@ def getLogger(self, *args, **kwargs):
"""Return a specific logger every time."""
return logger

class PrintCollector:
"""Collect written text, and return it when called."""

def __init__(self, _getattr_=None):
self._getattr_ = _getattr_

def write(self, text):
printed.append(text)

def __call__(self):
return "".join(printed)

def _call_print(self, *objects, **kwargs):
if kwargs.get("file", None) is None:
kwargs["file"] = self
else:
self._getattr_(kwargs["file"], "write")

print(*objects, **kwargs)

def protected_import(name: str, *args, **kwargs):
if name.split(".")[0] in ALLOWED_IMPORT:
return __import__(name, *args, **kwargs)
Expand All @@ -186,7 +221,7 @@ def protected_import(name: str, *args, **kwargs):
return ProtectedLogger()
raise ImportError(f"Module {name} not found")

def protected_getattr(obj, name, default=None):
def protected_getattr(obj, name, default=raise_):
"""Restricted method to get attributes."""
if name.startswith("async_"):
raise ScriptError("Not allowed to access async methods")
Expand Down Expand Up @@ -235,7 +270,7 @@ def protected_getattr(obj, name, default=None):
builtins.update(extra_builtins)
restricted_globals = {
"__builtins__": builtins,
"_print_": StubPrinter,
"_print_": PrintCollector,
"_getattr_": protected_getattr,
"_write_": full_write_guard,
"_getiter_": iter,
Expand All @@ -254,16 +289,10 @@ def protected_getattr(obj, name, default=None):
# pylint: disable-next=exec-used
exec(compiled.code, restricted_globals) # noqa: S102
_LOGGER.debug(
"Output of python_script:\n%s",
"Output of python_script:\n%s\n%s",
restricted_globals["output"],
"".join(printed),
)
# Ensure that we're always returning a dictionary
if not isinstance(restricted_globals["output"], dict):
output_type = type(restricted_globals["output"])
restricted_globals["output"] = {}
raise ScriptError( # noqa: TRY301
f"Expected `output` to be a dictionary, was {output_type}"
)
except ScriptError as err:
raise ServiceValidationError(f"Error executing script: {err}") from err
except Exception as err:
Expand All @@ -274,17 +303,6 @@ def protected_getattr(obj, name, default=None):
return restricted_globals["output"]


class StubPrinter:
"""Class to handle printing inside scripts."""

def __init__(self, _getattr_):
"""Initialize our printer."""

def _call_print(self, *objects, **kwargs):
"""Print text."""
_LOGGER.warning("Don't use print() inside scripts. Use logger.info() instead")


class TimeWrapper:
"""Wrap the time module."""

Expand Down
24 changes: 24 additions & 0 deletions tests/test_tool_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,27 @@ async def test_python_script_tool_import(
response = await api.async_call_tool(tool_input)

assert response == {"output": {"test": 1.0}}


async def test_python_script_tool_print(
hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component
) -> None:
"""Test print in python script tool."""
api = await llm.async_get_api(hass, "powerllm", llm_context)

source = """
print("test1")
def test2():
print("test2")
test2()
"""

tool_input = llm.ToolInput(
tool_name="python_code_execute",
tool_args={"source": source},
)

response = await api.async_call_tool(tool_input)

assert response == {"printed": "test1\ntest2\n"}

0 comments on commit 0f27812

Please sign in to comment.