Source code for gaggle.problem.problem_factory

from gaggle.problem import Problem
from gaggle.arguments import ProblemArgs, SysArgs
from gaggle.problem.environment.environment_factory import EnvironmentFactory
from gaggle.problem.environment.rl_problem import RLProblem
from gaggle.problem.leap.leap_problem import convert_leap_problem
from gaggle.problem.dataset.classification_problem import ClassificationProblem
from gaggle.problem.dataset.dataset_factory import DatasetFactory
from gaggle.utils.special_print import print_warning
from typing import Callable


[docs]class ProblemFactory: r"""Factory that generates pre-existing available Problems. ProblemFactory.problems stores said Problems as a dictionary. It stores problems as 4 types: classification, rl, leap and custom. Each is a key that holds a dictionary of available problems/datasets/environments. """ problems = { "classification": DatasetFactory.datasets, "rl": EnvironmentFactory.environments, "leap": {}, "custom": {} } registrable_problems = ["leap", "custom"]
[docs] @classmethod def register_problem(cls, problem_type: str, problem_name: str, problem: Callable, *args, **kwargs): """Register a new problem to the factory of pre-existing available Problems. Said problem has to be a problem that is registrable. ProblemFactory.registrable_problems stores the list of problem types that are allowed to be registrables. For the other types, please check their respective documentation to see how to register them (for classification: ClassificationProblem and DatasetFactory, for rl: RLProblem and EnvironmentFactory). If custom arguments need to be passed at initialization time of the Problem (for example when ProblemFactory.from_problem_args is called), they can also be passed as *args and **kwargs and they will be used to initialize the problem. Args: problem_type: The key to ProblemFactory.problems. problem_name: The new problem's name as a str. problem: The uninitialized problem. *args: The args that need to be passed to the problem at initialization time **kwargs: the kwargs that need to be passed to the problem at initialization time. Returns: """ assert isinstance(problem, Callable) if problem_type not in cls.registrable_problems: if problem_type == "classification": print_warning(f"To register a new classification problem that has the same evaluate as " f"ClassificationProblem but a different dataset," f" please use DatasetFactory.update with the updated dataset.\n" f"If the behavior is modified, then register the new problem as 'custom'") elif problem_type == "rl": print_warning(f"To register a new rl problem that has the same evaluate as RLProblem but a different " f"environment, please use EnvironmentFactory.update with the updated environment.\n" f"If the behavior is modified, then register the new problem as 'custom'") else: print(f"Can only register new {cls.registrable_problems} problems") raise ValueError(problem_type) else: if problem_name in list(cls.problems[problem_type].keys()): print_warning(f"{problem_name} is an already existing {problem_type} problem and will be overwritten") cls.problems[problem_type][problem_name] = (problem, args, kwargs)
[docs] @classmethod def from_problem_args(cls, problem_args: ProblemArgs = None, sys_args: SysArgs = None) -> Problem: r"""Initializes the requested Problem from the dictionary of available Problems. This is done by using the attributes problem_args.problem_type and problem_args.problem_name as the lookup keys to ProblemFactory.problems. Args: problem_args: sys_args: Returns: The initialized Problem requested. """ problem_args = problem_args if problem_args is not None else ProblemArgs() sys_args = SysArgs() if sys_args is None else sys_args problem_type = None for key, value in cls.problems.items(): if problem_args.problem_name in list(value.keys()): problem_type = key break if problem_type is None: raise ValueError(problem_args.problem_name) if problem_type == "classification": return ClassificationProblem(problem_args, sys_args) elif problem_type == "rl": return RLProblem(problem_args, sys_args) else: problem, args, kwargs = cls.problems[problem_type][problem_args.problem_name] return problem(problem_args, sys_args, *args, **kwargs)
[docs] @classmethod def convert_and_register_leap_problem(cls, problem_name: str, leap_problem: Callable, *args, **kwargs): """Shortcut method that both converts the leap problem to our Problem class and registers it Args: problem_name: leap_problem: *args: **kwargs: Returns: """ registrable_leap_problem = convert_leap_problem(leap_problem, *args, **kwargs) cls.register_problem(problem_type="leap", problem_name=problem_name, problem=registrable_leap_problem, *args, **kwargs)