Skip to content

plusnli/MITS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MITS: Enhanced Tree Search Reasoning for LLMs via Pointwise Mutual Information

arXiv License: MIT

[PAKDD 2026 Oral] Official implementation of MITS (Mutual Information Tree Search), a test-time scaling framework for enhancing tree search reasoning in Large Language Models through information-theoretic principles.

🔍 Overview

MITS introduces an effective scoring function based on Pointwise Mutual Information (PMI) that enables step-wise evaluation of reasoning paths and guides search tree expansion via beam search without expensive look-ahead simulations.

✨ Key Features

  • PMI-Based Scoring: Employs Pointwise Mutual Information as a principled scoring metric for intermediate reasoning steps
  • Dynamic Sampling Strategy: Allocates computation resources adaptively based on the uncertainty of the next current step
  • Weighted Average Voting: Provides robust answer selection through weighted averaging

MITS Overview

📦 Installation

Prerequisites

  • Python 3.11
  • CUDA 12.x (for GPU acceleration)

Setup

# Clone the repository
git clone https://github.com/plusnli/MITS.git
cd MITS

# Create conda environment
conda create -n mits python=3.11
conda activate mits

# Install dependencies
pip install -r requirements.txt

Note: The main dependencies include:

  • vllm>=0.8.3 - Efficient LLM inference
  • torch>=2.6.0 - Deep learning framework
  • transformers>=4.51.2 - Hugging Face transformers
  • datasets>=3.5.0 - Dataset loading and processing

🚀 Quick Start

Basic Usage

Run MITS on a dataset with default settings:

python main.py \
  --model Qwen2.5-7B \
  --dataset gsm8k \
  --start_idx 0 \
  --end_idx 100

Advanced Options

Enable dynamic sampling and beam search:

python main.py \
  --model Qwen2.5-7B \
  --prob_model Qwen2.5-3B \
  --dataset strategyqa \
  --num_samples 3 \
  --max_depth 10 \
  --dynamic_sampling \
  --max_dynamic_samples 6 \
  --beam 5 \
  --metric ave \
  --mi_topk 3 \
  --start_idx 0 \
  --end_idx 100

Evaluation Only

Evaluate existing results without running inference:

python main.py \
  --model Qwen2.5-7B \
  --dataset gsm8k \
  --start_idx 0 \
  --end_idx 100 \
  --eval_only

🤖 Supported Models

  • Llama: Llama-3.1-8B, Llama-3.2-{1B,3B}
  • Qwen: Qwen2.5-{0.5B,1.5B,3B,7B,14B,32B}
  • Mistral: Mistral-7B, Ministral-8B
  • Phi: Phi-3.5-mini, Phi-4, Phi-4-mini

📊 Supported Datasets

  • StrategyQA: Multi-hop reasoning
  • ARC-Challenge: Science questions
  • CommonsenseQA: Commonsense reasoning

⚙️ Arguments

Model & Dataset

  • --model: Model name (required)
  • --prob_model: Probability model for scoring (optional, defaults to main model)
  • --dataset: Dataset name (required)

Decoding Parameters

  • --temperature: Sampling temperature (default: 0.85)
  • --top_k: Top-k sampling (default: -1, all tokens)
  • --top_p: Nucleus sampling (default: 1.0)

Tree Search Parameters

  • --num_samples: Number of samples per step (default: 3)
  • --max_depth: Maximum tree depth (default: 10)
  • --max_new_tokens: Max tokens per step (default: 512)
  • --beam: Beam size for beam search (optional)
  • --metric: Scoring metric, 'ave' or 'sum' (default: 'ave')
  • --mi_topk: Top-k paths for reweighting (default: 1)

Dynamic Sampling

  • --dynamic_sampling: Enable dynamic sampling
  • --max_dynamic_samples: Max samples in dynamic mode (default: 6)
  • --ds_kp: Proportional gain for controller (default: 1.0)

Execution

  • --start_idx: Start example index (required)
  • --end_idx: End example index (optional)
  • --num_gpus: Number of GPUs (default: 1)
  • --eval_only: Evaluation mode only

📁 Project Structure

MITS/
├── main.py              # Main entry point
├── model.py             # MITS and BeamMITS implementations
├── utils.py             # Utility functions
├── data_utils/          # Dataset loaders
│   ├── strategyqa.py
│   ├── commonsenseqa.py
│   ├── arc_c.py
│   └── ...
├── figs/                # Figures and visualizations
└── logs/                # Results and logs

📝 Citation

If you find this work useful, please cite:

@article{li2025mits,
  title={MITS: Enhanced Tree Search Reasoning for LLMs via Pointwise Mutual Information},
  author={Li, Jiaxi and Shi, Yucheng and Lu, Jin and Liu, Ninghao},
  journal={arXiv preprint arXiv:2510.03632},
  year={2025}
}

📄 License

This project is licensed under the MIT License.

💬 Contact

For questions or issues, please open an issue on GitHub or contact the authors through the paper.

About

[PAKDD 2026 Oral] MITS: Enhanced Tree Search Reasoning for LLMs via Pointwise Mutual Information

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages