Decompose, Map, and Interpret any ViT !

Vision transformers (ViTs) are now the go-to architecture for vision-based foundation models, but they may be challenging to interpret and may exhibit unexpected behaviors. Gandelsman et al. were able to interpret CLIP-ViT components using text, but how do we interpret arbitrary ViTs which may have different architectures (SWIN, MaxViT, DINO, DINOv2) and trained with different pretraining objectives (Imagenet classification, self supervised learning) ?

We introduce a three step procedure to solve this problem:

  1. RepDecompose : Automatically decompose the final representation into contributions from different model components.
  2. CompAlign : Linearly map these contributions to CLIP space to interpret them using the CLIP text encoder
  3. CompAttribute : Rank components by their importance with respect to specific concepts by computing the variance of the projection of the contributions on the concept embeddings


Retrieve images with reference to text or another image

  • Model components which are responsible for encoding a particular property can be used to retrieve images which are close to a given probe image wrt that property!
  • We use the relevant component contributions (before the projection onto CLIP space) and sum them up, getting a property-specific representation of the image. This can then be used to retrieve image containing the same property.


Mitigate spurious correlations

  • Waterbirds dataset has a spurious correlation between background and foreground attributes. Each image contains a bird (foreground) and a background scene. The task is to classify the bird species into 'waterbird' or 'landbird'. However, the background scene is highly correlated with the bird species, creating a spurious correlation.
  • We can improve model performance on Waterbirds in a zero-shot manner simply by mean-ablating top-10 model components related to “location”!
  • There is a significant increase in the worst group accuracy for all models, accompanied with an increase in the average group accuracy as well.
Model name Worst group
accuracy
Average group
accuracy
DeiT 0.733 → 0.815 0.874 → 0.913
CLIP 0.507 → 0.744 0.727 → 0.790
DINO 0.800 → 0.911 0.900 → 0.938
DINOv2 0.967 → 0.978 0.983 → 0.986
SWIN 0.834 → 0.871 0.927 → 0.944
MaxVit 0.777 → 0.814 0.875 → 0.887
Worst group accuracy and average group accuracy for Waterbirds dataset before and after intervention


Visualizing token importance heatmaps (and segmenting images)

  • Each component contributioncan be further broken down into token-wise contributions. This can be used to visualize the importance of each token in the final representation with respect to a given property and/or component.
  • We can also use this to segment images in a zero-shot manner, similar to other saliency methods such as GradCam.
  • We outperform competitive saliency method baselines on segmenting ImageNet classes (see below table - Chefer et al.'s code does not support MaxVit or Swin).

Algorithm DeiT DINO MaxViT SWIN
pixAcc mIoU mAP pixAcc mIoU mAP pixAcc mIoU mAP pixAcc mIoU mAP
Chefer et al 0.7307 0.4785 0.7870 0.7309 0.4541 0.8080 - - - - - -
GradCam 0.6533 0.4625 0.7129 0.7045 0.4309 0.7481 0.4732 0.1705 0.4243 0.5973 0.2360 0.5365
Decompose 0.7719 0.5291 0.8305 0.7577 0.4863 0.8111 0.7163 0.4237 0.7237 0.7136 0.4338 0.7620



Read the paper for more details!

BibTeX


@inproceedings{
  balasubramanian2024decomposing,
  title={Decomposing and Interpreting Image Representations via Text in ViTs Beyond {CLIP}},
  author={Sriram Balasubramanian and Samyadeep Basu and Soheil Feizi},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
  year={2024},
  url={https://openreview.net/forum?id=Vhh7ONtfvV}
  }