START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation [NeurIPS 2024]
-
Python 3.10.13
conda create -n your_env_name python=3.10.13
-
torch 2.1.1 + cu118
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
-
Requirements: vim_requirements.txt
pip install -r vim/vim_requirements.txt
-
Install
causal_conv1dandmambapip install -e causal_conv1d>=1.1.0pip install -e mamba-1p1p1
Please download PACS dataset from here.
Make sure you use the official train/val/test split in PACS paper.
Take /data/DataSets/ as the saved directory for example:
images -> /data/DataSets/PACS/kfold/art_painting/dog/pic_001.jpg, ...
splits -> /data/DataSets/PACS/pacs_label/art_painting_crossval_kfold.txt, ...
Then set the "data_root" as "/data/DataSets/" and "data" as "PACS" in "main_dg.py".
You can directly set the "data_root" and "data" in "ft-vmamba-t.sh" for training the model.
Firstly download the VMamba-T model pretrained on ImageNet from here and save it to /pretrained_model. To run START-M, you could run the following code. Please set the --data_root argument needs to be changed according to your folder.
base scripts/START-M.sh
You can also train the START-X model by running the following code:
base scripts/START-X.sh
To evaluate the performance of the models, you can download the models trained on PACS as below:
| Methods | Photo | Art | Cartoon | Sketch | Avg. |
|---|---|---|---|---|---|
| START-M | 99.22 | 93.95 | 87.84 | 87.68 | 92.17 |
| START-X | 99.16 | 92.97 | 88.40 | 87.45 | 92.00 |
Please set the --eval as 1, --target as the domain index, and --resume as the saved path of the downloaded models, e.g., /trained/model/path/photo/model.pt in "scripts/test_model_performance.sh". Then you can directly run:
base scripts/test_model_performance.sh
You can also run the following code:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29000 --use_env ../main_dg.py \
--model vmamba_tiny \
--batch-size 32 \
--seed 0 \
--num_workers 16 \
--no_amp \
--data "PACS" \
--data_root [dataset_path] \
--target [domain_index, e.g., 0 for photo] \
--eval 1 \
--resume "/trained/model/path/photo/checkpoint.pth"
@inproceedings{guo2024start,
title={START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation},
author={Guo, Jintao and Qi, Lei and Shi, Yinghuan and Gao, Yang},
booktitle={The Thirty-Eighth Annual Conference on Neural Information Processing Systems},
year={2024}
}
Part of our code is derived from the following repository.
- VMamba: "Vmamba: Visual state space model", NeurIPS 2024
- Vim: "Vision mamba: Efficient visual representation learning with bidirectional state space model", ICML 2024
We thank to the authors for releasing their codes. Please also consider citing their work.