Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion src/core/entropy_dynamics/analysis_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,4 @@ def main():


if __name__ == "__main__":
main()
main()
4 changes: 2 additions & 2 deletions src/core/entropy_dynamics/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from pathlib import Path

import matplotlib.pyplot as plt
#import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
#import seaborn as sns


def load_results(path: str | Path) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion src/core/entropy_dynamics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def out_path(self) -> Path:
def results_filename(self) -> str:
"""Unique filename based on teacher source and role to prevent overwrites."""
stem = Path(self.teacher_reasoning_path).stem
return f"entropy_dynamics_{self.role.value}_{stem}.parquet"
return f"entropy_dynamics_{self.role.value}_{stem}.parquet"
5 changes: 5 additions & 0 deletions src/core/entropy_dynamics/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def _tokenize_with_assistant_prefix(
base_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)

# FIX: Flatten BatchEncoding if necessary
if not isinstance(base_ids, list):
base_ids = base_ids.input_ids[0] if isinstance(base_ids.input_ids[0], list) else list(base_ids.input_ids)

prefix_ids = tokenizer.encode(assistant_prefix, add_special_tokens=False)
return base_ids + prefix_ids

Expand Down
12 changes: 7 additions & 5 deletions src/core/entropy_dynamics/reasoning_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ def load_teacher_reasoning(
tokenizer: PreTrainedTokenizer,
min_thinking_tokens: int = 16,
) -> list[TeacherReasoning]:
"""Load and tokenize teacher reasoning chains.
"""
"""Load and tokenize teacher reasoning chains."""
df = pd.read_parquet(path)

# Маппим distill_reasoning в ожидаемый колонку thinking, если пришел новый датасет
if "distill_reasoning" in df.columns and "thinking" not in df.columns:
df = df.rename(columns={"distill_reasoning": "thinking"})

