diff --git a/config/client_text.yaml b/config/client_text.yaml index 8ddcaab..5464308 100644 --- a/config/client_text.yaml +++ b/config/client_text.yaml @@ -1,6 +1,6 @@ huri_url: ws://localhost:8000/session -topic_list: [question, rag_response] +topic_list: [question, token] senders: text: @@ -12,4 +12,7 @@ modules: args: language: en tone: formal - logging: INFO + max_history_turns: 6 + temperature: 0.7 + response_format: short + persona: "" diff --git a/config/huri.yaml b/config/huri.yaml index 593496c..8dd83e8 100644 --- a/config/huri.yaml +++ b/config/huri.yaml @@ -40,8 +40,15 @@ applications: import_path: src.app:app runtime_env: env_vars: + OLLAMA_BASE_URL: "http://localhost:11434" + OLLAMA_API_KEY: "ollama" + RAY_COLOR_PREFIX: "1" + # Make the CosyVoice repo (and its Matcha-TTS submodule) importable + # inside every replica, regardless of cwd. + PYTHONPATH: "/home/fifster/Tek/Eip/HuRI/assets/cosyvoice:/home/fifster/Tek/Eip/HuRI/assets/cosyvoice/third_party/Matcha-TTS" + # --- Gesture sliding-window defaults (run in the HuRI CPU actor) --- HURI_GESTURE_CONTEXT_SEC: "2.0" HURI_GESTURE_MIN_CHUNK_SEC: "0.5" @@ -56,19 +63,19 @@ applications: # it as an instruction and intermittently speaks it (prompt leakage). HURI_VOICE_TRANSCRIPT: "You are a helpful assistant.<|endofprompt|>Instinct creates its own oppressors and bids us rise up against them." # From .Values.models.cosytts.env (mountPath/modelId) — edit for local layout. - HURI_MODEL_PATH: /models/cosytts/FunAudioLLM/Fun-CosyVoice3-0.5B-2512 + HURI_MODEL_PATH: assets/cosyvoice_model # Path to the CosyVoice repo root containing third_party/Matcha-TTS. - HURI_COSY_DIR: /app/cosyvoice + HURI_COSY_DIR: assets/cosyvoice # From .Values.voiceAssets.env — the reference voice sample. - HURI_VOICE_SAMPLE_PATH: /assets/voice.wav + HURI_VOICE_SAMPLE_PATH: assets/voice.wav # --- STT (faster-whisper) --- # From .Values.models.whisper.env (mountPath/repoId) — edit for local layout. - HURI_STT_MODEL_PATH: /models/whisper/Systran/faster-whisper-base + # HURI_STT_MODEL_PATH: /models/whisper/Systran/faster-whisper-base # --- Gesture (EMAGE) --- # From .Values.models.emage.env (mountPath/repoId) — edit for local layout. - HURI_EMAGE_REPO: /models/emage/H-Liu1997/emage_audio + HURI_EMAGE_REPO: assets/emage_audio # --- GPU-vendor runtime env (Helm puts these on the worker containers) --- NVIDIA_VISIBLE_DEVICES: "all" @@ -88,7 +95,7 @@ applications: num_replicas: 1 ray_actor_options: num_cpus: 1 - num_gpus: 0.5 + num_gpus: 0 # RAG: embeddings (API) + LLM client. No GPU needed. - name: RAGHandle @@ -98,7 +105,10 @@ applications: num_gpus: 0 user_config: embedding_model: "bge-large-en-v1.5-gguf-Q4_K_M" - llm_model: "Qwen3.5-4B-GGUF" + llm_model: "mistral:7b" + llm_base_url: "http://localhost:11434/v1" # Ollama's OpenAI-compatible endpoint + llm_api_key: "ollama" + memory_maintenance_check_hours: 0.5 # GPU split (manual override knob): num_gpus are Ray *scheduling* # fractions that let replicas pack onto the same device and bias the diff --git a/src/core/huri.py b/src/core/huri.py index 5fa8038..a6ae715 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -125,5 +125,17 @@ async def receive_loop(session: Session, ws: WebSocket): finally: print(f"Client {user_id} disconnected") - await receive_loop(self.clients[session_id], ws) - del self.clients[session_id] + try: + await receive_loop(self.clients[session_id], ws) + finally: + # Persist per-session state (e.g. conversation memory) on disconnect. + for module in modules: + fin = getattr(module, "finalize", None) + if fin is None: + continue + try: + await fin() + except Exception: + import traceback + print(f"[HuRI] finalize failed for {type(module).__name__}:\n{traceback.format_exc()}") + self.clients.pop(session_id, None) \ No newline at end of file diff --git a/src/modules/rag/memory_inspect.py b/src/modules/rag/memory_inspect.py new file mode 100644 index 0000000..f912b16 --- /dev/null +++ b/src/modules/rag/memory_inspect.py @@ -0,0 +1,90 @@ +"""Inspect conversation memories: current strength, decay projection, fate. + +Usage: + python -m src.modules.rag.memory.memory_inspect + python -m src.modules.rag.memory.memory_inspect --user-id +""" +import argparse +from datetime import datetime, timedelta + +try: + from .qdrant_utils import make_qdrant_client +except ImportError: + from qdrant_utils import make_qdrant_client + +HALF_LIFE_DAYS = 5.0 +DELETE_BELOW, CONSOLIDATE_BELOW = 0.05, 0.30 + + +def strength(payload: dict, at: datetime | None = None) -> float: + """Query-independent strength: recency * importance (matches maintenance).""" + at = at or datetime.now() + imp = payload.get("importance", 3) + half = max(HALF_LIFE_DAYS * (imp / 5.0), 0.5) + try: + last = datetime.fromisoformat(payload.get("last_accessed") or payload["created_at"]) + age = (at - last).total_seconds() / 86400.0 + except Exception: + age = 0.0 + return (0.5 ** (max(age, 0) / half)) * (imp / 10.0) + + +def fate(s: float) -> str: + if s < DELETE_BELOW: + return "DELETE" + if s < CONSOLIDATE_BELOW: + return "CONSOLIDATE" + return "KEEP" + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--qdrant-url", default="http://localhost:6333") + ap.add_argument("--collection", default="conversations") + ap.add_argument("--user-id", default=None) + args = ap.parse_args() + + qdrant = make_qdrant_client(args.qdrant_url) + points, offset = [], None + while True: + batch, offset = qdrant.scroll(collection_name=args.collection, limit=200, + offset=offset, with_payload=True, with_vectors=False) + points.extend(batch) + if offset is None: + break + + now = datetime.now() + rows = [] + for p in points: + pl = p.payload + if pl.get("type") == "maintenance_marker": + print(f"[marker] last maintenance run: {pl.get('last_run')}\n") + continue + if args.user_id and pl.get("_user_id") != args.user_id: + continue + s_now = strength(pl, now) + rows.append({ + "text": pl.get("text", "")[:60].replace("\n", " "), + "type": pl.get("type", "?"), + "imp": pl.get("importance", "?"), + "acc": pl.get("access_count", 0), + "age_d": round((now - datetime.fromisoformat( + pl.get("last_accessed") or pl["created_at"])).total_seconds() / 86400, 1), + "now": round(s_now, 3), + "+5d": round(strength(pl, now + timedelta(days=5)), 3), + "+15d": round(strength(pl, now + timedelta(days=15)), 3), + "fate": fate(s_now), + }) + + rows.sort(key=lambda r: r["now"], reverse=True) + hdr = f"{'strength':>8} {'+5d':>6} {'+15d':>6} {'imp':>3} {'acc':>3} {'age':>5} {'fate':<12} text" + print(hdr) + print("-" * len(hdr)) + for r in rows: + print(f"{r['now']:>8} {r['+5d']:>6} {r['+15d']:>6} {r['imp']:>3} " + f"{r['acc']:>3} {r['age_d']:>5} {r['fate']:<12} {r['text']}") + print(f"\n{len(rows)} memories") + + +if __name__ == "__main__": + main() diff --git a/src/modules/rag/memory_maintenance.py b/src/modules/rag/memory_maintenance.py new file mode 100644 index 0000000..d20cc5a --- /dev/null +++ b/src/modules/rag/memory_maintenance.py @@ -0,0 +1,119 @@ +"""Periodic memory maintenance: decay-based pruning + consolidation. + +Run every N days (cron/systemd timer): + python -m src.modules.rag.memory_maintenance +Strong memories are kept, weak ones are merged into a consolidated memory, +dead ones are deleted. Decay itself is computed lazily at query time in +rag.py; this job only prunes and compresses. +""" +import argparse +import json +import uuid +from collections import defaultdict +from datetime import datetime + +import httpx +from qdrant_client.models import Distance, PointIdsList, PointStruct, VectorParams + +try: + from .qdrant_utils import make_qdrant_client +except ImportError: + from qdrant_utils import make_qdrant_client + +DELETE_BELOW = 0.05 +CONSOLIDATE_BELOW = 0.30 +HALF_LIFE_DAYS = 5.0 + + +def strength(payload: dict) -> float: + """Query-independent strength: recency * importance (no relevance term).""" + importance = payload.get("importance", 3) + half_life = max(HALF_LIFE_DAYS * (importance / 5.0), 0.5) + try: + last = datetime.fromisoformat(payload.get("last_accessed") or payload["created_at"]) + age_days = (datetime.now() - last).total_seconds() / 86400.0 + except Exception: + age_days = 0.0 + recency = 0.5 ** (age_days / half_life) + return recency * (importance / 10.0) + + +def embed(client: httpx.Client, url: str, model: str, text: str) -> list[float]: + r = client.post(f"{url}/v1/embeddings", json={"model": model, "input": text}) + r.raise_for_status() + return r.json()["data"][0]["embedding"] + + +def llm(client: httpx.Client, url: str, model: str, prompt: str) -> str: + r = client.post(f"{url}/api/chat", json={ + "model": model, "stream": False, + "messages": [{"role": "user", "content": prompt}], + "options": {"num_predict": 300}, + }) + r.raise_for_status() + return r.json()["message"]["content"] + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--qdrant-url", default="http://localhost:6333") + ap.add_argument("--collection", default="conversations") + ap.add_argument("--ollama-url", default="http://localhost:11434") + ap.add_argument("--embedding-model", default="bge-large-en-v1.5-gguf-Q4_K_M") + ap.add_argument("--llm-model", default="mistral:7b") + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + + qdrant = make_qdrant_client(args.qdrant_url) + http = httpx.Client(timeout=180.0) + + points, offset = [], None + while True: + batch, offset = qdrant.scroll(collection_name=args.collection, limit=200, + offset=offset, with_payload=True, with_vectors=False) + points.extend(batch) + if offset is None: + break + print(f"{len(points)} memories in '{args.collection}'") + + to_delete, weak_by_user = [], defaultdict(list) + for p in points: + s = strength(p.payload) + if s < DELETE_BELOW: + to_delete.append(p) + elif s < CONSOLIDATE_BELOW: + weak_by_user[p.payload.get("_user_id", "anonymous")].append(p) + print(f"delete: {len(to_delete)}, consolidate candidates: {sum(map(len, weak_by_user.values()))}") + + if args.dry_run: + return + + for user, weak in weak_by_user.items(): + if len(weak) < 3: + continue # not worth merging yet; keep decaying + texts = [p.payload["text"] for p in weak] + merged = llm(http, args.ollama_url, args.llm_model, + "These are old memories about conversations with the same person. " + "Merge them into a single 3-5 sentence memory keeping only durable " + "facts, preferences and recurring themes. Drop one-off small talk.\n\n" + + "\n---\n".join(texts)).strip() + vec = embed(http, args.ollama_url, args.embedding_model, merged) + now = datetime.now().isoformat() + imp = min(max(p.payload.get("importance", 3) for p in weak) + 1, 10) + qdrant.upsert(collection_name=args.collection, points=[PointStruct( + id=str(uuid.uuid4()), vector=vec, payload={ + "text": merged, "_user_id": user, "type": "conversation_consolidated", + "created_at": now, "last_accessed": now, "access_count": 0, + "importance": imp, + })]) + to_delete.extend(weak) + print(f"[{user}] consolidated {len(weak)} → 1 (importance={imp})") + + if to_delete: + qdrant.delete(collection_name=args.collection, + points_selector=PointIdsList(points=[p.id for p in to_delete])) + print(f"deleted {len(to_delete)} points") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0e65bbe..69f39b3 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,7 +1,10 @@ import json import os import traceback +import uuid +import asyncio from dataclasses import dataclass, field +from datetime import datetime from typing import Any, AsyncGenerator from pydantic import BaseModel @@ -14,7 +17,7 @@ import httpx -from qdrant_client.models import FieldCondition, Filter, MatchValue +from qdrant_client.models import FieldCondition, Filter, MatchValue, PointStruct, PointIdsList from .qdrant_utils import make_qdrant_client # Default character persona. Overridable per session via the `persona` key in the @@ -41,6 +44,15 @@ class RAGDeploymentConfig(BaseModel): verify_ssl: bool = True top_k: int = 5 score_threshold: float = 0.5 + + memory_collection: str = "conversations" + memory_top_k: int = 3 + memory_half_life_days: float = 5.0 + memory_w_relevance: float = 0.5 + memory_w_recency: float = 0.3 + memory_w_importance: float = 0.2 + memory_maintenance_days: float = 5.0 + memory_maintenance_check_hours: float = 6.0 @dataclass @@ -60,6 +72,7 @@ class RAGQuery: class RAGHandle: """Stateless RAG processor. Streams LLM tokens to the caller.""" + _MAINTENANCE_MARKER_ID = "00000000-0000-0000-0000-00000000feed" def __init__(self, **kwargs): self._cfg = RAGDeploymentConfig(**kwargs) self._apply_config() @@ -75,6 +88,10 @@ def _apply_config(self) -> None: print(f"[RAGHandle] Connected to Qdrant at {cfg.qdrant_url}") self._embed_client = httpx.AsyncClient(timeout=30.0, verify=cfg.verify_ssl) self._llm_client = httpx.AsyncClient(timeout=120.0, verify=cfg.verify_ssl) + if not hasattr(self, "_maintenance_task") or self._maintenance_task.done(): + self._maintenance_task = asyncio.get_event_loop().create_task( + self._maintenance_loop() + ) def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: collection = self._cfg.default_collection @@ -128,6 +145,227 @@ def _get_profile(self, collection: str, _user_id: str) -> list[str]: return [] return [p.payload.get("text", "") for p in points if p.payload.get("text")] + def _ensure_memory_collection(self, vector_size: int) -> None: + from qdrant_client.models import Distance, VectorParams + names = [c.name for c in self._qdrant.get_collections().collections] + if self._cfg.memory_collection not in names: + self._qdrant.create_collection( + collection_name=self._cfg.memory_collection, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + + async def _llm_complete(self, system_prompt: str, user_prompt: str, max_tokens: int = 300) -> str: + """Non-streamed convenience wrapper over _llm_stream.""" + parts = [] + async for d in self._llm_stream(system_prompt, user_prompt, {"max_length": max_tokens}, None): + parts.append(d) + return "".join(parts) + + def _memory_strength(self, payload: dict, relevance: float) -> float: + importance = payload.get("importance", 3) + half_life = max(self._cfg.memory_half_life_days * (importance / 5.0), 0.5) + try: + last = datetime.fromisoformat(payload.get("last_accessed") or payload["created_at"]) + age_days = (datetime.now() - last).total_seconds() / 86400.0 + except Exception: + age_days = 0.0 + recency = 0.5 ** (age_days / half_life) + cfg = self._cfg + return (cfg.memory_w_relevance * relevance + + cfg.memory_w_recency * recency + + cfg.memory_w_importance * (importance / 10.0)) + + def _search_memories(self, query_vector: list[float], _user_id: str) -> list[str]: + """Retrieve, re-rank (relevance+recency+importance), reinforce, return texts.""" + try: + hits = self._qdrant.query_points( + collection_name=self._cfg.memory_collection, + query=query_vector, + query_filter=Filter( + must=[FieldCondition(key="_user_id", match=MatchValue(value=_user_id))], + must_not=[FieldCondition(key="type", match=MatchValue(value="maintenance_marker"))], + ), + limit=10, + score_threshold=0.2, # permissive; real filtering is the re-rank + ).points + except Exception: + return [] # collection missing / qdrant down → just no memories + + scored = sorted(hits, key=lambda p: self._memory_strength(p.payload, p.score), reverse=True) + top = scored[: self._cfg.memory_top_k] + + # MemoryBank-style reinforcement: recalled memories decay slower. + now = datetime.now().isoformat() + for p in top: + try: + self._qdrant.set_payload( + collection_name=self._cfg.memory_collection, + payload={"last_accessed": now, + "access_count": p.payload.get("access_count", 0) + 1}, + points=[p.id], + ) + except Exception: + pass + return [p.payload.get("text", "") for p in top if p.payload.get("text")] + + async def save_conversation(self, _user_id: str, history: list) -> None: + """Summarize a finished session into one memory point. Called at disconnect.""" + if not history or len(history) < 2: + return + transcript = "\n".join(f"{m['role']}: {m['content']}" for m in history)[-6000:] + prompt = ( + "Summarize this conversation in 2-4 sentences, keeping any personal " + "facts, preferences, names, or commitments mentioned. Then rate 1-10 " + "how important it is to remember (small talk=1-2, personal facts or " + "preferences=7-10). Reply ONLY with JSON, no markdown: " + '{"summary": "...", "importance": N}\n\n' + transcript + ) + try: + raw = await self._llm_complete("You are a memory summarizer.", prompt) + cleaned = raw.strip().removeprefix("```json").removeprefix("```").removesuffix("```").strip() + data = json.loads(cleaned) + summary = str(data["summary"]) + importance = max(1, min(int(data["importance"]), 10)) + except Exception: + print(f"[RAG] memory summarization failed, storing raw tail:\n{traceback.format_exc()}") + summary, importance = transcript[-500:], 3 + + if importance <= 1: + print("[RAG] Session judged not memorable, skipping save") + return + + vector = await self._embed(summary) + self._ensure_memory_collection(len(vector)) + now = datetime.now().isoformat() + self._qdrant.upsert( + collection_name=self._cfg.memory_collection, + points=[PointStruct(id=str(uuid.uuid4()), vector=vector, payload={ + "text": summary, "_user_id": _user_id, "type": "conversation", + "created_at": now, "last_accessed": now, + "access_count": 0, "importance": importance, + })], + ) + print(f"[RAG] Saved conversation memory (importance={importance}): {summary[:80]}...") + + + def _last_maintenance(self) -> datetime | None: + try: + pts = self._qdrant.retrieve( + collection_name=self._cfg.memory_collection, + ids=[self._MAINTENANCE_MARKER_ID], with_payload=True, with_vectors=False, + ) + if pts: + return datetime.fromisoformat(pts[0].payload["last_run"]) + except Exception: + pass + return None + + def _mark_maintenance_done(self, vector_size: int) -> None: + # Marker point: zero vector, type=maintenance_marker. Filtered out of + # retrieval automatically (zero vector never scores) but be explicit anyway. + self._qdrant.upsert( + collection_name=self._cfg.memory_collection, + points=[PointStruct( + id=self._MAINTENANCE_MARKER_ID, + vector=[0.0] * vector_size, + payload={"type": "maintenance_marker", + "last_run": datetime.now().isoformat()}, + )], + ) + + async def _maintenance_loop(self) -> None: + import asyncio + check_secs = self._cfg.memory_maintenance_check_hours * 3600 + while True: + try: + last = self._last_maintenance() + due = (last is None or + (datetime.now() - last).total_seconds() + >= self._cfg.memory_maintenance_days * 86400) + if due: + print("[RAG] Running memory maintenance...") + await self._run_maintenance() + except Exception: + print(f"[RAG] maintenance loop error:\n{traceback.format_exc()}") + await asyncio.sleep(check_secs) + + async def _run_maintenance(self) -> None: + """Decay-based pruning + consolidation. Same logic as memory_maintenance.py.""" + from collections import defaultdict + DELETE_BELOW, CONSOLIDATE_BELOW = 0.05, 0.30 + + # scroll everything + points, offset = [], None + try: + while True: + batch, offset = self._qdrant.scroll( + collection_name=self._cfg.memory_collection, limit=200, + offset=offset, with_payload=True, with_vectors=False) + points.extend(batch) + if offset is None: + break + except Exception: + return # collection doesn't exist yet — nothing to do + + def base_strength(payload: dict) -> float: + # query-independent: recency * importance + imp = payload.get("importance", 3) + half = max(self._cfg.memory_half_life_days * (imp / 5.0), 0.5) + try: + last = datetime.fromisoformat(payload.get("last_accessed") or payload["created_at"]) + age = (datetime.now() - last).total_seconds() / 86400.0 + except Exception: + age = 0.0 + return (0.5 ** (age / half)) * (imp / 10.0) + + to_delete, weak_by_user = [], defaultdict(list) + vector_size = None + for p in points: + if p.payload.get("type") == "maintenance_marker": + continue + s = base_strength(p.payload) + if s < DELETE_BELOW: + to_delete.append(p) + elif s < CONSOLIDATE_BELOW: + weak_by_user[p.payload.get("_user_id", "anonymous")].append(p) + + for user, weak in weak_by_user.items(): + if len(weak) < 3: + continue + texts = [p.payload["text"] for p in weak] + merged = (await self._llm_complete( + "You are a memory consolidator.", + "These are old memories about conversations with the same person. " + "Merge them into a single 3-5 sentence memory keeping only durable " + "facts, preferences and recurring themes. Drop one-off small talk.\n\n" + + "\n---\n".join(texts))).strip() + if not merged: + continue + vec = await self._embed(merged) + vector_size = len(vec) + now = datetime.now().isoformat() + imp = min(max(p.payload.get("importance", 3) for p in weak) + 1, 10) + self._qdrant.upsert(collection_name=self._cfg.memory_collection, + points=[PointStruct(id=str(uuid.uuid4()), vector=vec, payload={ + "text": merged, "_user_id": user, + "type": "conversation_consolidated", + "created_at": now, "last_accessed": now, + "access_count": 0, "importance": imp, + })]) + to_delete.extend(weak) + print(f"[RAG] Consolidated {len(weak)} memories → 1 for user {user}") + + if to_delete: + self._qdrant.delete(collection_name=self._cfg.memory_collection, + points_selector=PointIdsList(points=[p.id for p in to_delete])) + print(f"[RAG] Deleted {len(to_delete)} decayed memories") + + if vector_size is None: + vector_size = len(await self._embed("marker")) + self._ensure_memory_collection(vector_size) + self._mark_maintenance_done(vector_size) + print("[RAG] Memory maintenance complete") + def _search( self, qdrant, @@ -168,6 +406,7 @@ def _build_prompt( chunks: list[dict], preferences: dict, profile_facts: list[str] | None = None, + memories: list[str] | None = None, ) -> tuple[str, str]: persona = preferences.get("persona", _DEFAULT_PERSONA) parts = [persona] @@ -190,19 +429,27 @@ def _build_prompt( parts.append(preferences["extra_instructions"]) parts.append( - "Use the context in the user's message to inform your answers when " - "it is relevant, but always answer in character. If you don't know " - "something, improvise in character rather than admitting you lack " - "information or breaking character. " - "IMPORTANT: Reply in 1-3 short sentences maximum. Be extremely concise. No lists, no emojis, no long explanations." - ) + "Use the context and memories in the user's message to inform your " + "answers when relevant, but always answer in character. If you have " + "relevant memories of past conversations, use them naturally. Only if " + "you genuinely know nothing relevant, improvise in character rather " + "than admitting you lack information or breaking character. " ) system_prompt = " ".join(parts) + memory_block = "" + if memories: + memory_block = ( + "Things you remember from previous conversations with this person:\n- " + + "\n- ".join(memories) + + "\n\n" + ) + if not chunks: user_prompt = ( - "No relevant context was found.\n\n" + memory_block + + "No relevant context was found.\n\n" f"Question: {question}\n\n" - "Answer based on general knowledge." + "Answer based on your memories above if relevant, otherwise general knowledge." ) else: context_parts = [] @@ -214,9 +461,10 @@ def _build_prompt( ) context_block = "\n\n".join(context_parts) user_prompt = ( - f"Context:\n{context_block}\n\n" + memory_block + + f"Context:\n{context_block}\n\n" f"Question: {question}\n\n" - "Answer based on the context above. " + "Answer based on the context and your memories above. " "Don't speak about the sources, just use them to answer." ) @@ -342,11 +590,14 @@ async def stream(self, query: RAGQuery) -> AsyncGenerator[str, None]: raise print(f"[RAG] Found {len(chunks)} chunks") + memories = self._search_memories(query_vector, query._user_id) + if memories: + print(f"[RAG] Recalled {len(memories)} memory(ies)") profile_facts = self._get_profile(collection, query._user_id) if profile_facts: print(f"[RAG] Loaded {len(profile_facts)} profile fact(s)") system_prompt, user_prompt = self._build_prompt( - query.question, chunks, query.preferences, profile_facts + query.question, chunks, query.preferences, profile_facts, memories ) print( @@ -441,3 +692,14 @@ def _record_turn(self, question: str, answer: str) -> None: def update_preferences(self, new_preferences: dict): self.preferences.update(new_preferences) + + async def finalize(self) -> None: + """Called when the session ends — persist this conversation as a memory.""" + if not self.history: + return + try: + await self._handle.save_conversation.remote( + self._user_id or "anonymous", list(self.history) + ) + except Exception: + print(f"[RAG] finalize failed:\n{traceback.format_exc()}")