Skip to content

whyujjwal/MechGraph

Repository files navigation

MechGraph: Neuro-Symbolic Molecular Mechanism Prediction

Python PyTorch License

🧪 Overview

MechGraph is a neuro-symbolic architecture designed to predict complex organic reaction mechanisms. It bridges the gap between molecular graph representations and large language models (LLMs) to generate step-by-step mechanistic explanations.

Standard LLMs trained on text representations of molecules (like SMILES) often fail to grasp 2D/3D topological information crucial for understanding reactivity. MechGraph addresses this by using a Graph Neural Network (GNN) to encode molecular structures and a trainable Projector to map these graph embeddings directly into the LLM's token embedding space.

✨ Key Features

  • 🔬 Multimodal Input: Processes 2D molecular graphs (via PyTorch Geometric) and text prompts simultaneously
  • 🧠 GIN Encoder: Graph Isomorphism Network extracts topology-aware molecular features
  • 🔗 Modality Projector: Learnable adapter aligns graph features with LLM token dimensions
  • 🤖 LLM Integration: Compatible with HuggingFace models (Llama-2, Mistral, TinyLlama, etc.)
  • 📊 Comprehensive Evaluation: Built-in metrics for accuracy, validity, BLEU-4, and more

📁 Project Structure

MechGraph/
├── mechgraph/                    # Main Python package
│   ├── __init__.py
│   ├── models/                   # Neural network components
│   │   ├── __init__.py
│   │   ├── graph_encoder.py      # GIN encoder for molecular graphs
│   │   └── mechgraph_model.py    # Main multimodal model
│   ├── data/                     # Data processing
│   │   ├── __init__.py
│   │   ├── processor.py          # SMILES to graph conversion
│   │   └── dataset.py            # Dataset loaders
│   ├── evaluation/               # Evaluation tools
│   │   ├── __init__.py
│   │   ├── metrics.py            # Accuracy, BLEU, validity metrics
│   │   └── logger.py             # Experiment logging
│   └── utils/                    # Utilities
│       ├── __init__.py
│       └── visualization.py      # Training plots
├── scripts/                      # Executable scripts
│   ├── train.py                  # Training script
│   ├── inference.py              # Run predictions
│   ├── evaluate.py               # Model evaluation
│   ├── save_model.py             # Save checkpoints
│   └── generate_tables.py        # Generate paper tables
├── configs/                      # Configuration files
│   └── default.yaml              # Default hyperparameters
├── data/                         # Data files
│   ├── pubchem-10m.txt           # PubChem molecules
│   └── pmechdb_data/             # PMechDB reaction data
├── tables/                       # Generated evaluation tables
├── notebooks/
│   └── MechGraph.ipynb           # Interactive notebook
├── requirements.txt              # Python dependencies
├── setup.py                      # Package installation
└── README.md                     # This file

🚀 Installation

Prerequisites

  • Python 3.9+
  • CUDA 11.8+ (recommended for GPU training)
  • 8GB+ GPU memory (for default LLM)

Quick Install

# Clone the repository
git clone https://github.com/whyujjwal/MechGraph.git
cd MechGraph

# Create virtual environment
conda create -n mechgraph python=3.10
conda activate mechgraph

# Install dependencies
pip install -r requirements.txt

# Install package in development mode
pip install -e .

PyTorch Geometric Installation

If you encounter issues with PyG, install it manually:

# For CUDA 11.8
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
pip install torch-geometric

See the PyG installation guide for other CUDA versions.

LLM Access

For gated models (e.g., Llama-2), authenticate with HuggingFace:

huggingface-cli login

🏗️ Architecture

MechGraph consists of three differentiable components:

1. GraphEncoder (GIN)

Input: Node features (atomic numbers) + Edge indices
      ↓
[GIN Conv] → [BatchNorm] → [ReLU] → [Dropout]  (× N layers)
      ↓
[Global Mean Pooling]
      ↓
Output: Graph embedding [B, Graph_Hidden_Dim]

2. Projector

Input: Graph embedding [B, Graph_Hidden_Dim]
      ↓
[Linear Layer]
      ↓
