Skip to content

lingeringlight/START

Repository files navigation

START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation [NeurIPS 2024]

Environments for Training

  • Python 3.10.13

    • conda create -n your_env_name python=3.10.13
  • torch 2.1.1 + cu118

    • pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  • Requirements: vim_requirements.txt

    • pip install -r vim/vim_requirements.txt
  • Install causal_conv1d and mamba

    • pip install -e causal_conv1d>=1.1.0
    • pip install -e mamba-1p1p1

DataSets

Please download PACS dataset from here. Make sure you use the official train/val/test split in PACS paper. Take /data/DataSets/ as the saved directory for example:

images -> /data/DataSets/PACS/kfold/art_painting/dog/pic_001.jpg, ...
splits -> /data/DataSets/PACS/pacs_label/art_painting_crossval_kfold.txt, ...

Then set the "data_root" as "/data/DataSets/" and "data" as "PACS" in "main_dg.py".

You can directly set the "data_root" and "data" in "ft-vmamba-t.sh" for training the model.

Training

Firstly download the VMamba-T model pretrained on ImageNet from here and save it to /pretrained_model. To run START-M, you could run the following code. Please set the --data_root argument needs to be changed according to your folder.

base scripts/START-M.sh

You can also train the START-X model by running the following code:

base scripts/START-X.sh

Evaluation

To evaluate the performance of the models, you can download the models trained on PACS as below:

Methods Photo Art Cartoon Sketch Avg.
START-M 99.22 93.95 87.84 87.68 92.17
START-X 99.16 92.97 88.40 87.45 92.00

Please set the --eval as 1, --target as the domain index, and --resume as the saved path of the downloaded models, e.g., /trained/model/path/photo/model.pt in "scripts/test_model_performance.sh". Then you can directly run:

base scripts/test_model_performance.sh

You can also run the following code:

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29000 --use_env ../main_dg.py \
--model vmamba_tiny \
--batch-size 32 \
--seed 0 \
--num_workers 16 \
--no_amp \
--data "PACS" \
--data_root [dataset_path] \
--target [domain_index, e.g., 0 for photo] \
--eval 1 \
--resume "/trained/model/path/photo/checkpoint.pth"

Citations

@inproceedings{guo2024start,
  title={START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation},
  author={Guo, Jintao and Qi, Lei and Shi, Yinghuan and Gao, Yang},
  booktitle={The Thirty-Eighth Annual Conference on Neural Information Processing Systems},
  year={2024}
}

Acknowledgement

Part of our code is derived from the following repository.

  • VMamba: "Vmamba: Visual state space model", NeurIPS 2024
  • Vim: "Vision mamba: Efficient visual representation learning with bidirectional state space model", ICML 2024

We thank to the authors for releasing their codes. Please also consider citing their work.

About

The official implementation for START (NeurIPS 2024).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors