Toronto AI Lab
Forecasting Model Search
Improving Hyperparameter Optimization with Checkpointed Model Weights

Nikhil Mehta1
Jonathan Lorraine1,2,3
Steve Masson1
Ramanathan Arunachalam1

Zaid Pervaiz Bhat1
James Lucas1
Arun George Zachariah1

1 NVIDIA
2 University of Toronto
3 Vector Institute


We introduce Forecasting Model Search (FMS), which builds on the DyHPO multi-fidelity Bayesian Optimization hyperparameter optimization method. Notably, we condition on logged weights from checkpoints during training to improve performance and data efficiency.

Abstract: When training deep learning models, the performance depends largely on the selected hyperparameters. However, hyperparameter optimization (HPO) is often one of the most expensive parts of model design. Classical HPO methods treat this as a black-box optimization problem. However, gray-box HPO methods, which incorporate more information about the setup, have emerged as a promising direction for more efficient optimization. For example, we can use intermediate loss evaluations to terminate bad selections. In this work, we propose an HPO method for neural networks that uses logged checkpoints of the trained weights to guide future hyperparameter selections. Our Forecasting Model Search (FMS) method embeds weights into a Gaussian process deep kernel surrogate model, to be data-efficient with the logged network weights. To facilitate reproducibility and further research, we open-source our code.




Overview

Forecasting Model Search (FMS) is a hyperparameter optimization method designed to enhance the efficiency of selecting hyperparameters for deep learning models, focusing on choosing pretrained models from a hub. By leveraging logged weights from checkpoints, FMS provides a more informed basis for guiding hyperparameter selection, improving upon traditional methods.


FMS builds on the DyHPO multi-fidelity Bayesian Optimization HPO method by incorporating logged checkpoints of model weights, providing a rich source of information that implicitly includes the architecture, dataset, loss, and optimization process. Specifically, checkpointed weights are featurized using a permutation-invariant graph metanetwork (PIGMN) as an input to a deep kernel Gaussian process. Permutation invariance allows FMS to be more data-efficient, while the graph metanetwork enables us to effectively leverage checkpoints of various architectures to guide hyperparameter selection.



Paper

Nikhil Mehta, Jonathan Lorraine, Steve Masson,
Ramanathan Arunachalam, Zaid Pervaiz Bhat,
James Lucas, Arun George Zachariah

Improving Hyperparameter Optimization with
Checkpointed Model Weights


[Paper]
[Code]
[Bibtex]



Experimental Results


Our experiments aim to evaluate FMS's compute budget versus quality trade-off and generalization to unseen datasets and architectures. We focus on scenarios where users select which pre-trained model to fine-tune. Below, in each plot, we show the regret against the compute budget across different hubs and various hyperparameter optimization methods in each color. The regret values reflect the difference between the actual performance and the best possible performance over time. Lower regret indicates better performance.


We evaluate FMS's ability to generalize to new datasets and architectures. FMS-GMN with generalization means we train on multiple datasets. FMS-GMN without generalization only trains on the target dataset. The results show that our model can effectively generalize knowledge between different tasks because the generalization setup's regret is consistently lower than that of the non-generalization setup. These results show FMS converges faster to a potentially higher-quality solution by leveraging the additional datasets.




Additional Ablations


Below, we show the regret against the compute budget for various flavors of FMS to ablate over key design choices. We include variations without learning curve features and flattened vectors for the network weights. The results show that FMS consistently performs well across different configurations. FMS-NFN doesn't support diverse architectures, so it only runs on Simple CNN hub.




Citation



Mehta, N., Lorraine, J., Masson, S., Arunachalam, R., Bhat, Z., Lucas, J., & Zachariah, A. (2024). Improving Hyperparameter Optimization with Checkpointed Model Weights. arXiv preprint arXiv:2406.18630.


@article{mehta2024fms,
  title={Improving Hyperparameter Optimization with Checkpointed Model Weights},
  author={Nikhil Mehta and Jonathan Lorraine and Steve Masson and Ramanathan Arunachalam and Zaid Pervaiz Bhat and James Lucas and Arun George Zachariah},
  journal={arXiv preprint arXiv:2406.18630},
  url={https://arxiv.org/abs/2406.18630},
  year = {2024},
}


We thank David Acuna for the website template.