File size: 7,168 Bytes
61db6a0 06fb61a 61db6a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# CLIP-Vision-BERT Multilingual Pre-trained Model
Pretrained CLIP-Vision-BERT pre-trained on translated [Conceptual-12M](https://github.com/google-research-datasets/conceptual-12m) image-text pairs using a masked language modeling (MLM) objective. 10M cleaned image-text pairs are translated using [mBART-50 one-to-many model](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) to 2.5M examples each in English, French, German and Spanish. This model is based on the VisualBERT which was introduced in
[this paper](https://arxiv.org/abs/1908.03557) and first released in
[this repository](https://github.com/uclanlp/visualbert). We trained CLIP-Vision-BERT model during community week hosted by Huggingface 🤗 using JAX/Flax.
This checkpoint is pre-trained for 70k steps.
## Model description
CLIP-Vision-BERT is a modified BERT model which takes in visual embeddings from CLIP-Vision transformer and concatenates them with BERT textual embeddings before passing them to the self-attention layers of BERT. This is done for deep cross-modal interaction between the two modes.
## Intended uses & limitations❗️
You can use the raw model for masked language modeling, but it's mostly intended to be fine-tuned on a downstream task.
Note that this model is primarily aimed at being fine-tuned on tasks such as visuo-linguistic sequence classification or visual question answering. We used this model to fine-tuned on a multi-translated version of the visual question answering task - [VQA v2](https://visualqa.org/challenge.html). Since Conceptual-12M is a dataset scraped from the internet, it will involve some biases which will also affect all fine-tuned versions of this model.
### How to use❓
You can use this model directly with a pipeline for masked language modeling. You will need to clone the model from [here](https://github.com/gchhablani/multilingual-vqa). An example of usage is shown below:
```python
>>> from torchvision.io import read_image
>>> import numpy as np
>>> import os
>>> from transformers import CLIPProcessor, BertTokenizerFast
>>> from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForMaskedLM
>>> image_path = os.path.join('images/val2014', os.listdir('images/val2014')[0])
>>> img = read_image(image_path)
>>> clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.
>>> clip_outputs = clip_processor(images=img)
>>> clip_outputs['pixel_values'][0] = clip_outputs['pixel_values'][0].transpose(1,2,0) # Need to transpose images as model expected channel last images.
>>> tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
>>> model = FlaxCLIPVisionBertForMaskedLM.from_pretrained('flax-community/clip-vision-bert-cc12m-70k')
>>> text = "Three teddy [MASK] in a showcase."
>>> tokens = tokenizer([text], return_tensors="np")
>>> pixel_values = np.concatenate([clip_outputs['pixel_values']])
>>> outputs = model(pixel_values=pixel_values, **tokens)
>>> indices = np.where(tokens['input_ids']==tokenizer.mask_token_id)
>>> preds = outputs.logits[indices][0]
>>> sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores
>>> top_5_indices = sorted_indices[:5]
>>> top_5_tokens = tokenizer.convert_ids_to_tokens(top_5_indices)
>>> top_5_scores = preds[top_5_indices]
>>> print(dict(zip(top_5_tokens, top_5_scores)))
{'bears': 19.400345, 'bear': 17.866995, 'animals': 14.453735, 'dogs': 14.427426, 'girls': 14.097499}
```
## Training data 🏋🏻♂️
The CLIP-Vision-BERT model was pre-trained on a translated version of the Conceptual-12m dataset in four languages using mBART-50: English, French, German and Spanish, with 2.5M image-text pairs in each.
The dataset captions and image urls can be downloaded from [flax-community/conceptual-12m-mbart-50-translated](https://huggingface.co/datasets/flax-community/conceptual-12m-mbart-50-multilingual).
## Data Cleaning 🧹
Though the original dataset contains 12M image-text pairs, a lot of the URLs are invalid now, and in some cases, images are corrupt or broken. We remove such examples from our data, which leaves us with approximately 10M image-text pairs.
**Splits**
We used 99% of the 10M examples as a train set, and the remaining ~ 100K examples as our validation set.
## Training procedure 👨🏻💻
### Preprocessing
The texts are lowercased and tokenized using WordPiece and a shared vocabulary size of approximately 110,000. The beginning of a new document is marked with `[CLS]` and the end of one by `[CLS]`
The details of the masking procedure for each sentence are the following:
- 15% of the tokens are masked.
- In 80% of the cases, the masked tokens are replaced by `[MASK]`.
- In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
- In the 10% remaining cases, the masked tokens are left as is.
The visual embeddings are taken from the CLIP-Vision model and combined with the textual embeddings inside the BERT embedding layer. The padding is done in the middle. Here is an example of what the embeddings look like:
```
[CLS Emb] [Textual Embs] [SEP Emb] [Pad Embs] [Visual Embs]
```
A total length of 128 tokens, including the visual embeddings, is used. The texts are truncated or padded accordingly.
### Pretraining
The checkpoint of the model was trained on Google Cloud Engine TPUv3-8 machine (with 335 GB of RAM, 1000 GB of hard drive, 96 CPU cores) **8 v3 TPU cores** for 70k steps with a per device batch size of 64 and a max sequence length of 128. The optimizer used is Adafactor with a learning rate of 1e-4, learning rate warmup for 1,000 steps, and linear decay of the learning rate after.
We tracked experiments using TensorBoard. Here is the link to the main dashboard: [CLIP Vision BERT CC12M Pre-training Dashboard](https://huggingface.co/flax-community/multilingual-vqa-pt-ckpts/tensorboard)
#### **Pretraining Results 📊**
The model at this checkpoint reached **eval accuracy of 67.85%** and **with train loss at 1.756 and eval loss at 1.706**.
## Team Members
- Gunjan Chhablani [@gchhablani](https://hf.co/gchhablani)
- Bhavitvya Malik[@bhavitvyamalik](https://hf.co/bhavitvyamalik)
## Acknowledgements
We thank [Nilakshan Kunananthaseelan](https://huggingface.co/knilakshan20) for helping us whenever he could get a chance. We also thank [Abheesht Sharma](https://huggingface.co/abheesht) for helping in the discussions in the initial phases. [Luke Melas](https://github.com/lukemelas) helped us get the CC-12M data on our TPU-VMs and we are very grateful to him.
This project would not be possible without the help of [Patrick](https://huggingface.co/patrickvonplaten) and [Suraj](https://huggingface.co/valhalla) who met with us frequently and helped review our approach and guided us throughout the project.
Huge thanks to Huggingface 🤗 & Google Jax/Flax team for such a wonderful community week and for answering our queries on the Slack channel, and for providing us with the TPU-VMs.
<img src=https://pbs.twimg.com/media/E443fPjX0AY1BsR.jpg:large>
|