diff --git a/gptscript/datasets.py b/gptscript/datasets.py index 360e765..63b8fd1 100644 --- a/gptscript/datasets.py +++ b/gptscript/datasets.py @@ -2,6 +2,12 @@ from pydantic import BaseModel, field_serializer, field_validator, BeforeValidator +class DatasetMeta(BaseModel): + id: str + name: str + description: str + + class DatasetElementMeta(BaseModel): name: str description: str diff --git a/gptscript/gptscript.py b/gptscript/gptscript.py index 2853a53..392f70b 100644 --- a/gptscript/gptscript.py +++ b/gptscript/gptscript.py @@ -8,7 +8,7 @@ from gptscript.confirm import AuthResponse from gptscript.credentials import Credential, to_credential -from gptscript.datasets import DatasetElementMeta, DatasetElement +from gptscript.datasets import DatasetElementMeta, DatasetElement, DatasetMeta from gptscript.fileinfo import FileInfo from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program from gptscript.opts import GlobalOptions @@ -214,18 +214,24 @@ async def delete_credential(self, context: str = "default", name: str = "") -> s ) # list_datasets returns an array of dataset IDs - async def list_datasets(self) -> List[str]: + async def list_datasets(self) -> List[DatasetMeta]: res = await self._run_basic_command( "datasets", { - "input": json.dumps({"workspaceID": os.getenv("GPTSCRIPT_WORKSPACE_ID")}), + "input": "{}", "datasetTool": self.opts.DatasetTool, "env": self.opts.Env } ) - return json.loads(res) + return [DatasetMeta.model_validate(d) for d in json.loads(res)] - async def add_dataset_elements(self, elements: List[DatasetElement], datasetID: str = "") -> str: + async def add_dataset_elements( + self, + elements: List[DatasetElement], + datasetID: str = "", + name: str = "", + description: str = "" + ) -> str: if not elements: raise ValueError("elements cannot be empty") @@ -233,8 +239,9 @@ async def add_dataset_elements(self, elements: List[DatasetElement], datasetID: "datasets/add-elements", { "input": json.dumps({ - "workspaceID": os.getenv("GPTSCRIPT_WORKSPACE_ID"), "datasetID": datasetID, + "name": name, + "description": description, "elements": [element.model_dump() for element in elements], }), "datasetTool": self.opts.DatasetTool, @@ -250,10 +257,7 @@ async def list_dataset_elements(self, datasetID: str) -> List[DatasetElementMeta res = await self._run_basic_command( "datasets/list-elements", { - "input": json.dumps({ - "workspaceID": os.getenv("GPTSCRIPT_WORKSPACE_ID"), - "datasetID": datasetID, - }), + "input": json.dumps({"datasetID": datasetID}), "datasetTool": self.opts.DatasetTool, "env": self.opts.Env } @@ -270,7 +274,6 @@ async def get_dataset_element(self, datasetID: str, elementName: str) -> Dataset "datasets/get-element", { "input": json.dumps({ - "workspaceID": os.getenv("GPTSCRIPT_WORKSPACE_ID"), "datasetID": datasetID, "name": elementName, }), diff --git a/tests/test_gptscript.py b/tests/test_gptscript.py index a5b16cd..6c2aff9 100644 --- a/tests/test_gptscript.py +++ b/tests/test_gptscript.py @@ -761,36 +761,40 @@ async def test_credentials(gptscript): @pytest.mark.asyncio async def test_datasets(gptscript): workspace_id = await gptscript.create_workspace("directory") - os.environ["GPTSCRIPT_WORKSPACE_ID"] = workspace_id + + new_client = GPTScript(GlobalOptions( + apiKey=os.getenv("OPENAI_API_KEY"), + env=[f"GPTSCRIPT_WORKSPACE_ID={workspace_id}"], + )) # Create dataset - dataset_id = await gptscript.add_dataset_elements([ + dataset_id = await new_client.add_dataset_elements([ DatasetElement(name="element1", contents="element1 contents", description="element1 description"), DatasetElement(name="element2", binaryContents=b"element2 contents", description="element2 description"), - ]) + ], name="test-dataset", description="test dataset description") # Add two more elements - await gptscript.add_dataset_elements([ + await new_client.add_dataset_elements([ DatasetElement(name="element3", contents="element3 contents", description="element3 description"), DatasetElement(name="element4", contents="element3 contents", description="element4 description"), - ], dataset_id) + ], datasetID=dataset_id) # Get the elements - e1 = await gptscript.get_dataset_element(dataset_id, "element1") + e1 = await new_client.get_dataset_element(dataset_id, "element1") assert e1.name == "element1", "Expected element name to match" assert e1.contents == "element1 contents", "Expected element contents to match" assert e1.description == "element1 description", "Expected element description to match" - e2 = await gptscript.get_dataset_element(dataset_id, "element2") + e2 = await new_client.get_dataset_element(dataset_id, "element2") assert e2.name == "element2", "Expected element name to match" assert e2.binaryContents == b"element2 contents", "Expected element contents to match" assert e2.description == "element2 description", "Expected element description to match" - e3 = await gptscript.get_dataset_element(dataset_id, "element3") + e3 = await new_client.get_dataset_element(dataset_id, "element3") assert e3.name == "element3", "Expected element name to match" assert e3.contents == "element3 contents", "Expected element contents to match" assert e3.description == "element3 description", "Expected element description to match" # List elements in the dataset - elements = await gptscript.list_dataset_elements(dataset_id) + elements = await new_client.list_dataset_elements(dataset_id) assert len(elements) == 4, "Expected four elements in the dataset" assert elements[0].name == "element1", "Expected element name to match" assert elements[0].description == "element1 description", "Expected element description to match" @@ -802,9 +806,11 @@ async def test_datasets(gptscript): assert elements[3].description == "element4 description", "Expected element description to match" # List datasets - dataset_ids = await gptscript.list_datasets() - assert len(dataset_ids) > 0, "Expected at least one dataset" - assert dataset_ids[0] == dataset_id, "Expected dataset id to match" + datasets = await new_client.list_datasets() + assert len(datasets) > 0, "Expected at least one dataset" + assert datasets[0].id == dataset_id, "Expected dataset id to match" + assert datasets[0].name == "test-dataset", "Expected dataset name to match" + assert datasets[0].description == "test dataset description", "Expected dataset description to match" await gptscript.delete_workspace(workspace_id)