Optional Extras¶
Organize your transforms in transforms.py¶
It can be useful to define your InputTransform
in an input_transform.py
file.
Here’s an example from image/classification/input_transform.py:
@dataclass
class ImageClassificationInputTransform(InputTransform):
image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)
def per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)
def train_per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
]
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)
Add outputs to your Task¶
We recommend that you do most of the heavy lifting in the OutputTransform
.
Specifically, it should include any formatting and transforms that should always be applied to the predictions.
If you want to support different use cases that require different prediction formats, you should add some Output
implementations in an output.py
file.
Some good examples are in flash/core/classification.py.
Here’s the ClassesOutput
Output
:
@CLASSIFICATION_OUTPUTS(name="classes")
class ClassesOutput(PredsClassificationOutput):
"""A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.
Args:
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""
def __init__(self, multi_label: bool = False, threshold: float = 0.5):
super().__init__(multi_label)
self.threshold = threshold
def transform(self, sample: Any) -> Union[int, List[int]]:
sample = super().transform(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
for index, value in enumerate(one_hot):
if value == 1:
result.append(index)
return result
return torch.argmax(sample, -1).tolist()
Alternatively, here’s the LogitsOutput
Output
:
@CLASSIFICATION_OUTPUTS(name="logits")
class LogitsOutput(PredsClassificationOutput):
"""A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list."""
def transform(self, sample: Any) -> Any:
return super().transform(sample).tolist()
Take a look at Predictions (inference) to learn more.
Once you’ve added any optional extras, it’s time to create some examples showing your task in action!