Formatting Classification Targets¶
This guide details the different target formats supported by classification tasks in Flash.
By default, the target format and any additional metadata (labels
, num_classes
, multi_label
) will be inferred from your training data.
You can override this behaviour by passing your own TargetFormatter
using the target_formatter
argument.
Single Label¶
Classification targets are described as single label (DataModule.multi_label = False
) if each data sample corresponds to a single class.
Class Indexes¶
Targets formatted as class indexes are represented by a single number, e.g. train_targets = [0, 1, 0]
.
No labels
will be inferred.
The inferred num_classes
is the maximum index plus one (we assume that class indexes are zero-based).
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[0, 1, 0],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False
Alternatively, you can provide a SingleNumericTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[0, 1, 0],
... target_formatter=SingleNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False
Labels¶
Targets formatted as labels are represented by a single string, e.g. train_targets = ["cat", "dog", "cat"]
.
The inferred labels
will be the unique labels in the train targets sorted alphanumerically.
The inferred num_classes
is the number of labels.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> datamodule.multi_label
False
Alternatively, you can provide a SingleLabelTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "dog", "cat"],
... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False
One-hot Binaries¶
Targets formatted as one-hot binaries are represented by a binary list with a single index (the target class index) set to 1
, e.g. train_targets = [[1, 0], [0, 1], [1, 0]]
.
No labels
will be inferred.
The inferred num_classes
is the length of the binary list.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0], [0, 1], [1, 0]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False
Alternatively, you can provide a SingleBinaryTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0], [0, 1], [1, 0]],
... target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['dog', 'cat']
>>> datamodule.multi_label
False
Multi Label¶
Classification targets are described as multi label (DataModule.multi_label = True
) if each data sample corresponds to zero or more (and perhaps many) classes.
Class Indexes¶
Targets formatted as multi label class indexes are represented by a list of class indexes, e.g. train_targets = [[0], [0, 1], [1, 2]]
.
No labels
will be inferred.
The inferred num_classes
is the maximum target value plus one (we assume that targets are zero-based).
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[0], [0, 1], [1, 2]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True
Alternatively, you can provide a MultiNumericTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[0], [0, 1], [1, 2]],
... target_formatter=MultiNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True
Labels¶
Targets formatted as multi label are represented by a list of strings, e.g. train_targets = [["cat"], ["cat", "dog"], ["dog", "rabbit"]]
.
The inferred labels
will be the unique labels in the train targets sorted alphanumerically.
The inferred num_classes
is the number of labels.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True
Alternatively, you can provide a MultiLabelTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
... target_formatter=MultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True
Comma Delimited¶
Targets formatted as comma delimited mutli label are given as comma delimited strings, e.g. train_targets = ["cat", "cat,dog", "dog,rabbit"]
.
The inferred labels
will be the unique labels in the train targets sorted alphanumerically.
The inferred num_classes
is the number of labels.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat,dog", "dog,rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True
Alternatively, you can provide a CommaDelimitedMultiLabelTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import CommaDelimitedMultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat,dog", "dog,rabbit"],
... target_formatter=CommaDelimitedMultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True
Space Delimited¶
Targets formatted as space delimited mutli label are given as space delimited strings, e.g. train_targets = ["cat", "cat dog", "dog rabbit"]
.
The inferred labels
will be the unique labels in the train targets sorted alphanumerically.
The inferred num_classes
is the number of labels.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat dog", "dog rabbit"],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True
Alternatively, you can provide a SpaceDelimitedTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SpaceDelimitedTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["cat", "cat dog", "dog rabbit"],
... target_formatter=SpaceDelimitedTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True
Multi-hot Binaries¶
Targets formatted as one-hot binaries are represented by a binary list with a zero or more indices (the target class indices) set to 1
, e.g. train_targets = [[1, 0, 0], [1, 1, 0], [0, 1, 1]]
.
No labels
will be inferred.
The inferred num_classes
is the length of the binary list.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True
Alternatively, you can provide a MultiBinaryTargetFormatter
to override the behaviour.
Here’s an example:
>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
... target_formatter=MultiBinaryTargetFormatter(labels=["dog", "cat", "rabbit"]),
... transform_kwargs=dict(image_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True