Graph Classification¶
The Task¶
This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.
The GraphClassifier
and GraphClassificationData
classes internally rely on pytorch-geometric.
Example¶
Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.
Once we’ve created the TUDataset, we create the GraphClassificationData
.
We then create our GraphClassifier
and train on the KKI data.
Next, we use the trained GraphClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier
example_requires("graph")
from torch_geometric.datasets import TUDataset # noqa: E402
# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")
datamodule = GraphClassificationData.from_datasets(
train_dataset=dataset,
val_split=0.1,
batch_size=4,
)
# 2. Build the task
backbone_kwargs = {"hidden_channels": 512, "num_layers": 4}
model = GraphClassifier(
num_features=datamodule.num_features, num_classes=datamodule.num_classes, backbone_kwargs=backbone_kwargs
)
# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify some graphs!
datamodule = GraphClassificationData.from_datasets(
predict_dataset=dataset[:3],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("graph_classification_model.pt")
Flash Zero¶
The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash graph_classification
To view configuration options and options for running the graph classifier with your own data, use:
flash graph_classification --help