The Task¶
Once you’ve implemented a Flash DataModule
and some backbones, you should implement your Task
in model.py.
The Task
is responsible for: setting up the backbone, performing the forward pass of the model, and calculating the loss and any metrics.
Remember that, under the hood, the Flash Task
is simply a LightningModule
with some helpful defaults.
To build your task, you can start by overriding the base Task
or any of the existing Task
implementations.
For example, in our scikit-learn example, we can just override ClassificationTask
which provides good defaults for classification.
You should attach your backbones registry as a class attribute like this:
class TemplateSKLearnClassifier(ClassificationTask):
backbones: FlashRegistry = TEMPLATE_BACKBONES
Model architecture and hyper-parameters¶
In the __init__()
, you will need to configure defaults for the:
loss function
optimizer
metrics
backbone / model
You will also need to create the backbone from the registry and create the model head. Here’s the code:
def __init__(
self,
num_features: int,
num_classes: Optional[int] = None,
labels: Optional[List[str]] = None,
backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128",
backbone_kwargs: Optional[Dict] = None,
loss_fn: LOSS_FN_TYPE = None,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
learning_rate: Optional[float] = None,
multi_label: bool = False,
):
self.save_hyperparameters()
if labels is not None and num_classes is None:
num_classes = len(labels)
super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
num_classes=num_classes,
labels=labels,
)
if not backbone_kwargs:
backbone_kwargs = {}
if isinstance(backbone, tuple):
self.backbone, out_features = backbone
else:
self.backbone, out_features = self.backbones.get(backbone)(num_features=num_features, **backbone_kwargs)
self.head = nn.Linear(out_features, num_classes)
Note
We call save_hyperparameters()
to log the arguments to the __init__
as hyperparameters. Read more here.
Adding the model routines¶
You should override the {train,val,test,predict}_step
methods.
The default {train,val,test,predict}_step
implementations in Task
expect a tuple containing the input (to be passed to the model) and target (to be used when computing the loss), and should be suitable for most applications.
In our template example, we just extract the input and target from the input mapping and forward them to the super
methods.
Here’s the code for the training_step
:
def training_step(self, batch: Any, batch_idx: int) -> Any:
"""For the training step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and
:attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the
:meth:`~flash.core.model.Task.training_step`."""
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return super().training_step(batch, batch_idx)
We use the same code for the validation_step
and test_step
.
For predict_step
we don’t need the targets, so our code looks like this:
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the
input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
batch = batch[DataKeys.INPUT]
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
Note
You can completely replace the {train,val,test,predict}_step
methods (that is, without a call to super
) if you need more custom behaviour for your Task
at a particular stage.
Finally, we use our backbone and head in a custom forward pass:
def forward(self, x) -> torch.Tensor:
"""First call the backbone, then the model head."""
x = self.backbone(x)
return self.head(x)
Now that you’ve got your task, take a look at some optional advanced features you can add or go ahead and create some examples showing your task in action!