experiments.headless_example.HeadlessExample¶
- class HeadlessExample(spec)¶
Bases:
BaseExample
- __init__(spec)¶
Class for running headless experiments
- __repr__()¶
Return repr(self).
Methods
- run(full_pretraining_model, headless_pretraining_model, head_layer_names, latent_feature_shape, loss_func_pretraining, learning_rate_pretraining, pretraining_device, batch_epoch_dict_pretraining, safety_batch_size_pretraining, n_trials, data_fracs, results_dir, perf_eval_fn, perf_eval_kwargs, constraint_eval_kwargs, n_workers=1, batch_epoch_dict={}, datagen_method='resample', verbose=False, baselines=[], performance_label='performance', performance_yscale='linear', plot_savename=None, plot_fontsize=12, legend_fontsize=8, model_label_dict={})¶
Run the experiment for this example. Runs any baseline models included in baselines parameter first. Then produces the three plots.
- Parameters:
full_pretraining_model – The model with head intact
headless_pretraining_model – The model with head removed
head_layer_names – List of names of the layers to be tuned.
latent_feature_shape – Shape of the latent features (the output shape of the last layer of headless model)
loss_func_pretraining – Loss function to use for pretraining
learning_rate_pretraining – Learning rate for pretraining
pretraining_device – Torch device for pretraining
batch_epoch_dict_pretraining – Dictionary mapping data fraction to (batch_size,n_epochs)
safety_batch_size_pretraining – The number of samples to forward pass at a time in the safety test. Changing this does not change the result, but can lead to memory overflow if this number is too large.
n_trials – The number of trials for the experiments
data_fracs – The data fractions for the experiments
results_dir – Directory for saving results files
perf_eval_fn – Performance evaluation function
perf_eval_kwargs – Keyword arguments to pass to the performance evaluation function
constraint_eval_kwargs (dict) – Extra keyword arguments to pass to the constraint_eval_fns
n_workers – Number of parallel processors to use
batch_epoch_dict (dict) – Instruct batch sizes and n_epochs for each data frac
datagen_method – Method for generating the trial data
baselines – List of baseline models to include
performance_label (str) – Label to use on the performance plot (left-most plot)
performance_yscale (str) – How to scale the y-axis on the performance plot. Options are “linear” and “log”
plot_savename – If provided, the filepath where the three plots will be saved
legend_fontsize – Font size for legend
model_label_dict – Dictionary mapping model names (see model.model_name) to display name in the 3 plots legend.