Multi-label Text Classification¶
The Task¶
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (text in this case).
Multi-label text classification is supported by the TextClassifier
via the multi-label
argument.
Example¶
Let’s look at the task of classifying comment toxicity. The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge. The data is stored in CSV files with this structure:
"id","comment_text","toxic","severe_toxic","obscene","threat","insult","identity_hate"
"0000997932d777bf","...",0,0,0,0,0,0
"0002bcb3da6cb337","...",1,1,1,0,1,0
"0005c987bdfc9d4b","...",1,0,0,0,0,0
...
Once we’ve downloaded the data using download_data()
, we create the TextClassificationData
.
We select a pre-trained backbone to use for our TextClassifier
and finetune on the toxic comments data.
The backbone can be any BERT classification model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the TextClassifier
and the TextClassificationData
!
Next, we use the trained TextClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import flash
import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
datamodule = TextClassificationData.from_csv(
"comment_text",
["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
train_file="data/jigsaw_toxic_comments/train.csv",
val_split=0.1,
batch_size=4,
)
# 2. Build the task
model = TextClassifier(
backbone="unitary/toxic-bert",
labels=datamodule.labels,
multi_label=datamodule.multi_label,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Generate predictions for a few comments!
datamodule = TextClassificationData.from_lists(
predict_data=[
"No, he is an arrogant, self serving, immature idiot. Get it right.",
"U SUCK HANNAH MONTANA",
"Would you care to vote? Thx.",
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The multi-label text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash text_classification from_toxic
To view configuration options and options for running the text classifier with your own data, use:
flash text_classification --help
Serving¶
The TextClassifier
is servable.
For more information, see Text Classification.