Skip to content

Commit

Permalink
[2.5] Support config dict (#2604)
Browse files Browse the repository at this point in the history
* add dict support and test cases

* more test cases

* more test cases

* added more docstrings
  • Loading branch information
yanchengnv authored Jun 3, 2024
1 parent 3e20b14 commit 0b314f8
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 40 deletions.
182 changes: 142 additions & 40 deletions nvflare/fuel/utils/config_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -126,8 +127,30 @@ def initialize(cls, section_files: Dict[str, str], config_path: List[str], parse
raise ValueError(f"parsed_args must be argparse.Namespace but got {type(parsed_args)}")
cls._cmd_args = dict(parsed_args.__dict__)

@classmethod
def reset(cls):
"""Reset the ConfigServer to its initial state. All registered sections and cached var values
are cleared. This method is mainly used for test purpose.
Returns:
"""
cls._sections = {}
cls._config_path = []
cls._cmd_args = None
cls._var_dict = None
cls._var_values = {}

@classmethod
def get_section(cls, name: str):
"""Get the specified section.
Args:
name: name of the section
Returns: the section of the specified name, or None if the section is not found.
"""
return cls._sections.get(name)

@classmethod
Expand All @@ -140,7 +163,7 @@ def add_section(cls, section_name: str, data: dict, overwrite_existing: bool = T
data: data of the section
overwrite_existing: if section already exists, whether to overwrite
Returns:
Returns: None
"""
if not isinstance(section_name, str):
Expand All @@ -153,6 +176,15 @@ def add_section(cls, section_name: str, data: dict, overwrite_existing: bool = T

@classmethod
def load_configuration(cls, file_basename: str) -> Optional[Config]:
"""Load config data from the specified file basename.
The full name of the config file will be determined by ConfigFactory.
Args:
file_basename: the basename of the config file.
Returns: config data loaded, or None if the config file is not found.
"""
return ConfigFactory.load_config(file_basename, cls._config_path)

@classmethod
Expand All @@ -161,6 +193,7 @@ def load_config_dict(
) -> Optional[Dict]:
"""
Load a specified config file ( ignore extension)
Args:
raise_exception: if True raise exception when error occurs
file_basename: base name of the config file to be loaded.
Expand Down Expand Up @@ -208,6 +241,25 @@ def find_file(cls, file_basename: str) -> Union[None, str]:
raise TypeError(f"file_basename must be str but got {type(file_basename)}")
return search_file(file_basename, cls._config_path)

@classmethod
def _get_from_config(cls, func, name: str, conf, default):
v, src = cls._get_var_from_source(name, conf)
cls.logger.debug(f"got var {name} from {src}")
if v is None:
return default

# convert to right data type
return func(name, v)

@classmethod
def _any_var(cls, func, name, conf, default):
if name in cls._var_values:
return cls._var_values.get(name)
v = cls._get_from_config(func, name, conf, default)
if v is not None:
cls._var_values[name] = v
return v

@staticmethod
def _get_var_from_os_env(name: str):
if not name.startswith(ENV_VAR_PREFIX):
Expand Down Expand Up @@ -266,53 +318,49 @@ def _get_var_from_source(cls, name: str, conf):
return cls._get_var_from_os_env(name), "env"

@classmethod
def _get_var(cls, name: str, conf):
value, src = cls._get_var_from_source(name, conf)
# print(f"#### VAR from {src}: {name}={value}")
return value

@classmethod
def _int_var(cls, name: str, conf=None, default=None):
v = cls._get_var(name, conf)
if v is None:
return default
def _to_int(cls, name: str, v):
try:
return int(v)
except Exception as e:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to int: {e}")

@classmethod
def _any_var(cls, func, name, conf, default):
if name in cls._var_values:
return cls._var_values.get(name)
v = func(name, conf, default)
if v is not None:
cls._var_values[name] = v
return v

@classmethod
def get_int_var(cls, name: str, conf=None, default=None):
return cls._any_var(cls._int_var, name, conf, default)
"""Get configured int value of the specified var
Args:
name: name of the var
conf: source config
default: value to return if the var is not found
Returns: configured value of the var, or the default value if var is not configured
"""
return cls._any_var(cls._to_int, name, conf, default)

@classmethod
def _float_var(cls, name: str, conf=None, default=None):
v = cls._get_var(name, conf)
if v is None:
return default
def _to_float(cls, name: str, v):
try:
return float(v)
except:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to float")
except Exception as e:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to float: {e}")

@classmethod
def get_float_var(cls, name: str, conf=None, default=None):
return cls._any_var(cls._float_var, name, conf, default)
"""Get configured float value of the specified var
Args:
name: name of the var
conf: source config
default: value to return if the var is not found
Returns: configured value of the var, or the default value if var is not configured
"""
return cls._any_var(cls._to_float, name, conf, default)

@classmethod
def _bool_var(cls, name: str, conf=None, default=None):
v = cls._get_var(name, conf)
if v is None:
return default
def _to_bool(cls, name: str, v):
if isinstance(v, bool):
return v
if isinstance(v, int):
Expand All @@ -324,22 +372,76 @@ def _bool_var(cls, name: str, conf=None, default=None):

@classmethod
def get_bool_var(cls, name: str, conf=None, default=None):
return cls._any_var(cls._bool_var, name, conf, default)
"""Get configured bool value of the specified var
Args:
name: name of the var
conf: source config
default: value to return if the var is not found
Returns: configured value of the var, or the default value if var is not configured
"""
return cls._any_var(cls._to_bool, name, conf, default)

@classmethod
def _str_var(cls, name: str, conf=None, default=None):
v = cls._get_var(name, conf)
if v is None:
return default
def _to_str(cls, name: str, v):
try:
return str(v)
except:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to str")
except Exception as e:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to str: {e}")

@classmethod
def get_str_var(cls, name: str, conf=None, default=None):
return cls._any_var(cls._str_var, name, conf, default)
"""Get configured str value of the specified var
Args:
name: name of the var
conf: source config
default: value to return if the var is not found
Returns: configured value of the var, or the default value if var is not configured
"""
return cls._any_var(cls._to_str, name, conf, default)

@classmethod
def _to_dict(cls, name: str, v):
if isinstance(v, dict):
return v

if isinstance(v, str):
# assume it's a json str
try:
v2 = json.loads(v)
except Exception as e:
raise ValueError(f"var {name}'s value '{v}' cannot be converted to dict: {e}")

if not isinstance(v2, dict):
raise ValueError(f"var {name}'s value '{v}' does not represent a dict")
return v2
else:
raise ValueError(f"var {name}'s value '{v}' does not represent a dict")

@classmethod
def get_dict_var(cls, name: str, conf=None, default=None):
"""Get configured dict value of the specified var
Args:
name: name of the var
conf: source config
default: value to return if the var is not found
Returns: configured value of the var, or the default value if var is not configured
"""
return cls._any_var(cls._to_dict, name, conf, default)

@classmethod
def get_var_values(cls):
"""Get cached var values.
Returns:
"""
return cls._var_values
Loading

0 comments on commit 0b314f8

Please sign in to comment.