seldonian.spec.RLSpec¶
- class RLSpec(dataset, model, parse_trees, frac_data_in_safety=0.6, primary_objective=None, initial_solution_fn=None, base_node_bound_method_dict={}, use_builtin_primary_gradient_fn=True, custom_primary_gradient_fn=None, optimization_technique='gradient_descent', optimizer='adam', optimization_hyperparams={'alpha_lamb': 0.005, 'alpha_theta': 0.005, 'beta_rmsprop': 0.95, 'beta_velocity': 0.9, 'gradient_library': 'autograd', 'hyper_search': None, 'lambda_init': 0.5, 'num_iters': 200, 'use_batches': False, 'verbose': False}, regularization_hyperparams={}, batch_size_safety=None, candidate_dataset=None, safety_dataset=None, additional_datasets={}, verbose=False)¶
Bases:
Spec
Specification object for running RL Seldonian algorithms
- Parameters:
dataset (
DataSet
object) – The dataset object containing safety datamodel – The
RL_Model
objectparse_trees (List(
ParseTree
objects)) – List of parse tree objects containing the behavioral constraintsfrac_data_in_safety (float) – Fraction of data used in safety test. The remaining fraction will be used in candidate selection
primary_objective (function or class method) – The objective function that would be solely optimized in the absence of behavioral constraints, i.e. the loss function
initial_solution_fn (function) – Function to provide initial model weights in candidate selection
base_node_bound_method_dict (dict, defaults to {}) – A dictionary specifying the bounding method to use for each base node
use_builtin_primary_gradient_fn (bool, defaults to True) – Whether to use the built-in function for the gradient of the primary objective, if one exists. If False, uses autograd
custom_primary_gradient_fn (function, defaults to None) – A function for computing the gradient of the primary objective. If None, falls back on builtin function or autograd
optimization_technique (str, defaults to 'gradient_descent') – The method for optimization during candidate selection. E.g. ‘gradient_descent’, ‘barrier_function’
optimizer (str, defaults to 'adam') – The string name of the optimizer used during candidate selection
optimization_hyperparams (dict) – Hyperparameters for optimization during candidate selection. See Candidate Selection.
regularization_hyperparams (dict) – Hyperparameters for regularization during candidate selection. See Candidate Selection.
batch_size_safety (int, defaults to None) – The number of samples that are forward passed through the model during the safety test. Value does not change result, but sometimes is necessary when dataset is large to avoid memory overflow.
candidate_dataset (
DataSet
, defaults to None) – An dataset to use explicitly for candidate selection. If provided, overrides the data splitting and dataset is not used.safety_dataset (
DataSet
, defaults to None) – An dataset to use explicitly for the safety test. If provided in conjuction with candidate_dataset, overrides the data splitting and dataset is not used.additional_datasets (dict, defaults to {}) – Specifies optional additional datasets to use for bounding the base nodes of the parse trees.
- __init__(dataset, model, parse_trees, frac_data_in_safety=0.6, primary_objective=None, initial_solution_fn=None, base_node_bound_method_dict={}, use_builtin_primary_gradient_fn=True, custom_primary_gradient_fn=None, optimization_technique='gradient_descent', optimizer='adam', optimization_hyperparams={'alpha_lamb': 0.005, 'alpha_theta': 0.005, 'beta_rmsprop': 0.95, 'beta_velocity': 0.9, 'gradient_library': 'autograd', 'hyper_search': None, 'lambda_init': 0.5, 'num_iters': 200, 'use_batches': False, 'verbose': False}, regularization_hyperparams={}, batch_size_safety=None, candidate_dataset=None, safety_dataset=None, additional_datasets={}, verbose=False)¶
- __repr__()¶
Return repr(self).
Methods
- validate_additional_datasets(additional_datasets)¶
Ensure that the additional datasets dict is valid. It is valid if 1) the strings for the parse trees and base nodes match what is in the parse trees. 2) Either a “dataset” key is present or
BOTH “candidate_dataset” and “safety_dataset” are present in each subdict.
Also, for the missing parse trees and base nodes, fill those entires with the primary dataset or candidate/safety split if provided.
- Parameters:
additional_datasets (dict, defaults to {}) – Specifies optional additional datasets to use for bounding the base nodes of the parse trees.
- validate_custom_datasets(candidate_dataset, safety_dataset)¶
Ensure that if either candidate_dataset or safety_dataset is specified, both are specified.
- Parameters:
candidate_dataset – The dataset provided by the user to be used for candidate selection.
safety_dataset – The dataset provided by the user to be used for the safety test.