A comprehensive framework for knowledge distillation of the Segment Anything Model (SAM) to create lightweight, efficient versions suitable for resource-constrained environments.
This project implements a knowledge distillation framework for the Segment Anything Model (SAM) to create efficient variants like Tiny MobileSAM. The framework supports various distillation strategies, multiple datasets, and provides tools for fine-tuning and evaluation.
git clone git@github.com:gintmr/MaskGuide.git
# Install dependencies
bash install.shEdit the configuration in Distill_Tiny_MSAM.py or use the YAML config:
# Set paths to your datasets
--train_data_IMC /path/to/your/data
--train_anno_IMC /path/to/your/annotationspython Distill_Tiny_MSAM.py \
--T_model vit_t \
--S_model tiny_msam \
--T_checkpoint_path /path/to/teacher/checkpoint \
--S_checkpoint_path /path/to/student/checkpoint \
--batch_size 8 \
--epochs 20 \
--learning_rate 5.0e-4Img_Encoder: Image Encoder distillationMask_Decoder: Mask Decoder distillationPrompt_Encoder: Prompt Encoder distillation
only_distill: Only perform distillationadd_distill: Add distillation to regular trainingmask&unmask: Combined mask and unmask distillation
distill: Distillation mode ("mask&unmask_v1", "ori")INFERENCE_MODE: Inference mode ("test", "train")MODEL_MODE: Model mode ("test")test_prompts: Prompt type ("bbox", "point")
- MIMC: Marine Image Multi-classification
- UIIS: Underwater Image Segmentation
- COCO: Common Objects in Context
- VOC: PASCAL Visual Object Classes
Datasets must be in COCO format with the following structure:
{
"images": [...],
"categories": [...],
"annotations": [...]
}Training parameters can be configured through:
- Command Line Arguments in
Distill_Tiny_MSAM.py - YAML Configurations in
Tools_configs/ - Environment Variables
Key parameters include:
--batch_size: Training batch size--image_size: Input image size (default: 1024)--learning_rate: Learning rate (default: 5.0e-4)--epochs: Number of training epochs--num_points: Number of random points for supervision--length: Number of masks to process
Set the GPU devices using environment variables:
export CUDA_VISIBLE_DEVICES='0,1,2,3'
export NCCL_P2P_DISABLE='1'The framework supports automatic mixed precision for faster training:
with autocast():
# Forward pass
loss = model(batch)The OceanSegmentationLoss class provides custom loss calculations:
- Dice loss
- Focal loss
- IoU loss
- Distillation losses
Run evaluation on trained models:
python Tools_metrics/eval.py --checkpoint_path /path/to/model