[PAKDD 2026 Oral] Official implementation of MITS (Mutual Information Tree Search), a test-time scaling framework for enhancing tree search reasoning in Large Language Models through information-theoretic principles.
MITS introduces an effective scoring function based on Pointwise Mutual Information (PMI) that enables step-wise evaluation of reasoning paths and guides search tree expansion via beam search without expensive look-ahead simulations.
- PMI-Based Scoring: Employs Pointwise Mutual Information as a principled scoring metric for intermediate reasoning steps
- Dynamic Sampling Strategy: Allocates computation resources adaptively based on the uncertainty of the next current step
- Weighted Average Voting: Provides robust answer selection through weighted averaging
- Python 3.11
- CUDA 12.x (for GPU acceleration)
# Clone the repository
git clone https://github.com/plusnli/MITS.git
cd MITS
# Create conda environment
conda create -n mits python=3.11
conda activate mits
# Install dependencies
pip install -r requirements.txtNote: The main dependencies include:
vllm>=0.8.3- Efficient LLM inferencetorch>=2.6.0- Deep learning frameworktransformers>=4.51.2- Hugging Face transformersdatasets>=3.5.0- Dataset loading and processing
Run MITS on a dataset with default settings:
python main.py \
--model Qwen2.5-7B \
--dataset gsm8k \
--start_idx 0 \
--end_idx 100Enable dynamic sampling and beam search:
python main.py \
--model Qwen2.5-7B \
--prob_model Qwen2.5-3B \
--dataset strategyqa \
--num_samples 3 \
--max_depth 10 \
--dynamic_sampling \
--max_dynamic_samples 6 \
--beam 5 \
--metric ave \
--mi_topk 3 \
--start_idx 0 \
--end_idx 100Evaluate existing results without running inference:
python main.py \
--model Qwen2.5-7B \
--dataset gsm8k \
--start_idx 0 \
--end_idx 100 \
--eval_only- Llama: Llama-3.1-8B, Llama-3.2-{1B,3B}
- Qwen: Qwen2.5-{0.5B,1.5B,3B,7B,14B,32B}
- Mistral: Mistral-7B, Ministral-8B
- Phi: Phi-3.5-mini, Phi-4, Phi-4-mini
- StrategyQA: Multi-hop reasoning
- ARC-Challenge: Science questions
- CommonsenseQA: Commonsense reasoning
--model: Model name (required)--prob_model: Probability model for scoring (optional, defaults to main model)--dataset: Dataset name (required)
--temperature: Sampling temperature (default: 0.85)--top_k: Top-k sampling (default: -1, all tokens)--top_p: Nucleus sampling (default: 1.0)
--num_samples: Number of samples per step (default: 3)--max_depth: Maximum tree depth (default: 10)--max_new_tokens: Max tokens per step (default: 512)--beam: Beam size for beam search (optional)--metric: Scoring metric, 'ave' or 'sum' (default: 'ave')--mi_topk: Top-k paths for reweighting (default: 1)
--dynamic_sampling: Enable dynamic sampling--max_dynamic_samples: Max samples in dynamic mode (default: 6)--ds_kp: Proportional gain for controller (default: 1.0)
--start_idx: Start example index (required)--end_idx: End example index (optional)--num_gpus: Number of GPUs (default: 1)--eval_only: Evaluation mode only
MITS/
├── main.py # Main entry point
├── model.py # MITS and BeamMITS implementations
├── utils.py # Utility functions
├── data_utils/ # Dataset loaders
│ ├── strategyqa.py
│ ├── commonsenseqa.py
│ ├── arc_c.py
│ └── ...
├── figs/ # Figures and visualizations
└── logs/ # Results and logs
If you find this work useful, please cite:
@article{li2025mits,
title={MITS: Enhanced Tree Search Reasoning for LLMs via Pointwise Mutual Information},
author={Li, Jiaxi and Shi, Yucheng and Lu, Jin and Liu, Ninghao},
journal={arXiv preprint arXiv:2510.03632},
year={2025}
}This project is licensed under the MIT License.
For questions or issues, please open an issue on GitHub or contact the authors through the paper.
