Train a Process Reward Model, use it to steer large-model code translation, and validate the resulting parallel programs on HeCBench.
This repository contains the full research pipeline for latent-guided parallel code translation. We train a Process Reward Model (PRM) on validator-derived reward signals, uses that PRM to choose among multiple latent or code candidates during large-model inference, and then validates the generated translations with the HeCBench toolchain.
The system is organized around one practical question:
Can a reward model trained from compiler/runtime feedback guide a large model toward more reliable parallel code translations?
This method is built for experiments involving serial C/C++, OpenMP, CUDA, and cross-API translation directions such as CUDA -> OpenMP or serial -> CUDA.
flowchart LR
A[Build / score code data] --> B[Train PRM]
B --> C[PRM Branch Selection Test]
C --> D[Run PRM-guided inference]
D --> E[Validate generated translations]
E --> F[validators_stats_*.txt]
Most users follow one of these paths:
| Goal | Start here | What you run |
|---|---|---|
| Reproduce results from a released checkpoint | Inference | coconut_large_model_inference/run_inference_w_prm_modal_code_parallel.sh |
| Train your own PRM | PRM training | PRM/run_pqm_qwen_code.sh or PRM/run_pqm_qwen_code_for_valid.sh |
| Rebuild the training data | Dataset rebuild | dataset/run_build_hecbench_dataset_modal.sh |
Parallax/
βββ PRM/ # PRM training, validation, vectors, checkpoints
β βββ train_main.py # Main PRM train/validate entry point
β βββ run_pqm_qwen_code.sh # Train on full code training data
β βββ run_pqm_qwen_code_for_valid.sh
β β # Train on filtered data without validation leakage
β βββ run_pqm_qwen_validate_code.sh
β β # Run the paper's branch-selection PRM validation test
β βββ loss_graph.py # Plot training loss
β βββ plot_distributions.py # Plot PRM score distributions
βββ coconut_large_model_inference/ # Large-model inference with optional PRM guidance
β βββ run_inference_w_prm_modal_code_parallel.sh
β β # Modal inference over translation directions
β βββ inference_modal_with_code_option_with_retry.py
β βββ coconut_w_prm.py
β βββ test_code_a2.jsonl # Default code inference set
βββ dataset/ # HeCBench data building, validation, and scoring
βββ full_eval_validation_modal.py
βββ run_reeval.sh
βββ run_build_hecbench_dataset_modal.sh
βββ run_make_clean_run_parallel.sh
- Train a reward model over code-translation trajectories and validator scores.
- Use latent vector caches instead of re-tokenizing or recomputing every sample.
- Support full training data and a filtered split that removes validation examples.
- Run two-phase training with early backbone freezing and later unfreezing.
- Use PRM/run_pqm_qwen_validate_code.sh for the branch-selection test reported in the paper.
- Validate checkpoints against PRM/branch_selection_test.jsonl.
- Control tolerance-band metrics through the validation wrapper.
- Keep validation vectors separate in PRM/vectors_code_validate_new/.
- Generate multiple candidates per latent/code step.
- Score candidates with the PRM during inference.
- Run multiple translation directions in parallel on Modal GPUs.
- Produce
inference_result_*.jsonlfiles for downstream HeCBench validation.
- Clean generated code, build inference training rows, and run validators.
- Produce per-run validator logs and
validators_stats_*.txtsummaries. - Support manual make/clean checks for borderline or missing compiled files.
Use this path when you already have a trained PRM checkpoint.
Checkpoints are not bundled with the repository. Put the checkpoint under
PRM/checkpoints/ and point scripts at the concrete
checkpoint-* directory:
PRM/checkpoints/<run-name>/checkpoint-568/
cd coconut_large_model_inference
bash run_inference_w_prm_modal_code_parallel.shBefore running, update the important variables in coconut_large_model_inference/run_inference_w_prm_modal_code_parallel.sh:
| Variable | What to set |
|---|---|
PRM_CHECKPOINT |
Path to your PRM checkpoint, usually ../PRM/checkpoints/<run>/checkpoint-*. |
PRM_MODEL_ID |
Backbone model used to train the PRM. Current runs use Qwen/Qwen2.5-Coder-7B-Instruct. |
N_CANDIDATES |
Number of candidates scored by the PRM. Higher values are slower but give more choices. |
OUTPUT_PATH |
Name of the inference JSONL to create. Use a descriptive suffix. |
DIRECTIONS |
Translation directions, for example cuda:omp, omp:cuda, serial:omp, serial:cuda. |
MODAL_GPU_SPEC |
Modal GPU type/count, for example H200:2 or A100-80GB:4. |
The default inference set is coconut_large_model_inference/test_code_a2.jsonl.
Copy the generated inference_result_*.jsonl file into:
dataset/data/Datasets/HeCBench/
cd dataset
python full_eval_validation_modal.pyImportant: dataset/full_eval_validation_modal.py
currently selects the input file through the hard-coded suffix near the top of
the file. Set it so this path exists:
data/Datasets/HeCBench/inference_result_modal<suffix>.jsonl
Validation produces:
dataset/data/Datasets/HeCBench/inference_output<suffix>.jsonldataset/logs/validators_results_<run_id>.txtdataset/logs/validators_stats_<run_id>.txt
The final reported numbers are based on the generated validators_stats_*.txt
files.
Some translations that appear as failures in validators_stats_*.txt may still
be worth checking manually, especially when the validator log shows missing or
borderline compiled outputs. For those cases, take the relevant generated run
directories from the validator statistics/logs and add them to the LOCATIONS
array in dataset/run_make_clean_run_parallel.sh.
Then rerun the generated code directly:
cd dataset
bash run_make_clean_run_parallel.shThe script runs make clean run for each listed generated directory and writes
the combined output to dataset/logs/make_clean_<suffix>.txt. Inspect that log
by hand and update your final accounting for cases that compile and run
successfully despite being missed or marked incorrectly by the automated stats.
All PRM training and validation lives in PRM/. This section is based on the Process Q Model. We use it only as the base of our code not as part of our method.
cd PRM
bash run_pqm_qwen_code.shThis uses PRM/train_code_a3.jsonl.
Use this when you want the strongest model from all available training data. Do not use this checkpoint for a clean validation report if validation examples are included in the full training set.
cd PRM
bash run_pqm_qwen_code_for_valid.shThis uses PRM/train_code_a3_filtered.jsonl, which removes validation examples from training. This is the right path when you want validation on PRM/branch_selection_test.jsonl to be meaningful.
The shell wrappers call PRM/train_main.py. Edit the wrapper
first; only edit train_main.py when you are changing the training code itself.
| Parameter | Where | What it controls |
|---|---|---|
MODEL_PATH / --model-path |
wrapper + train_main.py |
Hugging Face model ID or local model directory for the PRM backbone. |
CUDA_VISIBLE_DEVICES |
wrapper | GPU IDs allocated to the run. NPROC_PER_NODE is computed from this list. |
--train-jsonl |
wrapper | Training JSONL. Use train_code_a3.jsonl for full training or train_code_a3_filtered.jsonl for validation-clean training. |
--vector-base-dir |
wrapper | Directory used to resolve vector files referenced in the JSONL. Code PRM runs usually use PRM/vectors_code. |
--save-path |
wrapper | Output directory for checkpoints and profiler outputs. Change it for each important experiment. |
--checkpoint-path |
wrapper | Existing checkpoint to resume/load. Remove for a fresh run, or point to a concrete checkpoint-* directory. |
--loss-type |
wrapper | Objective: mse, huber, rank, orm, or bce. Current code PRM runs use mse. |
--zeta |
wrapper | Objective-specific shaping/scaling parameter used by the current loss. |
--two-phase |
wrapper | Freezes part of the model early, then unfreezes later. Useful for vector-input PRM training. |
--unfreeze-step |
wrapper | Step where the backbone is unfrozen during two-phase training. |
--backbone-lr-factor |
wrapper | Multiplier applied to the backbone learning rate after unfreezing. |
--num-epochs |
wrapper | Number of training epochs. |
--effective-batch-size |
wrapper | Total effective batch size across all GPUs after gradient accumulation. |
--per-device-batch-size |
wrapper | Samples per GPU per forward pass. Increase only if VRAM allows it. |
--score-threshold |
wrapper | Score below this value is treated as a negative step. Current scripts use 0.5. |
Good defaults to review before every run:
MODEL_PATH="Qwen/Qwen2.5-Coder-7B-Instruct"
export CUDA_VISIBLE_DEVICES=0,1,2
--train-jsonl "${SCRIPT_DIR}/train_code_a3_filtered.jsonl"
--vector-base-dir "${SCRIPT_DIR}/vectors_code"
--save-path "${SCRIPT_DIR}/checkpoints/<new-run-name>"cd PRM
bash run_pqm_qwen_validate_code.shThis wrapper is the PRM validation / branch-selection test used in the paper. With the checked-in defaults it evaluates a checkpoint on the released validation JSONL and vector cache. If you rebuild the branch-selection dataset from scratch in the section below, point this same wrapper at that rebuilt JSONL and vectors directory to reproduce the paper numbers end to end.
Update these values in PRM/run_pqm_qwen_validate_code.sh:
| Variable / argument | What it should point to |
|---|---|
VAL_JSONL |
JSONL scored by the branch-selection wrapper. The checked-in default is PRM/branch_selection_test.jsonl. |
CHECKPOINT_PATH |
Checkpoint to evaluate, usually from the filtered training run. |
--vector-base-dir |
Vector cache for the same branch-selection dataset. The checked-in default is PRM/vectors_code_validate_new. |
TOLERANCE / --tolerance |
Tolerance band for the validation metric. |
CUDA_VISIBLE_DEVICES |
GPUs used for validation. |
The validation wrapper calls PRM/train_main.py with
--validate, so it evaluates a checkpoint rather than continuing training.
This is the same entry point used for the branch-selection numbers discussed in
the dataset rebuild section below.
Two utility scripts help inspect a run:
| Script | Purpose |
|---|---|
| PRM/loss_graph.py | Plot the training loss curve. |
| PRM/plot_distributions.py | Plot PRM score distributions. |
Use these to check whether training is stable and whether the PRM scores separate successful and unsuccessful candidates in a useful way.
Most users can skip this section. The released training JSONL files are enough for PRM training and validation.
cd dataset
bash run_reeval.shBefore launching, check dataset/run_reeval.sh:
| Parameter | Meaning |
|---|---|
GPUS |
Local GPU IDs available for validation. |
TOTAL |
Number of entries to process. |
MAX_PARALLEL |
Number of validator jobs to run per GPU. |
Compressed assets may need to be decompressed before use:
dataset/data/Datasets/HeCBench/translations_redo_a2_full.jsonl.gz
cd dataset
bash run_build_hecbench_dataset_modal.shThis is the slowest part of the pipeline. It runs generation on Modal and local HeCBench validation on your machine or cluster.
Parameters likely to change in dataset/run_build_hecbench_dataset_modal.sh:
| Parameter | Meaning |
|---|---|
COCONUT_CONFIG |
Path to the Coconut code-translation config. After the folder rename, point into coconut_large_model_inference/. |
CUDA_VISIBLE_DEVICES |
Local GPU used for validation. |
FROM_API / TO_API |
Translation direction for dataset generation. |
MODAL_GPU |
Modal GPU spec used for generation. |
--split |
Dataset split passed to scripts/run_build_dataset_hecbench.py. |
--num-resamples |
Number of generated alternatives per item. |
--max-workers |
Local parallelism for build/validation. |
The branch-selection experiment measures whether the PRM can pick the best continuation when the large model branches off at every latent step. To reproduce that dataset, run the 4-direction wrapper:
cd dataset
bash run_build_validation_dir_4runs_modal.sh
# resume an interrupted run:
bash run_build_validation_dir_4runs_modal.sh --resumeWhat the script does:
- Runs Coconut on Modal for 4 translation directions in parallel
(
cudaβomp,ompβcuda,serialβcuda,serialβomp),LIMITkernels each. - For every kernel it produces the original translation and one
continuation per
(latent_vector_index, resample)branch-off. - With
SAVE_CONTINUATION_VECTORS=1(default), the full latent trajectory of each continuation is dumped asvectors/{capture_id}_v{i}_r{j}.ptalongside the originalvectors/{capture_id}.pt. These are the inputs the PRM scores in the branch-selection test. - Local HeCBench validation is skipped (
--skip-validation); scoring those branches is done later by the PRM, not the compiler.
Output layout under dataset/data/validation_dir/:
data/validation_dir/
βββ dataset.jsonl # one row per kernel/direction
βββ vectors/
β βββ <capture_id>.pt # original translation latent vectors [K, dim]
β βββ <capture_id>_v{i}_r{j}.pt # continuation latent vectors [K, dim]
βββ logs/validation_dir_<ts>/build_<from>_<to>.log
Variables you will probably edit in dataset/run_build_validation_dir_4runs_modal.sh:
| Variable | Meaning |
|---|---|
COCONUT_CONFIG |
Coconut config. Default ../coconut_large_model_inference/args/code_translation_70b.yaml. Pass 7b.yaml/14b.yaml/30b.yaml/70b.yaml and MEM_PER_GPU is auto-derived. |
MODAL_GPU |
Modal GPU spec, e.g. A100-80GB:4 (default), H200:2, H200:3. Per-GPU memory cap is computed from this. |
SAVE_CONTINUATION_VECTORS |
1 (default) saves continuation .pt files; 0 skips them (you almost never want this for branch-selection data). |
LIMIT |
Kernels per direction. Default 15 (matches the released branch-selection set). |
OUTPUT_DIR |
Output root, default data/validation_dir. |
MAX_ATTEMPTS / RETRY_DELAY_SECS |
Retry policy for Modal OOM (RETRY_EXIT_CODE=33). |
Override at the command line, for example:
COCONUT_CONFIG=../coconut_large_model_inference/args/code_translation_30b.yaml \
MODAL_GPU=H200:2 \
SAVE_CONTINUATION_VECTORS=1 \
bash run_build_validation_dir_4runs_modal.shAfter the build finishes, point the PRM validation wrapper at the produced
vectors directory (set --vector-base-dir in
PRM/run_pqm_qwen_validate_code.sh to
dataset/data/validation_dir/vectors) and the corresponding JSONL to
score every branch and reproduce the branch-selection numbers.
Prerequisites: the coconut conda environment (the script does
conda activate LGPRM), a working Modal account (modal token new), and
the volume hf-model-cache will be created automatically on first run to
cache the model weights between Modal cold starts.
| Artifact | Produced by | Used for |
|---|---|---|
PRM/checkpoints/<run>/checkpoint-* |
PRM training | Validation and inference-time PRM scoring. |
PRM/vectors_code/ |
Dataset/vector preparation | Training vector lookup. |
PRM/vectors_code_validate_new/ |
Validation vector preparation | Validation vector lookup. |
coconut_large_model_inference/inference_result_*.jsonl |
PRM-guided inference | Input to HeCBench validation. |
dataset/data/Datasets/HeCBench/inference_output*.jsonl |
Final validation | Per-sample scored outputs. |
dataset/logs/validators_stats_*.txt |
Final validation | Reported run statistics. |
Some runs require one manual post-check: copy the compiled-file list from the validation log into dataset/run_make_clean_run_parallel.sh and compare its output to identify missing files that still qualify.
| Symptom | What to check |
|---|---|
| PRM checkpoint not found | Make sure scripts point to PRM/checkpoints/<run>/checkpoint-*, not just the parent run directory. |
| Validation uses the wrong inference file | Update the hard-coded suffix in dataset/full_eval_validation_modal.py. |
| Inference still points to old folders | Replace old ../My_PRM/... paths with ../PRM/... in inference wrappers. |
| Dataset build config is missing | Update old coconut_large_model paths to coconut_large_model_inference. |
| CUDA out of memory | Reduce --per-device-batch-size, reduce candidates, or use fewer/larger GPUs depending on the stage. |
| Validation results look too optimistic | Train with PRM/train_code_a3_filtered.jsonl before validating on PRM/branch_selection_test.jsonl. |
To reproduce an existing run from a released PRM checkpoint:
- Place the checkpoint under
PRM/checkpoints/. - Update
PRM_CHECKPOINTin coconut_large_model_inference/run_inference_w_prm_modal_code_parallel.sh. - Run PRM-guided inference.
- Copy the resulting
inference_result_*.jsonlintodataset/data/Datasets/HeCBench/. - Update
suffixin dataset/full_eval_validation_modal.py. - Run validation and use the resulting
validators_stats_*.txtfiles. - Run the manual make/clean check when the validator logs indicate missing compiled files that may still qualify.
See LICENSE.
