normalizing_flow¶
Standard model class(es).
- class graphnet.models.normalizing_flow.NormalizingFlow(*args, **kwargs)[source]¶
Bases:
EasySyntaxA model for building (conditional) normalizing flows in GraphNeT.
This model relies on jammy_flows for building and evaluating normalizing flows. https://thoglu.github.io/jammy_flows/usage/introduction.html for details.
Build NormalizingFlow to learn (conditional) normalizing flows.
NormalizingFlow is able to build, train and evaluate a wide suite of normalizing flows. Instead of optimizing a loss function, flows minimize a learned pdf of your data, providing you with a posterior distribution for every example instead of point-like predictions.
NormalizingFlow can be conditioned on existing fields in the DataRepresentation or latent representations from Models.
NormalizingFlow is built upon https://github.com/thoglu/jammy_flows, and we refer to their documentation for details on the flows.
- Parameters:
graph_definition (
GraphDefinition) – The GraphDefinition to train the model on.target_labels (
str) – Name of target(s) to learn the pdf of.backbone (
Optional[GNN], default:None) – Architecture used to produce latent representations ofconditioned. (the input data on which the pdf will be)
None. (Defaults to)
condition_on (
Union[str,List[str],None], default:None) – List of fields in Data objects to condition theNone.
flow_layers (
str, default:'gggt') – A string defining the flow layers.https (See) – //thoglu.github.io/jammy_flows/usage/introduction.html
"gggt". (for details. Defaults to)
optimizer_class (
Type[Optimizer], default:<class 'torch.optim.adam.Adam'>) – Optimizer to use. Defaults to Adam.optimizer_kwargs (
Optional[Dict], default:None) – Optimzier arguments. Defaults to None.scheduler_class (
Optional[type], default:None) – Learning rate scheduler to use. Defaults to None.scheduler_kwargs (
Optional[Dict], default:None) – Arguments to learning rate scheduler.None.
scheduler_config (
Optional[Dict], default:None) – Defaults to None.args (Any)
kwargs (Any)
- Raises:
ValueError – if both backbone and condition_on is specified.
- Return type:
object
- forward(data)[source]¶
Forward pass, chaining model components.
- Return type:
Tensor- Parameters:
data (Data | List[Data])
Perform shared step.
Applies the forward pass and the following loss calculation, shared between the training and validation step.
- Return type:
Tensor- Parameters:
batch (List[Data])
batch_idx (int)