experiments.headless_example.HeadlessExample

class HeadlessExample(spec)

Bases: BaseExample

__init__(spec)

Class for running headless experiments

__repr__()

Return repr(self).

Methods

run(full_pretraining_model, headless_pretraining_model, head_layer_names, latent_feature_shape, loss_func_pretraining, learning_rate_pretraining, pretraining_device, batch_epoch_dict_pretraining, safety_batch_size_pretraining, n_trials, data_fracs, results_dir, perf_eval_fn, perf_eval_kwargs, constraint_eval_kwargs, n_workers=1, batch_epoch_dict={}, datagen_method='resample', verbose=False, baselines=[], performance_label='performance', performance_yscale='linear', plot_savename=None, plot_fontsize=12, legend_fontsize=8, model_label_dict={})

Run the experiment for this example. Runs any baseline models included in baselines parameter first. Then produces the three plots.

Parameters:
  • full_pretraining_model – The model with head intact

  • headless_pretraining_model – The model with head removed

  • head_layer_names – List of names of the layers to be tuned.

  • latent_feature_shape – Shape of the latent features (the output shape of the last layer of headless model)

  • loss_func_pretraining – Loss function to use for pretraining

  • learning_rate_pretraining – Learning rate for pretraining

  • pretraining_device – Torch device for pretraining

  • batch_epoch_dict_pretraining – Dictionary mapping data fraction to (batch_size,n_epochs)

  • safety_batch_size_pretraining – The number of samples to forward pass at a time in the safety test. Changing this does not change the result, but can lead to memory overflow if this number is too large.

  • n_trials – The number of trials for the experiments

  • data_fracs – The data fractions for the experiments

  • results_dir – Directory for saving results files

  • perf_eval_fn – Performance evaluation function

  • perf_eval_kwargs – Keyword arguments to pass to the performance evaluation function

  • constraint_eval_kwargs (dict) – Extra keyword arguments to pass to the constraint_eval_fns

  • n_workers – Number of parallel processors to use

  • batch_epoch_dict (dict) – Instruct batch sizes and n_epochs for each data frac

  • datagen_method – Method for generating the trial data

  • baselines – List of baseline models to include

  • performance_label (str) – Label to use on the performance plot (left-most plot)

  • performance_yscale (str) – How to scale the y-axis on the performance plot. Options are “linear” and “log”

  • plot_savename – If provided, the filepath where the three plots will be saved

  • legend_fontsize – Font size for legend

  • model_label_dict – Dictionary mapping model names (see model.model_name) to display name in the 3 plots legend.