Skip to content

Commit

Permalink
Merge pull request #46 from h-0-0/h-0-0-dev
Browse files Browse the repository at this point in the history
Fixed Issue#37
  • Loading branch information
h-0-0 authored May 16, 2024
2 parents 9c1481b + a83e5f7 commit b51728a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/slune/savers/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def __init__(self, logger_instance: BaseLogger, ext: str = '.csv', params: dict
Args:
- logger_instance (BaseLogger): Instance of a logger class that inherits from BaseLogger.
- ext (str): Extension of the file where we will store the results, default is '.csv'.
- params (dict): (key,value) pairs we would like to use for our methods, default is None.
If None, we will create a path using the parameters given in the log.
- params (dict): (key,value) pairs we would like to generate a path for, default is None.
- root_dir (str, optional): Path to the root directory where we will store the '.ext files, default is './slune_results'.
"""
Expand Down Expand Up @@ -231,3 +230,11 @@ def getset_current_path(self, params:dict=None, save:bool=True) -> str:
raise Exception('SaverExt.current_path is None, please provide parameters to create a path.')
return self.current_path

def get_current_params(self) -> dict:
""" Getter function for the current_params attribute.
Returns:
- current_params (dict): (key,value) pairs we would like to use for our methods.
"""
return self.current_params
11 changes: 11 additions & 0 deletions tests/test_savers_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,17 @@ def test_nonexistent_params(self):
actual_path = saver.getset_current_path(params)
self.assertEqual(actual_path, expected_path)

class TestSaverCsvGetCurrentParams(unittest.TestCase):

def test_no_params(self):
saver = SaverCsv(LoggerDefault())
self.assertEqual(None, saver.get_current_params())

def test_with_params(self):
params = {'--param1': '1', '--param2': 'True', '--param3': '3'}
saver = SaverCsv(LoggerDefault(), params=params)
self.assertEqual(params, saver.get_current_params())


if __name__ == "__main__":
unittest.main()

0 comments on commit b51728a

Please sign in to comment.