from dataclasses import dataclass, field
import yaml
from gaggle.arguments.sys_args import SysArgs
from gaggle.arguments.individual_args import IndividualArgs
from gaggle.arguments.outdir_args import OutdirArgs
from gaggle.arguments.ga_args import GAArgs
from gaggle.arguments.problem_args import ProblemArgs
from gaggle.utils.special_print import print_warning
import transformers
[docs]def parse_args():
"""Helper function that parses the argument classes into a list of initialized argument objects with the given
CLI argument values.
Returns:
Returns list of [OutdirArgs, SysArgs, IndividualArgs, GAArgs, ProblemArgs, ConfigArgs]
"""
parser = transformers.HfArgumentParser((OutdirArgs, SysArgs, IndividualArgs, GAArgs, ProblemArgs,
ConfigArgs))
return parser.parse_args_into_dataclasses()
[docs]@dataclass
class ConfigArgs:
""" Argument class that allows to combine all the other arguments together and read
from config files for experiments"""
config_path: str = field(default=None, metadata={
"help": "path to the yaml configuration file (*.yml)"
})
[docs] def exists(self):
return self.config_path is not None
args_to_config = { # specify the config keys to read in the *.yml file
SysArgs.CONFIG_KEY: SysArgs(),
IndividualArgs.CONFIG_KEY: IndividualArgs(),
OutdirArgs.CONFIG_KEY: OutdirArgs(),
GAArgs.CONFIG_KEY: GAArgs(),
ProblemArgs.CONFIG_KEY: ProblemArgs()
}
[docs] @classmethod
def get_keys(cls):
"""
Returns: the list of config keys that will be read in the *.yml file
"""
return list(cls.args_to_config.keys())
[docs] @classmethod
def update(cls, config_key, arg_subclass):
r"""Add or replace one of the argument classes in the args_to_config that will
be read in the *.yml file.
Args:
config_key: key of the argument class to be added/replaced
arg_subclass: argument class that will be called when using the given config_key
Notes:
arg_subclass needs to be an un-initialized object as the update will initialize it.
"""
try:
assert config_key in cls.args_to_config.keys()
except AssertionError:
print(f"Config Key {config_key} is not a valid config key")
print(f"Valid Config Key: {list(cls.args_to_config.keys())}")
return
try:
assert issubclass(arg_subclass, type(cls.args_to_config[config_key]))
except AssertionError:
print(f"Given class needs to be a subclass of the replaced arg")
print(arg_subclass)
print(type(cls.args_to_config[config_key]))
return
cls.args_to_config[config_key] = arg_subclass()
[docs] def get_args(self):
return self.get_outdir_args(), self.get_sys_args(), self.get_individual_args(), self.get_problem_args(), \
self.get_ga_args()
[docs] def get_sys_args(self) -> SysArgs:
return self.args_to_config[SysArgs.CONFIG_KEY]
[docs] def get_problem_args(self) -> ProblemArgs:
return self.args_to_config[ProblemArgs.CONFIG_KEY]
[docs] def get_individual_args(self) -> IndividualArgs:
return self.args_to_config[IndividualArgs.CONFIG_KEY]
[docs] def get_outdir_args(self) -> OutdirArgs:
return self.args_to_config[OutdirArgs.CONFIG_KEY]
[docs] def get_ga_args(self) -> GAArgs:
return self.args_to_config[GAArgs.CONFIG_KEY]
def __post_init__(self):
if self.config_path is None:
return
with open(self.config_path, "r") as f:
data = yaml.safe_load(f)
self.keys = list(data.keys())
# load arguments
keys_not_found = []
for entry, values in data.items():
for key, value in values.items():
if key not in self.args_to_config[entry].__dict__.keys():
keys_not_found += [(entry, key)]
self.args_to_config[entry].__dict__[key] = value
if len(keys_not_found) > 0:
print_warning(f"Could not find these keys: {keys_not_found}. Make sure they exist.")