seldonian.safety_test.safety_test.SafetyTest¶
- class SafetyTest(safety_dataset, model, parse_trees, regime='supervised_learning', additional_datasets={}, **kwargs)¶
Bases:
object
- __init__(safety_dataset, model, parse_trees, regime='supervised_learning', additional_datasets={}, **kwargs)¶
Object for running safety test
- Parameters:
safety_dataset (
DataSet
object) – The dataset object containing safety datamodel (
SeldonianModel
object) – The Seldonian model objectparse_trees (List(
ParseTree
objects)) – List of parse tree objects containing the behavioral constraintsregime (str) – The category of the machine learning algorithm, e.g., supervised_learning or reinforcement_learning
additional_datasets (dict, defaults to {}) – Specifies optional additional datasets to use for bounding the base nodes of the parse trees.
- __repr__()¶
Return repr(self).
Methods
- evaluate_primary_objective(theta, primary_objective)¶
Get value of the primary objective given model weights theta, and the safety data, D_s. This is a wrapper for primary_objective where data is fixed.
- Parameters:
theta (numpy.ndarray) – model weights
primary_objective – The primary objective function you want to evaluate
- Returns:
The primary objective function evaluated at theta, D_s
- Return type:
float
- get_importance_weights(theta)¶
Get an array of importance weights given model weights, theta, and the safety data, D_s. Only relevant for RL.
- Parameters:
theta (numpy.ndarray) – model weights
- Returns:
The importance weights
- Return type:
numpy.ndarray
- run(solution, batch_size_safety=None, **kwargs)¶
Loop over parse trees, calculate the bounds on leaf nodes and propagate to the root node. The safety test passes if the upper bounds of all parse tree root nodes are less than or equal to 0.
- Parameters:
solution (numpy ndarray) – The candidate solution found by candidate selection
batch_size_safety (int) – The number of datapoints to pass through the model in a single forward pass
- Returns:
passed, whether the candidate solution passed the safety test
- Return type:
bool