The Example¶
Now you’ve implemented your task, it’s time to add an example showing how cool it is!
We usually provide one example in examples/.
You can base these off of our template.py
examples.
The example should:
download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)
load the data into a
DataModule
create an instance of the
Task
create a
Trainer
call
finetune()
orfit()
to train your modelgenerate predictions for a few examples
save the checkpoint
For our template example we don’t have a pretrained backbone, so we can just call fit()
rather than finetune()
.
Here’s the full example (examples/template.py):
import flash
import numpy as np
import torch
from flash.template import TemplateData, TemplateSKLearnClassifier
from sklearn import datasets
# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
train_bunch=datasets.load_iris(),
val_split=0.1,
batch_size=4,
)
# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify a few examples
datamodule = TemplateData.from_numpy(
predict_data=[
np.array([4.9, 3.0, 1.4, 0.2]),
np.array([6.9, 3.2, 5.7, 2.3]),
np.array([7.2, 3.0, 5.8, 1.6]),
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("template_model.pt")
We get this output:
['setosa', 'virginica', 'versicolor']
Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!