The Backbones¶
Now that you’ve got a way of loading data, you should implement some backbones to use with your Task
.
Create a FlashRegistry
to use with your Task
in backbones.py.
The registry allows you to register backbones for your task that can be selected by the user.
The backbones can come from anywhere as long as you can register a function that loads the backbone.
Furthermore, the user can add their own models to the existing backbones, without having to write their own Task
!
You can create a registry like this:
TEMPLATE_BACKBONES = FlashRegistry("backbones")
Let’s add a simple MLP backbone to our registry.
We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our Task
).
You can use any name for the function, although we use load_{model name}
by convention.
You also need to provide name
and namespace
of the backbone.
The standard for namespace is data_type/task_type
, so for an image classification task the namespace will be image/classification
.
Here’s the code:
@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
"""A simple MLP backbone with 128 hidden units."""
return (
nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(True),
nn.BatchNorm1d(128),
),
128,
)
Here’s another example with a slightly more complex model:
@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
"""Two layer MLP backbone with 128 and 256 hidden units respectively."""
return (
nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(True),
nn.BatchNorm1d(128),
nn.Linear(128, 256),
nn.ReLU(True),
nn.BatchNorm1d(256),
),
256,
)
Here’s a another example, which adds DINO
pretrained model from PyTorch Hub to the IMAGE_CLASSIFIER_BACKBONES
, from flash/image/classification/backbones/transformers.py:
def dino_vitb16(*_, **__):
backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
return backbone, 768
Once you’ve got some data and some backbones, implement your task!