if "input" in df.columns and "output" in df.columns:
records = _parse_synth_aug(df)
elif "thinking" in df.columns:
elif "thinking" in df.columns: # Сюда теперь зайдет и ваш датасет
records = _parse_flat(df)
else:
raise ValueError(
Expand All @@ -57,7 +60,6 @@ def load_teacher_reasoning(

return results


def _parse_synth_aug(df: pd.DataFrame) -> list[TeacherReasoning]:
records: list[TeacherReasoning] = []

Expand Down Expand Up @@ -125,4 +127,4 @@ def _safe_literal_eval(s) -> list:
result = ast.literal_eval(str(s))
return list(result) if isinstance(result, (list, tuple)) else []
except Exception:
return []
return []
17 changes: 9 additions & 8 deletions src/core/entropy_dynamics/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,23 @@
import argparse
import sys

from core.entropy_dynamics.analyzer import run_full_analysis
#from core.entropy_dynamics.analyzer import run_full_analysis
from core.entropy_dynamics.config import (
EntropyDynamicsConfig, ExperimentRole, InferenceMode, StudentModelConfig,
)
from core.entropy_dynamics.runner import EntropyDynamicsRunner

# ── Model presets by role ──
STUDENT_MODELS = [
StudentModelConfig(model_id="/home/dviazhev/complexity-aware-fine-tuning-old/src/models/Qwen2.5-3B-Instruct", label="qwen_3b"),
StudentModelConfig(model_id="/home/dviazhev/qa_finetune/Phi-4-mini-instruct", label="phi4_mini"),
# StudentModelConfig(model_id="meta-llama/Llama-3.2-3B-Instruct", label="llama_3b"),
StudentModelConfig(model_id="Qwen/Qwen2.5-3B", label="qwen_3b"),
StudentModelConfig(model_id="microsoft/Phi-4-mini-instruct", label="phi4_mini"),
StudentModelConfig(model_id="meta-llama/Llama-3.2-3B-Instruct", label="llama_3b"),
]

PROXY_MODELS = [
StudentModelConfig(model_id="/home/dviazhev/recursive_caft/models/Qwen2.5-32B-Instruct", label="qwen_32b"),
StudentModelConfig(model_id="/home/dviazhev/recursive_caft/models/Mistral-Small-24B-Instruct-2501", label="mistral_24b"),
StudentModelConfig(model_id="/mnt/data198/LLM/models/Qwen2.5-32B-Instruct", label="qwen_32b"),
StudentModelConfig(model_id="/mnt/data198/LLM/models/Qwen2.5-14B-Instruct", label="qwen_14b"),
StudentModelConfig(model_id="/mnt/data198/LLM/models/Mistral-Small-3.2-24B-Instruct-2506", label="mistral_24b"),
]


Expand Down Expand Up @@ -138,10 +139,10 @@ def main():
results_path = config.out_path / config.results_filename
if results_path.exists():
print(f"\nRunning analysis on {results_path}...")
run_full_analysis(results_path, config.out_path / "analysis")
# run_full_analysis(results_path, config.out_path / "analysis")
else:
print(f"Results file not found at {results_path}. Run inference first.")


if __name__ == "__main__":
main()
main()
87 changes: 46 additions & 41 deletions src/core/entropy_dynamics/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@

@dataclass
class StepResult:
"""Single measurement: one student × one question × one k."""
"""Single measurement: one model × one question × one k."""
question_id: str
student_label: str
model_label: str # was "student_label" — now generic
role: str # "student" or "proxy"
k: int
num_reasoning_tokens: int
total_reasoning_tokens: int
mode: str
answer_entropy: float
student_answer: str
student_correct: bool
model_answer: str # was "student_answer"
model_correct: bool # was "student_correct"
gold_answer: str


Expand All @@ -50,18 +51,16 @@ def save(self, path: Path):


class EntropyDynamicsRunner:
"""Runs the full entropy dynamics experiment."""
"""Runs the full entropy dynamics experiment for any role (student or proxy)."""

def __init__(self, config: EntropyDynamicsConfig):
self.config = config

def run(self) -> pd.DataFrame:
set_seed()

# Load teacher reasoning with a lightweight tokenizer
# (any tokenizer sharing the vocab works for slicing)
first_student_id = self.config.students[0].model_id
loader_tokenizer = AutoTokenizer.from_pretrained(first_student_id)
first_model_id = self.config.students[0].model_id
loader_tokenizer = AutoTokenizer.from_pretrained(first_model_id)

print(f"Loading teacher reasoning from {self.config.teacher_reasoning_path}...")
samples = load_teacher_reasoning(
Expand All @@ -73,48 +72,53 @@ def run(self) -> pd.DataFrame:

all_results = ExperimentResults()

for student_cfg in self.config.students:
self._run_single_student(student_cfg, samples, all_results)
for model_cfg in self.config.students:
self._run_single_model(model_cfg, samples, all_results)

out_path = self.config.out_path / "entropy_dynamics_results.parquet"
# ── Use unique filename to prevent overwrites ──
out_path = self.config.out_path / self.config.results_filename
all_results.save(out_path)
print(f"All results saved to {out_path}")

return all_results.to_dataframe()

def _run_single_student(
def _run_single_model(
self,
student_cfg: StudentModelConfig,
model_cfg: StudentModelConfig,
samples: list[TeacherReasoning],
results: ExperimentResults,
):
role = self.config.role.value
print(f"\n{'='*60}")
print(f"Student: {student_cfg.label} ({student_cfg.model_id})")
print(f"[{role}] {model_cfg.label} ({model_cfg.model_id})")
print(f"{'='*60}")

tokenizer = AutoTokenizer.from_pretrained(student_cfg.model_id)
tokenizer = AutoTokenizer.from_pretrained(model_cfg.model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
student_cfg.model_id,
model_cfg.model_id,
device_map=DEVICE_MAP,
torch_dtype=torch.bfloat16,
)
model.eval()

# ── Checkpoint includes role to avoid collision ──
checkpoint_path = (
self.config.out_path / f"checkpoint_{student_cfg.label}_{self.config.mode.value}.parquet"
self.config.out_path
/ f"checkpoint_{role}_{model_cfg.label}_{self.config.mode.value}.parquet"
)
processed_ids = _load_processed_ids(checkpoint_path)
print(f"Resuming: {len(processed_ids)} questions already processed.")

t_start = time.perf_counter()

for i, sample in enumerate(tqdm(samples, desc=f"[{student_cfg.label}]")):
for i, sample in enumerate(tqdm(samples, desc=f"[{role}/{model_cfg.label}]")):
if sample.question_id in processed_ids:
continue

# Re-tokenize with this model's tokenizer
sample.thinking_token_ids = tokenizer.encode(
sample.thinking_text, add_special_tokens=False
)
Expand All @@ -128,19 +132,19 @@ def _run_single_student(
)

for prompt in prefixed_prompts:
step_result = self._run_single_step(model, tokenizer, prompt, student_cfg.label)
step_result = self._run_single_step(
model, tokenizer, prompt, model_cfg.label, role
)
results.append(step_result)

# Checkpoint
if (i + 1) % self.config.batch_save_every == 0:
results.save(checkpoint_path)
elapsed = time.perf_counter() - t_start
print(f" [{student_cfg.label}] {i+1}/{len(samples)} "
print(f" [{role}/{model_cfg.label}] {i+1}/{len(samples)} "
f"({elapsed:.0f}s elapsed)")

results.save(checkpoint_path)

# Free
del model
gc.collect()
if torch.cuda.is_available():
Expand All @@ -152,25 +156,25 @@ def _run_single_step(
model,
tokenizer: PreTrainedTokenizer,
prompt: PrefixedPrompt,
student_label: str,
model_label: str,
role: str,
) -> StepResult:
"""Run one forward pass and extract entropy."""

input_ids = torch.tensor([prompt.input_ids], device=DEVICE)
attention_mask = torch.ones_like(input_ids)

if prompt.mode == InferenceMode.FORCED:
return self._step_forced(
model, tokenizer, input_ids, attention_mask, prompt, student_label
model, tokenizer, input_ids, attention_mask, prompt, model_label, role
)
else:
return self._step_continuation(
model, tokenizer, input_ids, attention_mask, prompt, student_label
model, tokenizer, input_ids, attention_mask, prompt, model_label, role
)

def _step_forced(
self, model, tokenizer, input_ids, attention_mask,
prompt: PrefixedPrompt, student_label: str,
prompt: PrefixedPrompt, model_label: str, role: str,
) -> StepResult:
"""Mode A: generate 1 token, measure its entropy."""
outputs = model.generate(
Expand All @@ -188,26 +192,27 @@ def _step_forced(
entropy = compute_entropy_from_logits(first_token_logits).item()

generated_id = outputs.sequences[0, input_ids.shape[1]].item()
student_answer = tokenizer.decode([generated_id]).strip().lower()
answer = tokenizer.decode([generated_id]).strip().lower()

return StepResult(
question_id=prompt.question_id,
student_label=student_label,
model_label=model_label,
role=role,
k=prompt.k,
num_reasoning_tokens=prompt.num_reasoning_tokens,
total_reasoning_tokens=prompt.total_reasoning_tokens,
mode=prompt.mode.value,
answer_entropy=entropy,
student_answer=student_answer,
student_correct=(student_answer == prompt.gold_answer),
model_answer=answer,
model_correct=(answer == prompt.gold_answer),
gold_answer=prompt.gold_answer,
)

def _step_continuation(
self, model, tokenizer, input_ids, attention_mask,
prompt: PrefixedPrompt, student_label: str,
prompt: PrefixedPrompt, model_label: str, role: str,
) -> StepResult:
"""Mode B: student continues generating, measure avg entropy of tail."""
"""Mode B: model continues generating, measure avg entropy of tail."""
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -222,11 +227,11 @@ def _step_continuation(
gen_ids = outputs.sequences[0, input_ids.shape[1]:]
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)

student_answer = ""
marker_pos = gen_text.find(answer_marker[1]) # find "]]"
answer = ""
marker_pos = gen_text.find(answer_marker[1])
marker_start = gen_text.rfind(answer_marker[0], 0, marker_pos if marker_pos != -1 else None)
if marker_start != -1 and marker_pos != -1:
student_answer = gen_text[marker_start + len(answer_marker[0]):marker_pos].strip().lower()
answer = gen_text[marker_start + len(answer_marker[0]):marker_pos].strip().lower()

scores = outputs.scores
if not scores:
Expand All @@ -243,20 +248,20 @@ def _step_continuation(

return StepResult(
question_id=prompt.question_id,
student_label=student_label,
model_label=model_label,
role=role,
k=prompt.k,
num_reasoning_tokens=prompt.num_reasoning_tokens,
total_reasoning_tokens=prompt.total_reasoning_tokens,
mode=prompt.mode.value,
answer_entropy=entropy,
student_answer=student_answer,
student_correct=(student_answer == prompt.gold_answer),
model_answer=answer,
model_correct=(answer == prompt.gold_answer),
gold_answer=prompt.gold_answer,
)


def _load_processed_ids(checkpoint_path: Path) -> set[str]:
"""Load question_ids already processed from a checkpoint parquet."""
if not checkpoint_path.exists():
return set()
try:
Expand Down
Loading