Warning
Multi-gpu training is not currently supported by the ImageEmbedder
task.
Image Embedder¶
The Task¶
Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.
The Flash ImageEmbedder
can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data.
The ImageEmbedder
internally relies on VISSL.
You can read more about our integration with VISSL here: VISSL.
Example¶
Let’s see how to configure a training strategy for the ImageEmbedder
task.
First we create an ImageClassificationData
object using a Dataset from torchvision.
Next, we configure the ImageEmbedder
task with training_strategy
, backbone
, head
and pretraining_transform
.
Finally, we construct a Trainer
and call fit()
.
Here’s the full example:
import torch
from torchvision.datasets import CIFAR10
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=4,
)
# 2. Build the task
embedder = ImageEmbedder(
backbone="vision_transformer",
training_strategy="barlow_twins",
head="barlow_twins_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [32]},
)
# 3. Create the trainer and pre-train the encoder
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)
# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")
# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
],
batch_size=3,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)
# list of embeddings for images sent to the predict function
print(embeddings)
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
You can view the available training strategies with the available_training_strategies()
method.
Note
The "dino"
training strategy only supports single GPU training with strategy="ddp"
.
The head
and pretraining_transform
arguments should match the choice of training_strategy
following this table:
|
|
|
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|