grit¶
Implementation of GRIT, a graph transformer model.
Original author: Liheng Ma Original code: https://github.com/LiamMa/GRIT Paper: “Graph Inductive Biases in Transformers without Message Passing”,
Adapted by: Philip Weigel
- class graphnet.models.gnn.grit.GRIT(*args, **kwargs)[source]¶
Bases:
GNNGRIT is a graph transformer model.
Original code: https://github.com/LiamMa/GRIT/blob/main/grit/network/grit_model.py
Construct GRIT model.
- Parameters:
nb_inputs (
int) – Number of inputs.hidden_dim (
int) – Size of hidden dimension.nb_outputs (
int, default:1) – Size of output dimension.ksteps (
int, default:21) – Number of random walk steps.n_layers (
int, default:10) – Number of GRIT layers.n_heads (
int, default:8) – Number of heads in MHA.pad_to_full_graph (
bool, default:True) – Pad to form fully-connected graph.add_node_attr_as_self_loop (
bool, default:False) – Adds node attr as an self-edge.dropout (
float, default:0.0) – Dropout probability.fill_value (
float, default:0.0) – Padding value.norm (
Module, default:<class 'torch.nn.modules.batchnorm.BatchNorm1d'>) – Uninstantiated normalization layer. Either torch.nn.BatchNorm1d or torch.nn.LayerNorm.attn_dropout (
float, default:0.2) – Attention dropout probability.edge_enhance (
bool, default:True) – Applies learnable weight matrix with node-pair in output node calculation for MHA.update_edges (
bool, default:True) – Update edge values after GRIT layer.attn_clamp (
float, default:5.0) – Clamp absolute value of attention scores to a value.activation (
Module, default:<class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated activation function. E.g. torch.nn.ReLUattn_activation (
Module, default:<class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated attention activation function. E.g. torch.nn.ReLUnorm_edges (
bool, default:True) – Apply normalization layer to edges.enable_edge_transform (
bool, default:True) – Apply transformation to edges.pred_head_layers (
int, default:2) – Number of layers in the prediction head.pred_head_activation (
Module, default:<class 'torch.nn.modules.activation.ReLU'>) – Uninstantiated prediction head activation function. E.g. torch.nn.ReLUpred_head_pooling (
str, default:'mean') – Pooling function to use for the prediction head, either “mean” (default) or “add”.position_encoding (
str, default:'NoPE') – Method of position encoding.args (Any)
kwargs (Any)
- Return type:
object