diff --git a/custom_components/powerllm/tools/python_code.py b/custom_components/powerllm/tools/python_code.py index 21f0984..cdf9266 100644 --- a/custom_components/powerllm/tools/python_code.py +++ b/custom_components/powerllm/tools/python_code.py @@ -21,6 +21,7 @@ full_write_guard, guarded_iter_unpack_sequence, guarded_unpack_sequence, + safer_getattr, ) from ..llm_tools import llm_tool @@ -46,11 +47,11 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No Use `output` dictionary for output and `logger` object for logging. """ log_handler = MyHandler() - logger = logging.getLogger("homeassistant.components.python_script.source") + logger = logging.getLogger(f"{__name__}.script") logger.setLevel(logging.DEBUG) logger.addHandler(log_handler) try: - output = execute(hass, "source", source, data, return_response=True) + output = execute(hass, source, data) except HomeAssistantError as e: output = {"error": type(e).__name__} if str(e): @@ -102,6 +103,20 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No "parse_date", "get_age", } +ALLOWED_IMPORT = { + "math", + "random", + "itertools", + "functools", + "collections", + "json", + "csv", + "re", + "string", + "operator", + "enum", + "types", +} class ScriptError(HomeAssistantError): @@ -138,21 +153,38 @@ def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any: return op_fun(target, operand) -def execute(hass, filename, source, data=None, return_response=False): +def execute(hass, source, data=None): """Execute Python source.""" - compiled = compile_restricted_exec(source, filename=filename) + compiled = compile_restricted_exec(source) if compiled.errors: - _LOGGER.error( - "Error loading script %s: %s", filename, ", ".join(compiled.errors) - ) - return None + raise ScriptError("Compile error: %s", ", ".join(compiled.errors)) + + logger = logging.getLogger(f"{__name__}.script") if compiled.warnings: - _LOGGER.warning( - "Warning loading script %s: %s", filename, ", ".join(compiled.warnings) - ) + logger.warning("Warning loading script: %s", ", ".join(compiled.warnings)) + + class ProtectedLogger: + """Replacement logger for compatibility.""" + + def getLogger(self, *args, **kwargs): + """Return a specific logger every time.""" + return logger + + def protected_import(name: str, *args, **kwargs): + if name.split(".")[0] in ALLOWED_IMPORT: + return __import__(name, *args, **kwargs) + if name == "datetime": + return datetime + if name == "time": + return TimeWrapper() + if name == "dt_util": + return dt_util + if name == "logging": + return ProtectedLogger() + raise ImportError(f"Module {name} not found") def protected_getattr(obj, name, default=None): """Restricted method to get attributes.""" @@ -176,9 +208,10 @@ def protected_getattr(obj, name, default=None): ): raise ScriptError(f"Not allowed to access {obj.__class__.__name__}.{name}") - return getattr(obj, name, default) + return safer_getattr(obj, name, default) extra_builtins = { + "__import__": protected_import, "datetime": datetime, "sorted": sorted, "time": TimeWrapper(), @@ -189,12 +222,17 @@ def protected_getattr(obj, name, default=None): "any": any, "all": all, "enumerate": enumerate, + "type": type, + "map": map, + "filter": filter, + "reversed": reversed, + "getattr": getattr, + "hasattr": hasattr, } builtins = safe_builtins.copy() builtins.update(utility_builtins) builtins.update(limited_builtins) builtins.update(extra_builtins) - logger = logging.getLogger(f"{__name__}.{filename}") restricted_globals = { "__builtins__": builtins, "_print_": StubPrinter, @@ -212,12 +250,11 @@ def protected_getattr(obj, name, default=None): } try: - _LOGGER.info("Executing %s: %s", filename, data) + _LOGGER.info("Executing script: %s", data) # pylint: disable-next=exec-used exec(compiled.code, restricted_globals) # noqa: S102 _LOGGER.debug( - "Output of python_script: `%s`:\n%s", - filename, + "Output of python_script:\n%s", restricted_globals["output"], ) # Ensure that we're always returning a dictionary @@ -228,17 +265,11 @@ def protected_getattr(obj, name, default=None): f"Expected `output` to be a dictionary, was {output_type}" ) except ScriptError as err: - if return_response: - raise ServiceValidationError(f"Error executing script: {err}") from err - logger.error("Error executing script: %s", err) - return None + raise ServiceValidationError(f"Error executing script: {err}") from err except Exception as err: - if return_response: - raise HomeAssistantError( - f"Error executing script ({type(err).__name__}): {err}" - ) from err - logger.exception("Error executing script") - return None + raise HomeAssistantError( + f"Error executing script ({type(err).__name__}): {err}" + ) from err return restricted_globals["output"] diff --git a/tests/test_tool_python_code.py b/tests/test_tool_python_code.py index 009cb1d..9d3cda8 100644 --- a/tests/test_tool_python_code.py +++ b/tests/test_tool_python_code.py @@ -25,7 +25,7 @@ def test_test(hass): async def test_python_script_tool( hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component ) -> None: - """Test function tools with async function.""" + """Test python script tool.""" api = await llm.async_get_api(hass, "powerllm", llm_context) source = """ @@ -41,3 +41,24 @@ async def test_python_script_tool( response = await api.async_call_tool(tool_input) assert response == {"output": {"test": "passed", "test2": "passed2"}} + + +async def test_python_script_tool_import( + hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component +) -> None: + """Test python script tool with import.""" + api = await llm.async_get_api(hass, "powerllm", llm_context) + + source = """ +import math +output["test"] = math.cos(0) + """ + + tool_input = llm.ToolInput( + tool_name="python_code_execute", + tool_args={"source": source}, + ) + + response = await api.async_call_tool(tool_input) + + assert response == {"output": {"test": 1.0}}