seldonian.safety_test.safety_test.SafetyTest

class SafetyTest(safety_dataset, model, parse_trees, regime='supervised_learning', additional_datasets={}, **kwargs)

Bases: object

__init__(safety_dataset, model, parse_trees, regime='supervised_learning', additional_datasets={}, **kwargs)

Object for running safety test

Parameters:
  • safety_dataset (DataSet object) – The dataset object containing safety data

  • model (SeldonianModel object) – The Seldonian model object

  • parse_trees (List(ParseTree objects)) – List of parse tree objects containing the behavioral constraints

  • regime (str) – The category of the machine learning algorithm, e.g., supervised_learning or reinforcement_learning

  • additional_datasets (dict, defaults to {}) – Specifies optional additional datasets to use for bounding the base nodes of the parse trees.

__repr__()

Return repr(self).

Methods

evaluate_primary_objective(theta, primary_objective)

Get value of the primary objective given model weights theta, and the safety data, D_s. This is a wrapper for primary_objective where data is fixed.

Parameters:
  • theta (numpy.ndarray) – model weights

  • primary_objective – The primary objective function you want to evaluate

Returns:

The primary objective function evaluated at theta, D_s

Return type:

float

get_importance_weights(theta)

Get an array of importance weights given model weights, theta, and the safety data, D_s. Only relevant for RL.

Parameters:

theta (numpy.ndarray) – model weights

Returns:

The importance weights

Return type:

numpy.ndarray

run(solution, batch_size_safety=None, **kwargs)

Loop over parse trees, calculate the bounds on leaf nodes and propagate to the root node. The safety test passes if the upper bounds of all parse tree root nodes are less than or equal to 0.

Parameters:
  • solution (numpy ndarray) – The candidate solution found by candidate selection

  • batch_size_safety (int) – The number of datapoints to pass through the model in a single forward pass

Returns:

passed, whether the candidate solution passed the safety test

Return type:

bool