diff --git a/custom_components/powerllm/tools/python_code.py b/custom_components/powerllm/tools/python_code.py index cdf9266..fb8e37d 100644 --- a/custom_components/powerllm/tools/python_code.py +++ b/custom_components/powerllm/tools/python_code.py @@ -21,6 +21,7 @@ full_write_guard, guarded_iter_unpack_sequence, guarded_unpack_sequence, + raise_, safer_getattr, ) @@ -44,14 +45,16 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No """Execute python code in a restricted environment. Use this tool for math calculations among other things. - Use `output` dictionary for output and `logger` object for logging. + Use `output` dictionary for output, `logger` object for logging, + and `print` for printing. """ log_handler = MyHandler() logger = logging.getLogger(f"{__name__}.script") logger.setLevel(logging.DEBUG) logger.addHandler(log_handler) try: - output = execute(hass, source, data) + printed = [] + output = execute(hass, source, data, printed) except HomeAssistantError as e: output = {"error": type(e).__name__} if str(e): @@ -59,6 +62,10 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No logger.removeHandler(log_handler) result = {"output": output} + if printed: + result["printed"] = "".join(printed) + if not output: + del result["output"] logs = [ {"level": record.levelno, "msg": record.getMessage()} for record in log_handler.logs @@ -153,9 +160,12 @@ def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any: return op_fun(target, operand) -def execute(hass, source, data=None): +def execute(hass, source, data=None, printed: list[str] | None = None): """Execute Python source.""" + if printed is None: + printed = [] + compiled = compile_restricted_exec(source) if compiled.errors: @@ -163,8 +173,13 @@ def execute(hass, source, data=None): logger = logging.getLogger(f"{__name__}.script") - if compiled.warnings: - logger.warning("Warning loading script: %s", ", ".join(compiled.warnings)) + compiled_warnings = [ + w + for w in compiled.warnings.copy() + if "Prints, but never reads 'printed' variable." not in w + ] + if compiled_warnings: + logger.warning("Warning loading script: %s", ", ".join(compiled_warnings)) class ProtectedLogger: """Replacement logger for compatibility.""" @@ -173,6 +188,26 @@ def getLogger(self, *args, **kwargs): """Return a specific logger every time.""" return logger + class PrintCollector: + """Collect written text, and return it when called.""" + + def __init__(self, _getattr_=None): + self._getattr_ = _getattr_ + + def write(self, text): + printed.append(text) + + def __call__(self): + return "".join(printed) + + def _call_print(self, *objects, **kwargs): + if kwargs.get("file", None) is None: + kwargs["file"] = self + else: + self._getattr_(kwargs["file"], "write") + + print(*objects, **kwargs) + def protected_import(name: str, *args, **kwargs): if name.split(".")[0] in ALLOWED_IMPORT: return __import__(name, *args, **kwargs) @@ -186,7 +221,7 @@ def protected_import(name: str, *args, **kwargs): return ProtectedLogger() raise ImportError(f"Module {name} not found") - def protected_getattr(obj, name, default=None): + def protected_getattr(obj, name, default=raise_): """Restricted method to get attributes.""" if name.startswith("async_"): raise ScriptError("Not allowed to access async methods") @@ -235,7 +270,7 @@ def protected_getattr(obj, name, default=None): builtins.update(extra_builtins) restricted_globals = { "__builtins__": builtins, - "_print_": StubPrinter, + "_print_": PrintCollector, "_getattr_": protected_getattr, "_write_": full_write_guard, "_getiter_": iter, @@ -254,16 +289,10 @@ def protected_getattr(obj, name, default=None): # pylint: disable-next=exec-used exec(compiled.code, restricted_globals) # noqa: S102 _LOGGER.debug( - "Output of python_script:\n%s", + "Output of python_script:\n%s\n%s", restricted_globals["output"], + "".join(printed), ) - # Ensure that we're always returning a dictionary - if not isinstance(restricted_globals["output"], dict): - output_type = type(restricted_globals["output"]) - restricted_globals["output"] = {} - raise ScriptError( # noqa: TRY301 - f"Expected `output` to be a dictionary, was {output_type}" - ) except ScriptError as err: raise ServiceValidationError(f"Error executing script: {err}") from err except Exception as err: @@ -274,17 +303,6 @@ def protected_getattr(obj, name, default=None): return restricted_globals["output"] -class StubPrinter: - """Class to handle printing inside scripts.""" - - def __init__(self, _getattr_): - """Initialize our printer.""" - - def _call_print(self, *objects, **kwargs): - """Print text.""" - _LOGGER.warning("Don't use print() inside scripts. Use logger.info() instead") - - class TimeWrapper: """Wrap the time module.""" diff --git a/tests/test_tool_python_code.py b/tests/test_tool_python_code.py index 9d3cda8..1aa13f0 100644 --- a/tests/test_tool_python_code.py +++ b/tests/test_tool_python_code.py @@ -62,3 +62,27 @@ async def test_python_script_tool_import( response = await api.async_call_tool(tool_input) assert response == {"output": {"test": 1.0}} + + +async def test_python_script_tool_print( + hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component +) -> None: + """Test print in python script tool.""" + api = await llm.async_get_api(hass, "powerllm", llm_context) + + source = """ +print("test1") +def test2(): + print("test2") + +test2() + """ + + tool_input = llm.ToolInput( + tool_name="python_code_execute", + tool_args={"source": source}, + ) + + response = await api.async_call_tool(tool_input) + + assert response == {"printed": "test1\ntest2\n"}