ChemGPT is a transformer-based Large Language Model (LLM) designed to understand and generate chemical reaction data. It uses a GPT architecture trained from scratch on SMILES representations of chemical reactions and functional groups.
- Custom Tokenization: Uses a Byte-Pair Encoding (BPE) tokenizer specifically trained on chemical SMILES strings.
- GPT Architecture: Implements a standard decoder-only transformer with Multi-Head Flash Attention.
- Distributed Training: Supports Distributed Data Parallel (DDP) for efficient training across multiple GPUs.
- Checkpointing: Automatically saves model states, optimizer states, and loss history after every epoch.
- Live Monitoring: Real-time loss tracking and validation metrics via
tqdm.
This project uses uv for dependency management.
-
Clone the repository:
git clone https://github.com/yourusername/ChemGPT.git cd ChemGPT -
Install dependencies:
uv sync
Alternatively, if not using uv:
pip install torch tokenizers tqdm matplotlib
The model expects a pickle file located at data/pretraining_data.pickle. This file should contain a dictionary with two lists:
reactions: List of SMILES strings representing reactions.groups: List of SMILES strings representing associated groups.
Input Format: The model processes data in the following format:
Reaction:"<Reaction_SMILES>" Group:"<Group_SMILES>"
If you need to retrain the tokenizer on new data:
uv run python src/train_tokenizer.pyThis saves the tokenizer to ChemBPETokenizer/tokenizer.json.
Single GPU / CPU:
uv run python src/pretrain.pyMulti-GPU (DDP): To utilize multiple GPUs via Distributed Data Parallel:
uv run python src/pretrain.py --ddpCheckpoints are saved automatically to the checkpoints/ directory:
checkpoint_epoch_N.pt: State at end of epoch N.checkpoint_latest.pt: Always contains the most recent state.
Run the test suite to ensure the context window and tokenizer are configured correctly:
uv run pytest tests/ChemGPT/
├── ChemBPETokenizer/ # Trained tokenizer artifacts
├── checkpoints/ # Saved model checkpoints
├── data/ # Input data (pickle)
├── src/
│ ├── Attention.py # Multi-Head Attention implementation
│ ├── GPT.py # Main model architecture
│ ├── GPTDataloader.py# Dataset and Dataloader logic
│ ├── Tokenizer.py # Tokenizer wrapper
│ ├── Transformer.py # Transformer block implementation
│ ├── config.py # Model configuration
│ └── pretrain.py # Main training script (Single & DDP)
├── tests/ # Unit tests
├── pyproject.toml # Dependency configuration
└── README.md
Based on the architecture and training loops from "Build a Large Language Model From Scratch" by Sebastian Raschka.