diff --git a/gptscript/datasets.py b/gptscript/datasets.py new file mode 100644 index 0000000..e9de278 --- /dev/null +++ b/gptscript/datasets.py @@ -0,0 +1,25 @@ +from typing import Dict +from pydantic import BaseModel + +class DatasetElementMeta(BaseModel): + name: str + description: str + + +class DatasetElement(BaseModel): + name: str + description: str + contents: str + + +class DatasetMeta(BaseModel): + id: str + name: str + description: str + + +class Dataset(BaseModel): + id: str + name: str + description: str + elements: Dict[str, DatasetElementMeta] diff --git a/gptscript/gptscript.py b/gptscript/gptscript.py index d7891a8..1946361 100644 --- a/gptscript/gptscript.py +++ b/gptscript/gptscript.py @@ -7,6 +7,7 @@ from gptscript.confirm import AuthResponse from gptscript.credentials import Credential, to_credential +from gptscript.datasets import DatasetMeta, Dataset, DatasetElementMeta, DatasetElement from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program from gptscript.opts import GlobalOptions from gptscript.prompt import PromptResponse @@ -210,6 +211,86 @@ async def delete_credential(self, context: str = "default", name: str = "") -> s {"context": [context], "name": name} ) + async def list_datasets(self, workspace: str) -> List[DatasetMeta]: + if workspace == "": + workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"] + + res = await self._run_basic_command( + "datasets", + {"input": "{}", "workspace": workspace, "datasetToolRepo": self.opts.DatasetToolRepo} + ) + return [DatasetMeta.model_validate(d) for d in json.loads(res)] + + async def create_dataset(self, workspace: str, name: str, description: str = "") -> Dataset: + if workspace == "": + workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"] + + if name == "": + raise ValueError("name cannot be empty") + + res = await self._run_basic_command( + "datasets/create", + {"input": json.dumps({"datasetName": name, "datasetDescription": description}), + "workspace": workspace, + "datasetToolRepo": self.opts.DatasetToolRepo} + ) + return Dataset.model_validate_json(res) + + async def add_dataset_element(self, workspace: str, datasetID: str, elementName: str, elementContent: str, + elementDescription: str = "") -> DatasetElementMeta: + if workspace == "": + workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"] + + if datasetID == "": + raise ValueError("datasetID cannot be empty") + elif elementName == "": + raise ValueError("elementName cannot be empty") + elif elementContent == "": + raise ValueError("elementContent cannot be empty") + + res = await self._run_basic_command( + "datasets/add-element", + {"input": json.dumps({"datasetID": datasetID, + "elementName": elementName, + "elementContent": elementContent, + "elementDescription": elementDescription}), + "workspace": workspace, + "datasetToolRepo": self.opts.DatasetToolRepo} + ) + return DatasetElementMeta.model_validate_json(res) + + async def list_dataset_elements(self, workspace: str, datasetID: str) -> List[DatasetElementMeta]: + if workspace == "": + workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"] + + if datasetID == "": + raise ValueError("datasetID cannot be empty") + + res = await self._run_basic_command( + "datasets/list-elements", + {"input": json.dumps({"datasetID": datasetID}), + "workspace": workspace, + "datasetToolRepo": self.opts.DatasetToolRepo} + ) + return [DatasetElementMeta.model_validate(d) for d in json.loads(res)] + + async def get_dataset_element(self, workspace: str, datasetID: str, elementName: str) -> DatasetElement: + if workspace == "": + workspace = os.environ["GPTSCRIPT_WORKSPACE_DIR"] + + if datasetID == "": + raise ValueError("datasetID cannot be empty") + elif elementName == "": + raise ValueError("elementName cannot be empty") + + res = await self._run_basic_command( + "datasets/get-element", + {"input": json.dumps({"datasetID": datasetID, "element": elementName}), + "workspace": workspace, + "datasetToolRepo": self.opts.DatasetToolRepo} + ) + return DatasetElement.model_validate_json(res) + def _get_command(): if os.getenv("GPTSCRIPT_BIN") is not None: diff --git a/gptscript/opts.py b/gptscript/opts.py index 7094dc1..c6eabca 100644 --- a/gptscript/opts.py +++ b/gptscript/opts.py @@ -12,6 +12,7 @@ def __init__( defaultModelProvider: str = "", defaultModel: str = "", cacheDir: str = "", + datasetToolRepo: str = "", env: list[str] = None, ): self.URL = url @@ -21,6 +22,7 @@ def __init__( self.DefaultModel = defaultModel self.DefaultModelProvider = defaultModelProvider self.CacheDir = cacheDir + self.DatasetToolRepo = datasetToolRepo if env is None: env = [f"{k}={v}" for k, v in os.environ.items()] elif isinstance(env, dict): @@ -38,6 +40,7 @@ def merge(self, other: Self) -> Self: cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider cp.CacheDir = other.CacheDir if other.CacheDir != "" else self.CacheDir + cp.DatasetToolRepo = other.DatasetToolRepo if other.DatasetToolRepo != "" else self.DatasetToolRepo cp.Env = (other.Env or []) cp.Env.extend(self.Env or []) return cp @@ -77,8 +80,9 @@ def __init__(self, defaultModelProvider: str = "", defaultModel: str = "", cacheDir: str = "", + datasetToolDir: str = "", ): - super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, env) + super().__init__(url, token, apiKey, baseURL, defaultModelProvider, defaultModel, cacheDir, datasetToolDir, env) self.input = input self.disableCache = disableCache self.subTool = subTool diff --git a/requirements.txt b/requirements.txt index 03e1509..4e5a79b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ setuptools==69.1.1 twine==5.0.0 build==1.1.1 httpx==0.27.0 -pywin32==306; sys_platform == 'win32' \ No newline at end of file +pydantic==2.9.2 +pywin32==306; sys_platform == 'win32' diff --git a/tests/test_gptscript.py b/tests/test_gptscript.py index 1232dc3..f23cb6e 100644 --- a/tests/test_gptscript.py +++ b/tests/test_gptscript.py @@ -4,6 +4,7 @@ import os import platform import subprocess +import tempfile from datetime import datetime, timedelta, timezone from time import sleep @@ -755,3 +756,39 @@ async def test_credentials(gptscript): res = await gptscript.delete_credential(name=name) assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res + +@pytest.mark.asyncio +async def test_datasets(gptscript): + with tempfile.TemporaryDirectory(prefix="py-gptscript_") as tempdir: + dataset_name = str(os.urandom(8).hex()) + + # Create dataset + dataset = await gptscript.create_dataset(tempdir, dataset_name, "this is a test dataset") + assert dataset.id != "", "Expected dataset id to be set" + assert dataset.name == dataset_name, "Expected dataset name to match" + assert dataset.description == "this is a test dataset", "Expected dataset description to match" + assert len(dataset.elements) == 0, "Expected dataset elements to be empty" + + # Add an element + element_meta = await gptscript.add_dataset_element(tempdir, dataset.id, "element1", "element1 contents", "element1 description") + assert element_meta.name == "element1", "Expected element name to match" + assert element_meta.description == "element1 description", "Expected element description to match" + + # Get the element + element = await gptscript.get_dataset_element(tempdir, dataset.id, "element1") + assert element.name == "element1", "Expected element name to match" + assert element.contents == "element1 contents", "Expected element contents to match" + assert element.description == "element1 description", "Expected element description to match" + + # List elements in the dataset + elements = await gptscript.list_dataset_elements(tempdir, dataset.id) + assert len(elements) == 1, "Expected one element in the dataset" + assert elements[0].name == "element1", "Expected element name to match" + assert elements[0].description == "element1 description", "Expected element description to match" + + # List datasets + datasets = await gptscript.list_datasets(tempdir) + assert len(datasets) > 0, "Expected at least one dataset" + assert datasets[0].id == dataset.id, "Expected dataset id to match" + assert datasets[0].name == dataset_name, "Expected dataset name to match" + assert datasets[0].description == "this is a test dataset", "Expected dataset description to match" diff --git a/tox.ini b/tox.ini index da1435a..d048575 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,7 @@ deps = httpx pytest pytest-asyncio + pydantic passenv = OPENAI_API_KEY