The Data¶
The first step to contributing a task is to implement the classes we need to load some data. Inside data.py you should implement:
some
Input
classes (optional)a
BaseVisualization
(optional)a
OutputTransform
(optional)
Input¶
The Input
class contains the logic for data loading from different sources such as folders, files, tensors, etc.
Every Flash DataModule
can be instantiated with from_datasets()
.
For each additional way you want the user to be able to instantiate your DataModule
, you’ll need to create a Input
.
Each Input
has 2 methods:
load_data()
takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.load_sample()
then takes as input a single element from the output ofload_data
and returns a sample.
By default these methods just return their input, so you don’t need both a load_data()
and a load_sample()
to create a Input
.
Where possible, you should override one of our existing Input
classes.
Let’s start by implementing a TemplateNumpyClassificationInput
, which overrides ClassificationInputMixin
.
The main Input
method that we have to implement is load_data()
.
ClassificationInputMixin
provides utilities for handling targets within flash which need to be called from the load_data()
and load_sample()
.
In this Input
, we’ll also set the num_features
attribute so that we can access it later.
Here’s the code for our TemplateNumpyClassificationInput.load_data
method:
def load_data(
self,
examples: Collection[np.ndarray],
targets: Optional[Sequence[Any]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> Sequence[Dict[str, Any]]:
"""Sets the ``num_features`` attribute and calls ``super().load_data``.
Args:
examples: The ``np.ndarray`` (num_examples x num_features).
targets: Associated targets.
target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.
Returns:
A sequence of samples / sample metadata.
"""
if not self.predicting and isinstance(examples, np.ndarray):
self.num_features = examples.shape[1]
if targets is not None:
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(examples, targets)
and here’s the code for the TemplateNumpyClassificationInput.load_sample
method:
def load_sample(self, sample: Dict[str, Any]) -> Any:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
return sample
Note
Later, when we add our DataModule implementation, we’ll make num_features
available to the user.
For our template Task
, it would be cool if the user could provide a scikit-learn Bunch
as the data source.
To achieve this, we’ll add a TemplateSKLearnClassificationInput
whose load_data
expects a Bunch
as input.
We override our TemplateNumpyClassificationInput
so that we can call super
with the data and targets extracted from the Bunch
.
We perform two additional steps here to improve the user experience:
We set the
num_classes
attribute on thedataset
. Ifnum_classes
is set, it is automatically made available as a property of theDataModule
.We create and set a
ClassificationState
. The labels provided here will be shared with theLabels
output, so the user doesn’t need to provide them.
Here’s the code for the TemplateSKLearnClassificationInput.load_data
method:
def load_data(self, data: Bunch, target_formatter: Optional[TargetFormatter] = None) -> Sequence[Dict[str, Any]]:
"""Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.
Args:
data: The scikit-learn data ``Bunch``.
target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.
Returns:
A sequence of samples / sample metadata.
"""
return super().load_data(data.data, data.target, target_formatter=target_formatter)
We can customize the behaviour of our load_data()
for different stages, by prepending train, val, test, or predict.
For our TemplateSKLearnClassificationInput
, we don’t want to provide any targets to the model when predicting.
We can implement predict_load_data
like this:
def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]:
"""Avoid including targets when predicting.
Args:
data: The scikit-learn data ``Bunch``.
Returns:
A sequence of samples / sample metadata.
"""
return super().load_data(data.data)
InputTransform¶
The InputTransform
object contains all the data transforms.
Internally we inject the InputTransform
transforms at several points along the pipeline.
Defining the standard transforms (typically at least a per_sample_transform
should be defined) for your InputTransform
involves simply overriding the required hook to return a callable transform.
For our TemplateInputTransform
, we’ll just configure a per_sample_transform
.
Let’s first define a to_tensor transform as a staticmethod
:
@staticmethod
def to_tensor(sample: Dict[str, Any]):
"""Transform which converts the sample to a tensor."""
sample[DataKeys.INPUT] = torch.from_numpy(sample[DataKeys.INPUT]).float()
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = torch.as_tensor(sample[DataKeys.TARGET]).long()
return sample
Now in our per_sample_transform
hook, we return the transform:
def per_sample_transform(self) -> Callable:
return self.to_tensor
DataModule¶
The DataModule
is responsible for creating the DataLoader
and injecting the transforms for each stage.
When the user calls a from_*
method (such as from_numpy()
), the following steps take place:
The
from_()
method is called with the name of theInput
to use and the inputs to provide toload_data()
for each stage.The
InputTransform
is created fromcls.input_transform_cls
(if it wasn’t provided by the user) with any provided transforms.The
Input
of the provided name is retrieved from theInputTransform
.A
BaseAutoDataset
is created from theInput
for each stage.The
DataModule
is instantiated with the data sets.
To create our TemplateData
DataModule
, we first need to attach our input transform class like this:
input_transform_cls = TemplateInputTransform
Since we provided a NUMPY
Input
in the TemplateInputTransform
, from_numpy()
will now work with our TemplateData
.
If you’ve defined a fully custom Input
(like our TemplateSKLearnClassificationInput
), then you will need to write a from_*
method for each.
Here’s the from_sklearn
method for our TemplateData
:
@classmethod
def from_sklearn(
cls,
train_bunch: Optional[Bunch] = None,
val_bunch: Optional[Bunch] = None,
test_bunch: Optional[Bunch] = None,
predict_bunch: Optional[Bunch] = None,
input_cls: Type[Input] = TemplateSKLearnClassificationInput,
transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "TemplateData":
"""This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and creates the
``TemplateData`` with them.
Args:
train_bunch: The scikit-learn ``Bunch`` containing the train data.
val_bunch: The scikit-learn ``Bunch`` containing the validation data.
test_bunch: The scikit-learn ``Bunch`` containing the test data.
predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed data module.
"""
ds_kw = dict()
train_input = input_cls(RunningStage.TRAINING, train_bunch, **ds_kw)
target_formatter = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_bunch, target_formatter=target_formatter, **ds_kw),
input_cls(RunningStage.TESTING, test_bunch, target_formatter=target_formatter, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_bunch, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)
The final step is to implement the num_features
property for our TemplateData
.
This is just a convenience for the user that finds the num_features
attribute on any of the data sets and returns it.
Here’s the code:
@property
def num_features(self) -> Optional[int]:
"""Tries to get the ``num_features`` from each dataset in turn and returns the output."""
n_fts_train = getattr(self.train_dataset, "num_features", None)
n_fts_val = getattr(self.val_dataset, "num_features", None)
n_fts_test = getattr(self.test_dataset, "num_features", None)
return n_fts_train or n_fts_val or n_fts_test
BaseVisualization¶
An optional step is to implement a BaseVisualization
.
The BaseVisualization
lets you control how data at various points in the pipeline can be visualized.
This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.
Note
Don’t worry about implementing it right away, you can always come back and add it later!
Here’s the code for our TemplateVisualization
which just prints the data:
class TemplateVisualization(BaseVisualization):
"""The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just
prints the data.
If you want to provide a visualization with your task, you can override these hooks.
"""
def show_load_sample(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
print(samples)
def show_per_sample_transform(
self,
samples: List[Any],
running_stage: RunningStage,
limit_nb_samples: int = None,
figsize: Tuple[int, int] = (6.4, 4.8),
):
print(samples)
We can configure our custom visualization in the TemplateData
using configure_data_fetcher()
like this:
@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
method."""
return TemplateVisualization(*args, **kwargs)
OutputTransform¶
OutputTransform
contains any transforms that need to be applied after the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here’s the SemanticSegmentationOutputTransform
which decodes tokenized model outputs:
class SemanticSegmentationOutputTransform(OutputTransform):
def per_sample_transform(self, sample: Any) -> Any:
resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST)
sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
return super().per_sample_transform(sample)
In your Input
or InputTransform
, you can add metadata to the batch using the METADATA
key.
Your OutputTransform
can then use this metadata in its transforms.
You should use this approach if your postprocessing depends on the state of the input before the InputTransform
transforms.
For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA
.
Here’s an example from the ImageInput
:
@requires("image")
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
w, h = sample[DataKeys.INPUT].size # W x H
if DataKeys.METADATA not in sample:
sample[DataKeys.METADATA] = {}
sample[DataKeys.METADATA].update(
{
"size": (h, w),
"height": h,
"width": w,
}
)
return sample
The METADATA
can now be referenced in your OutputTransform
.
For example, here’s the code for the per_sample_transform
method of the SemanticSegmentationOutputTransform
:
def per_sample_transform(self, sample: Any) -> Any:
resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST)
sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
return super().per_sample_transform(sample)
Now that you’ve got some data, it’s time to add some backbones for your task!