Toronto AI Lab
Graph Metanetworks
Graph Metanetworks for Processing Diverse Neural Architectures

Derek Lim2
Haggai Maron1,3
Marc T. Law 1
Jonathan Lorraine 1
James Lucas1

1NVIDIA
2MIT CSAIL
3Technion

Spotlight paper at
International Conference on Learning Representations (ICLR) 2024


What if neural nets could process, analyze, interpret, and edit other neural nets' weights? We design metanets capable of processing diverse input neural net architectures, with provable guarantees.

Abstract: Neural networks efficiently encode learned information within their parameters. Consequently, many tasks can be unified by treating neural networks themselves as input data. When doing so, recent studies demonstrated the importance of accounting for the symmetries and geometry of parameter spaces. However, those works developed architectures tailored to specific networks such as MLPs and CNNs without normalization layers, and generalizing such architectures to other types of networks can be challenging. In this work, we overcome these challenges by building new metanetworks — neural networks that take weights from other neural networks as input. Put simply, we carefully build graphs representing the input neural networks and process the graphs using graph neural networks. Our approach, Graph Metanetworks (GMNs), generalizes to neural architectures where competing methods struggle, such as multi-head attention layers, normalization layers, convolutional layers, ResNet blocks, and group-equivariant linear layers. We prove that GMNs are expressive and equivariant to parameter permutation symmetries that leave the input neural network functions unchanged. We validate the effectiveness of our method on several metanetwork tasks over diverse neural network architectures.




Paper

Derek Lim, Haggai Maron, Marc T. Law,
Jonathan Lorraine, James Lucas

Graph Metanetworks for
Processing Diverse Neural Architectures

            ICLR 2024 Logo


[Paper]
[Code coming soon!]
[Bibtex]



Overview


What if neural nets could process, analyze, interpret, and edit other neural nets' weights? These models, termed metanetworks, are extremely versatile and useful. For instance, metanets have been used for learned optimization (map [params, gradients] → new params), 3D data generation (generating INR weights), analyzing parameter spaces (predict hyperparameters from weights), and more.

We develop Graph Metanetworks (GMNs). GMNs are metanetworks that operate on neural architectures while respecting parameter symmetries. To achieve this we,

  1. Convert input neural networks to graphs, where neurons are nodes and parameters are edges.
  2. Process this with a graph neural network or Graph Transformer

Crucially, our design allows a single metanet to operate on a diverse set of neural architectures.



We develop procedures to turn many types of layers into graphs, such as: multihead attention, convolutions, linear, residual, group equivariant linear, spatial grid, and normalization layers. We experimentally process small ViTs, ResNets, DeepSets, and other input nets.




Theoretical Guarantees



Theoretically, we prove that our GMNs are equivariant/invariant to computation graph automorphisms of the inputs, which correspond to permutation parameter symmetries (e.g. permuting hidden neurons in MLPs, hidden channels in CNNs).


Propositions 1&2 (informal): Graph Metanets are equivariant to parameter permutation symmetries.

We prove that our GMNs are able to express existing metanets.

Proposition 3: On MLP inputs (where parameter graphs and computation graphs coincide), Graph Metanets can express StatNN and NP-NFN.

They can even simulate the forward pass of any network defined as a computation graph. This means they are at least as expressive as the input networks they operate on.

Proposition 4: On computation graph inputs, graph metanets can express the forward pass of any input feedforward neural network.




Experimental Results


We empirically show that our GMNs improve over previous metanets on several tasks:

  • Predicting CIFAR-10 test accuracy of diverse input image classifiers (ViTs, ResNets, DeepSets, etc) from their weights alone
  • Editing 2D neural implicits
  • Self-supervised learning of representations of neural implicits

Results for predicting the test accuracy of input neural networks trained on CIFAR-10. The top results use a training set of 15,000 uniformly selected input networks, the middle results use 10% of this training set, and the bottom results only train on input networks of low hidden dimension (while testing on networks with strictly higher hidden dimension). Our method performs best in all settings.
    Varying CNNs Diverse Architectures
R2 τ R2 τ
50% DeepSets 0.778 ± 0.002 0.697 ± 0.002 0.562 ± 0.020 0.559 ± 0.011
DMC 0.948 ± 0.009 0.876 ± 0.003 0.957 ± 0.009 0.883 ± 0.007
GMN (Ours) 0.978 ± 0.002 0.915 ± 0.006 0.975 ± 0.002 0.908 ± 0.004
5% DeepSets 0.692 ± 0.006 0.648 ± 0.002 0.126 ± 0.015 0.290 ± 0.010
DMC 0.816 ± 0.038 0.762 ± 0.014 0.810 ± 0.046 0.758 ± 0.013
GMN (Ours) 0.876 ± 0.010 0.797 ± 0.005 0.918 ± 0.002 0.828 ± 0.005
OOD DeepSets 0.741 ± 0.015 0.683 ± 0.005 0.128 ± 0.071 0.380 ± 0.014
DMC 0.387 ± 0.229 0.760 ± 0.024 -0.134 ± 0.147 0.566 ± 0.055
GMN (Ours) 0.891 ± 0.037 0.870 ± 0.010 0.768 ± 0.063 0.780 ± 0.030


Citation


Derek Lim, Haggai Maron, Marc T. Law, Jonathan Lorraine, James Lucas
Graph Metanetworks for Processing Diverse Neural Architectures. ICLR, 2024.


@inproceedings{lim2024graph,
  title={Graph Metanetworks for Processing Diverse Neural Architectures},
  author={Derek Lim and Haggai Maron and Marc T. Law and Jonathan Lorraine and James Lucas},
  booktitle={The Twelfth International Conference on Learning Representations},
  url={https://openreview.net/forum?id=ijK5hyxs0n},
  year = {2024},
}


We thank David Acuna for the website template.