"""Experiment P6: Multi-turn conversation simulation. Simulate a realistic multi-day conversation scenario: - Day 1: User discusses database issues - Day 2: User works on deployment - Day 3: User comes back with a related question → should recall Day 1 context - Day 4: User asks about something mentioned in passing on Day 1 Test: cross-session recall, context accumulation, multi-hop across days. """ import sys import time from pathlib import Path import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) sys.path.insert(0, str(Path(__file__).parent.parent)) from nuonuo.hippocampus import HippocampalMemory from llm import generate_paraphrases_heuristic DEVICE = "cuda" def load_model(): from sentence_transformers import SentenceTransformer return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) def emb(model, text): return model.encode([text], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)[0] def store_with_augmentation(mem, model, cue, target, timestamp=0.0): """Store a memory with heuristic paraphrases.""" cue_emb = emb(model, cue) target_emb = emb(model, target) paras = generate_paraphrases_heuristic(cue, n=3) para_embs = [emb(model, p) for p in paras] if paras else None return mem.store(cue_emb, target_emb, cue_variants=para_embs, metadata={"cue": cue, "target": target}, timestamp=timestamp) def test_recall(mem, model, query, expected_target_substr): """Test if recall contains expected substring.""" results = mem.recall(emb(model, query), top_k=3) for r in results: if expected_target_substr.lower() in r.metadata.get("target", "").lower(): return True, r.similarity, r.metadata["target"] return False, 0.0, results[0].metadata.get("target", "???") if results else "no results" def main(): print("=" * 60) print("Experiment P6: Multi-turn Conversation") print("=" * 60) model = load_model() mem = HippocampalMemory(embed_dim=384) # ===== Day 1: Database troubleshooting session ===== print("\n--- Day 1: Database Troubleshooting ---") day1_memories = [ ("The database is really slow", "The users table is missing an index on created_at"), ("What's the query that's slow?", "SELECT * FROM users WHERE created_at > ? ORDER BY created_at"), ("How many rows in the users table?", "About 2.3 million rows, growing 10K per day"), ("Who has access to the database?", "Only the backend team: Alice, Bob, and Charlie"), ("What's the database host?", "PostgreSQL on db.internal:5432, running version 15.2"), ] for cue, target in day1_memories: store_with_augmentation(mem, model, cue, target, timestamp=1.0) # ===== Day 2: Deployment work ===== print("--- Day 2: Deployment ---") day2_memories = [ ("How do we deploy?", "Blue-green deployment via GitHub Actions, config in .github/workflows/deploy.yml"), ("What's the rollback procedure?", "Switch the load balancer back to the previous blue/green slot"), ("Where are the deployment logs?", "GitHub Actions logs, also mirrored to Loki at loki.internal:3100"), ("Who approves production deploys?", "Requires approval from Alice or David in the #deploys channel"), ] for cue, target in day2_memories: store_with_augmentation(mem, model, cue, target, timestamp=2.0) # ===== Day 3: Monitoring setup ===== print("--- Day 3: Monitoring ---") day3_memories = [ ("Set up monitoring for the database", "Prometheus scrapes pg_exporter on db.internal:9187, dashboard in Grafana"), ("What alerts do we have?", "PagerDuty alerts for: CPU>80%, disk>90%, replication lag>30s"), ("Where's the Grafana dashboard?", "grafana.internal/d/postgres-overview, login with SSO"), ] for cue, target in day3_memories: store_with_augmentation(mem, model, cue, target, timestamp=3.0) print(f"\nTotal memories: {mem.stats()}") # ===== Test: Cross-session recall ===== print("\n=== Cross-session Recall Tests ===\n") tests = [ # (query, expected_substring, description) # Day 1 recall ("DB is slow again", "index", "Day 1: DB slow → index"), ("How big is the users table?", "million", "Day 1: table size"), ("Who can access the database?", "Alice", "Day 1: DB access"), ("What Postgres version?", "15.2", "Day 1: PG version"), # Day 2 recall ("How to deploy the new version?", "blue-green", "Day 2: deploy method"), ("How to rollback?", "load balancer", "Day 2: rollback"), ("Who approves deploys?", "Alice", "Day 2: deploy approval"), # Day 3 recall ("Where's the monitoring dashboard?", "grafana", "Day 3: Grafana URL"), ("What alerts are configured?", "PagerDuty", "Day 3: alerts"), # Cross-day inference ("The database is slow, what index is missing?", "created_at", "Cross: DB slow → specific index"), ("I need to check deploy logs", "Loki", "Cross: deploy logs → Loki"), ("Database monitoring exporter", "pg_exporter", "Cross: DB + monitoring"), ] correct = 0 for query, expected, desc in tests: found, sim, got = test_recall(mem, model, query, expected) status = "✓" if found else "✗" if found: correct += 1 print(f" {status} [{sim:.2f}] {desc}") if not found: print(f" Expected '{expected}', got: '{got[:50]}...'") n = len(tests) print(f"\n Total: {correct}/{n} ({correct/n:.0%})") # ===== Test: Multi-hop across days ===== print("\n=== Multi-hop Across Days ===\n") # Store explicit chains across days # Day 1: "DB slow" → "missing index" # Day 3: "monitoring DB" → "pg_exporter" # Chain: "DB slow" → (hop1) "missing index" → ... can we reach monitoring? # Actually, multi-hop needs explicit chain links. Let's store some: store_with_augmentation(mem, model, "The missing index caused the slow query", "Added index and set up monitoring to prevent recurrence", timestamp=3.5) chain = mem.recall_chain(emb(model, "database is slow"), hops=3) print(" Chain from 'database is slow':") for r in chain: print(f" hop {r.hop_distance}: {r.metadata.get('target', '?')[:60]}...") # ===== Test: Memory conflicts ===== print("\n=== Memory Update / Conflict ===\n") # Store contradicting info store_with_augmentation(mem, model, "What Postgres version?", "Upgraded to PostgreSQL 16.1 last night", timestamp=4.0) # Which version does it recall? results = mem.recall(emb(model, "What Postgres version are we running?"), top_k=2) print(" Query: 'What Postgres version?'") for r in results: print(f" [{r.similarity:.2f}] {r.metadata.get('target', '?')}") print(" Note: Both old (15.2) and new (16.1) returned — recency sorting needed") if __name__ == "__main__": main()