Output: LLM-compatible token [B, 1, LLM_Hidden_Dim]

3. MechGraphModel

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  Molecule Graph ──→ [GraphEncoder] ──→ [Projector] ──┐      │
│                                                      │      │
│                                                      ↓      │
│  Text Prompt ──→ [Tokenizer] ──→ [Embeddings] ──→ [Concat]  │
│                                                      │      │
│                                                      ↓      │
│                                               [Frozen LLM]  │
│                                                      │      │
│                                                      ↓      │
│                                            Mechanism Output │
└─────────────────────────────────────────────────────────────┘

📖 Usage

Quick Start

from mechgraph import MechGraphModel, MoleculeProcessor

# Initialize
processor = MoleculeProcessor()
model = MechGraphModel(llm_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# Convert SMILES to graph
graph = processor.smiles_to_graph("c1ccccc1")  # Benzene

# Get graph embedding
embedding = model.get_graph_embedding(graph)

Training

# Stage 1: Alignment (molecule-description pairs)
python scripts/train.py --stage alignment --epochs 2

# Stage 2: Instruction tuning (reaction-mechanism pairs)
python scripts/train.py --stage instruction --epochs 5

# With custom settings
python scripts/train.py \
    --stage instruction \
    --llm_path "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
    --batch_size 4 \
    --learning_rate 1e-4 \
    --max_samples 1000

Inference

# Run inference on a SMILES string
python scripts/inference.py --smiles "CC(=O)O"

# With trained checkpoint
python scripts/inference.py \
    --smiles "c1ccccc1" \
    --checkpoint checkpoints/mechgraph_epoch_2.pt

Evaluation

# Evaluate model
python scripts/evaluate.py --checkpoint checkpoints/mechgraph_epoch_2.pt

# Generate paper tables
python scripts/generate_tables.py

📊 Data Preparation

Stage 1: Alignment

Uses molecule-description pairs. The included data/pubchem-10m.txt contains SMILES strings from PubChem.

Stage 2: Instruction Tuning

Requires the PMechDB dataset:

  1. Visit PMechDB Download Page
  2. Download the dataset
  3. Place CSV files in data/pmechdb_data/

Expected structure:

data/pmechdb_data/
├── manually_curated_train.csv
├── manually_curated_test_challenging.csv
├── combinatorial_train.csv
├── combinatorial_test.csv
└── combinatorial_all.csv

📈 Results

Model Top-1 Accuracy Mechanism Validity BLEU-4
BioBERT 42.5% 68.2% 0.35
GPT-4 (Zero-shot) 58.1% 84.3% 0.61
MolRAG 65.3% 78.9% 0.72
MechGraph (Ours) 82.4% 91.5% 0.88

🔧 Configuration

Edit configs/default.yaml to customize:

model:
  llm_path: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
  freeze_llm: true
  graph_hidden_dim: 128
  num_gnn_layers: 3

training:
  epochs: 2
  batch_size: 2
  learning_rate: 1.0e-4

📚 API Reference

MechGraphModel

MechGraphModel(
    llm_path: str = "meta-llama/Llama-2-7b-hf",
    freeze_llm: bool = True,
    node_feature_dim: int = 1,
    graph_hidden_dim: int = 128,
    num_gnn_layers: int = 3
)

MoleculeProcessor

processor = MoleculeProcessor(add_hydrogens=True)
graph = processor.smiles_to_graph("CCO")  # Ethanol
graphs = processor.batch_smiles_to_graphs(["CCO", "CC(=O)O"])

Evaluation Metrics

from mechgraph.evaluation import (
    calculate_top1_accuracy,
    calculate_bleu4,
    calculate_mechanism_validity,
    calculate_levenshtein_distance
)

🤝 Contributing

Contributions are welcome! Please:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit changes (git commit -m 'Add amazing feature')
  4. Push to branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.


📖 Citation

If you use MechGraph in your research, please cite:

@article{mechgraph2024,
  title={MechGraph: Neuro-Symbolic Molecular Mechanism Prediction},
  author={MechGraph Team},
  journal={arXiv preprint},
  year={2024}
}

🙏 Acknowledgments

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages