Skip to content

Commit

Permalink
Allow import in python script tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Shulyaka committed Dec 10, 2024
1 parent 045444a commit 6f93417
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 27 deletions.
83 changes: 57 additions & 26 deletions custom_components/powerllm/tools/python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
full_write_guard,
guarded_iter_unpack_sequence,
guarded_unpack_sequence,
safer_getattr,
)

from ..llm_tools import llm_tool
Expand All @@ -46,11 +47,11 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No
Use `output` dictionary for output and `logger` object for logging.
"""
log_handler = MyHandler()
logger = logging.getLogger("homeassistant.components.python_script.source")
logger = logging.getLogger(f"{__name__}.script")
logger.setLevel(logging.DEBUG)
logger.addHandler(log_handler)
try:
output = execute(hass, "source", source, data, return_response=True)
output = execute(hass, source, data)
except HomeAssistantError as e:
output = {"error": type(e).__name__}
if str(e):
Expand Down Expand Up @@ -102,6 +103,20 @@ def python_code_execute(hass: HomeAssistant, source: str, data: dict | None = No
"parse_date",
"get_age",
}
ALLOWED_IMPORT = {
"math",
"random",
"itertools",
"functools",
"collections",
"json",
"csv",
"re",
"string",
"operator",
"enum",
"types",
}


class ScriptError(HomeAssistantError):
Expand Down Expand Up @@ -138,21 +153,38 @@ def guarded_inplacevar(op: str, target: Any, operand: Any) -> Any:
return op_fun(target, operand)


def execute(hass, filename, source, data=None, return_response=False):
def execute(hass, source, data=None):
"""Execute Python source."""

compiled = compile_restricted_exec(source, filename=filename)
compiled = compile_restricted_exec(source)

if compiled.errors:
_LOGGER.error(
"Error loading script %s: %s", filename, ", ".join(compiled.errors)
)
return None
raise ScriptError("Compile error: %s", ", ".join(compiled.errors))

logger = logging.getLogger(f"{__name__}.script")

if compiled.warnings:
_LOGGER.warning(
"Warning loading script %s: %s", filename, ", ".join(compiled.warnings)
)
logger.warning("Warning loading script: %s", ", ".join(compiled.warnings))

class ProtectedLogger:
"""Replacement logger for compatibility."""

def getLogger(self, *args, **kwargs):
"""Return a specific logger every time."""
return logger

def protected_import(name: str, *args, **kwargs):
if name.split(".")[0] in ALLOWED_IMPORT:
return __import__(name, *args, **kwargs)
if name == "datetime":
return datetime
if name == "time":
return TimeWrapper()
if name == "dt_util":
return dt_util
if name == "logging":
return ProtectedLogger()
raise ImportError(f"Module {name} not found")

def protected_getattr(obj, name, default=None):
"""Restricted method to get attributes."""
Expand All @@ -176,9 +208,10 @@ def protected_getattr(obj, name, default=None):
):
raise ScriptError(f"Not allowed to access {obj.__class__.__name__}.{name}")

return getattr(obj, name, default)
return safer_getattr(obj, name, default)

extra_builtins = {
"__import__": protected_import,
"datetime": datetime,
"sorted": sorted,
"time": TimeWrapper(),
Expand All @@ -189,12 +222,17 @@ def protected_getattr(obj, name, default=None):
"any": any,
"all": all,
"enumerate": enumerate,
"type": type,
"map": map,
"filter": filter,
"reversed": reversed,
"getattr": getattr,
"hasattr": hasattr,
}
builtins = safe_builtins.copy()
builtins.update(utility_builtins)
builtins.update(limited_builtins)
builtins.update(extra_builtins)
logger = logging.getLogger(f"{__name__}.{filename}")
restricted_globals = {
"__builtins__": builtins,
"_print_": StubPrinter,
Expand All @@ -212,12 +250,11 @@ def protected_getattr(obj, name, default=None):
}

try:
_LOGGER.info("Executing %s: %s", filename, data)
_LOGGER.info("Executing script: %s", data)
# pylint: disable-next=exec-used
exec(compiled.code, restricted_globals) # noqa: S102
_LOGGER.debug(
"Output of python_script: `%s`:\n%s",
filename,
"Output of python_script:\n%s",
restricted_globals["output"],
)
# Ensure that we're always returning a dictionary
Expand All @@ -228,17 +265,11 @@ def protected_getattr(obj, name, default=None):
f"Expected `output` to be a dictionary, was {output_type}"
)
except ScriptError as err:
if return_response:
raise ServiceValidationError(f"Error executing script: {err}") from err
logger.error("Error executing script: %s", err)
return None
raise ServiceValidationError(f"Error executing script: {err}") from err
except Exception as err:
if return_response:
raise HomeAssistantError(
f"Error executing script ({type(err).__name__}): {err}"
) from err
logger.exception("Error executing script")
return None
raise HomeAssistantError(
f"Error executing script ({type(err).__name__}): {err}"
) from err

return restricted_globals["output"]

Expand Down
23 changes: 22 additions & 1 deletion tests/test_tool_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_test(hass):
async def test_python_script_tool(
hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component
) -> None:
"""Test function tools with async function."""
"""Test python script tool."""
api = await llm.async_get_api(hass, "powerllm", llm_context)

source = """
Expand All @@ -41,3 +41,24 @@ async def test_python_script_tool(
response = await api.async_call_tool(tool_input)

assert response == {"output": {"test": "passed", "test2": "passed2"}}


async def test_python_script_tool_import(
hass: HomeAssistant, llm_context: llm.LLMContext, mock_init_component
) -> None:
"""Test python script tool with import."""
api = await llm.async_get_api(hass, "powerllm", llm_context)

source = """
import math
output["test"] = math.cos(0)
"""

tool_input = llm.ToolInput(
tool_name="python_code_execute",
tool_args={"source": source},
)

response = await api.async_call_tool(tool_input)

assert response == {"output": {"test": 1.0}}

0 comments on commit 6f93417

Please sign in to comment.