Compare commits

...

1 Commits

Author SHA1 Message Date
Timothy d78473ff20 chore: experiment 2026-03-24 16:02:32 -07:00
5 changed files with 1691 additions and 17 deletions
+101 -1
View File
@@ -137,6 +137,32 @@ def append_episodic_entry(content: str) -> None:
with ep_path.open("a", encoding="utf-8") as f:
f.write(block)
# Immediately create a bare index entry (no enrichment — that happens at
# consolidation time). Wrapped so any indexing failure never interrupts
# the diary write.
try:
_post_append_index_hook(today.strftime("%Y-%m-%d"), timestamp, content.strip())
except Exception:
logger.warning("queen_memory: index hook failed on diary append", exc_info=True)
def _post_append_index_hook(date_str: str, timestamp: str, prose: str) -> None:
"""Create a bare MemoryEntry in the index for a freshly-appended diary section."""
from framework.agents.queen.queen_memory_index import (
get_entry,
index_entry_from_diary_section,
load_index,
put_entry,
save_index,
)
index = load_index()
entry_id = f"{date_str}:{timestamp}"
if get_entry(index, entry_id) is None:
entry = index_entry_from_diary_section(date_str, timestamp, prose)
put_entry(index, entry)
save_index(index)
def seed_if_missing() -> None:
"""Create MEMORY.md with a blank template if it doesn't exist yet."""
@@ -311,9 +337,10 @@ async def consolidate_queen_memory(
llm: LLMProvider instance (must support acomplete()).
"""
try:
logger.info("queen_memory: consolidation triggered for session %s", session_id)
session_context = read_session_context(session_dir)
if not session_context:
logger.debug("queen_memory: no session context, skipping consolidation")
logger.info("queen_memory: no session context found, skipping")
return
logger.info("queen_memory: consolidating memory for session %s ...", session_id)
@@ -388,6 +415,14 @@ async def consolidate_queen_memory(
len(diary_entry),
)
# Update the memory index for today's entries: enrich, embed, link,
# and optionally evolve neighbour metadata. Wrapped so failures never
# block or disrupt the main consolidation path.
try:
await _update_index_after_consolidation(today.strftime("%Y-%m-%d"), llm)
except Exception:
logger.warning("queen_memory: index update failed after consolidation", exc_info=True)
except Exception:
tb = traceback.format_exc()
logger.exception("queen_memory: consolidation failed")
@@ -401,3 +436,68 @@ async def consolidate_queen_memory(
)
except OSError:
pass # Cannot write error file; original exception already logged
async def _update_index_after_consolidation(date_str: str, llm: object) -> None:
"""Enrich, embed, link, and evolve today's memory index entries.
Called after the main semantic/diary LLM writes complete. All failures
are silently logged this function must never propagate exceptions.
"""
from framework.agents.queen.queen_memory_index import (
embed_text,
embeddings_enabled,
get_embed_model,
link_entry,
load_index,
maybe_evolve_neighbors,
put_entry,
rebuild_index_for_date,
save_index,
)
# Phase 1 — ensure all diary sections are in the index and enriched
await rebuild_index_for_date(date_str, llm=llm)
if not embeddings_enabled():
logger.debug("queen_memory: embeddings not configured, skipping embed/link/evolve")
return # Phases 2-5 require embeddings
logger.info("queen_memory: running embed/link/evolve for %s", date_str)
# Phases 2-5 — embed, link, evolve any entries still missing vectors
index = load_index()
entries = index.get("entries", {})
newly_embedded: list[str] = []
for entry_id, raw in entries.items():
if not entry_id.startswith(date_str):
continue
if raw.get("embedding") is not None:
continue
prose = raw.get("summary", "")
if not prose:
continue
vec = await embed_text(prose)
if vec is not None:
raw["embedding"] = vec
index["embed_model"] = get_embed_model()
index["embed_dim"] = len(vec)
newly_embedded.append(entry_id)
if newly_embedded:
save_index(index)
# Phase 3 — cross-reference linking for newly embedded entries
for entry_id in newly_embedded:
linked = link_entry(index, entry_id)
# Phase 5 — memory evolution for top neighbours
if linked:
await maybe_evolve_neighbors(entry_id, linked, index, llm)
if newly_embedded:
save_index(index)
logger.debug(
"queen_memory: indexed %d new embedding(s) for %s",
len(newly_embedded),
date_str,
)
@@ -0,0 +1,788 @@
"""Structured index for queen episodic memory entries.
Attaches rich metadata, embedding vectors, cross-reference links, and
retrieval counts to every diary entry. The index lives at:
~/.hive/queen/memories/index.json
It is a *sidecar* to the existing markdown diary files those files are
never modified by this module.
Configuration
-------------
Set ``HIVE_EMBED_MODEL`` to an embedding model name supported by litellm
(e.g. ``text-embedding-3-small``) to enable semantic search. When unset
the system degrades gracefully: enrichment (keywords/tags/category) still
works via the consolidation LLM, and recall_diary falls back to substring
matching.
Phases implemented
------------------
Phase 1 - Index I/O + semantic enrichment (keywords, category, tags)
Phase 2 - Embedding storage + semantic search via cosine similarity
Phase 3 - Cross-reference linking (bidirectional related[] links)
Phase 4 - Importance tracking (retrieval counts + recency decay)
Phase 5 - Memory evolution (LLM-driven neighbour metadata refinement)
"""
from __future__ import annotations
import json
import logging
import math
import re
from dataclasses import asdict, dataclass, field
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Category vocabulary — fixed to prevent unbounded drift
# ---------------------------------------------------------------------------
_CATEGORIES = [
"agent_build",
"infrastructure",
"user_preference",
"communication_style",
"diagnostic_learning",
"milestone",
"pipeline",
"data_processing",
"other",
]
# ---------------------------------------------------------------------------
# MemoryEntry dataclass
# ---------------------------------------------------------------------------
@dataclass
class MemoryEntry:
"""Rich metadata record for a single diary section (one ### HH:MM block)."""
# Identity — "YYYY-MM-DD:HH:MM" matches the diary ### timestamp
id: str
date: str # "YYYY-MM-DD"
timestamp: str # "HH:MM"
# Content preview (not full prose — just enough for search result context)
summary: str # first 300 chars of the section's prose
# Phase 1 — semantic enrichment
keywords: list[str] = field(default_factory=list)
category: str = "other"
tags: list[str] = field(default_factory=list)
# Phase 3 — cross-reference links
related: list[str] = field(default_factory=list)
# Phase 4 — importance tracking
retrieval_count: int = 0
last_retrieved: str | None = None # ISO-format datetime string
# Phase 2 — embedding vector (None when HIVE_EMBED_MODEL is unset)
embedding: list[float] | None = None
# Whether enrichment has been applied (used to skip re-enrichment)
enriched: bool = False
# ---------------------------------------------------------------------------
# Index I/O
# ---------------------------------------------------------------------------
_EMPTY_INDEX: dict[str, Any] = {
"version": 1,
"embed_model": None,
"embed_dim": None,
"entries": {},
}
def _queen_memories_dir() -> Path:
return Path.home() / ".hive" / "queen" / "memories"
def index_path() -> Path:
return _queen_memories_dir() / "index.json"
def load_index() -> dict[str, Any]:
"""Load the index from disk. Returns a fresh empty index on any error."""
p = index_path()
if not p.exists():
return {**_EMPTY_INDEX, "entries": {}}
try:
data = json.loads(p.read_text(encoding="utf-8"))
if not isinstance(data, dict) or "entries" not in data:
raise ValueError("Malformed index")
return data
except Exception as exc:
logger.warning("queen_memory_index: index.json unreadable (%s), starting fresh", exc)
return {**_EMPTY_INDEX, "entries": {}}
def save_index(index: dict[str, Any]) -> None:
"""Atomically write the index to disk (tmp file → rename)."""
p = index_path()
p.parent.mkdir(parents=True, exist_ok=True)
tmp = p.with_suffix(".json.tmp")
tmp.write_text(json.dumps(index, ensure_ascii=False), encoding="utf-8")
tmp.replace(p)
def get_entry(index: dict[str, Any], entry_id: str) -> MemoryEntry | None:
"""Deserialise one entry from the index dict, or None if missing."""
raw = index.get("entries", {}).get(entry_id)
if raw is None:
return None
try:
return MemoryEntry(**{k: raw[k] for k in MemoryEntry.__dataclass_fields__ if k in raw})
except Exception as exc:
logger.warning("queen_memory_index: failed to deserialise entry %s: %s", entry_id, exc)
return None
def put_entry(index: dict[str, Any], entry: MemoryEntry) -> None:
"""Serialise and insert/overwrite one entry in the index dict (mutates in place)."""
index.setdefault("entries", {})[entry.id] = asdict(entry)
# ---------------------------------------------------------------------------
# Configuration helpers
# ---------------------------------------------------------------------------
def get_embed_model() -> str | None:
"""Return the configured embedding model (e.g. 'openai/text-embedding-3-small').
Reads from the ``embedding`` section of ~/.hive/configuration.json.
Falls back to the ``HIVE_EMBED_MODEL`` env var for backward compatibility.
"""
from framework.config import get_embed_model as _cfg_get_embed_model
return _cfg_get_embed_model()
def embeddings_enabled() -> bool:
return bool(get_embed_model())
def _detect_model_change(index: dict[str, Any]) -> bool:
"""Return True if the stored embed model differs from the current env var."""
current = get_embed_model()
stored = index.get("embed_model")
return current != stored
def _clear_embeddings(index: dict[str, Any]) -> None:
"""Clear all stored vectors when the embedding model has changed."""
for raw in index.get("entries", {}).values():
raw["embedding"] = None
index["embed_model"] = get_embed_model()
index["embed_dim"] = None
logger.info("queen_memory_index: embedding model changed — cleared cached vectors")
# ---------------------------------------------------------------------------
# Embedding calls (Phase 2)
# ---------------------------------------------------------------------------
def _embed_kwargs() -> dict[str, Any]:
"""Build the kwargs dict for litellm.aembedding() from configuration."""
from framework.config import get_embed_api_base, get_embed_api_key
kwargs: dict[str, Any] = {}
api_key = get_embed_api_key()
if api_key:
kwargs["api_key"] = api_key
api_base = get_embed_api_base()
if api_base:
kwargs["api_base"] = api_base
return kwargs
async def embed_text(text: str) -> list[float] | None:
"""Embed *text* via litellm.aembedding().
Returns None (with a WARNING log) on any failure or when no embedding
model is configured.
"""
model = get_embed_model()
if not model:
return None
try:
import litellm # already a project dependency
logger.info("queen_memory_index: embedding text (%d chars) via %s", len(text), model)
resp = await litellm.aembedding(model=model, input=[text], **_embed_kwargs())
vec: list[float] = resp.data[0]["embedding"]
logger.info("queen_memory_index: embedding complete (dim=%d)", len(vec))
return vec
except Exception as exc:
logger.warning("queen_memory_index: embed_text failed (%s)", exc)
return None
async def embed_batch(texts: list[str]) -> list[list[float] | None]:
"""Embed a list of texts, returning a parallel list of vectors (or None)."""
model = get_embed_model()
if not model:
return [None] * len(texts)
try:
import litellm
logger.info(
"queen_memory_index: batch embedding %d text(s) via %s", len(texts), model
)
resp = await litellm.aembedding(model=model, input=texts, **_embed_kwargs())
vecs = [item["embedding"] for item in resp.data]
logger.info(
"queen_memory_index: batch embedding complete (dim=%d)", len(vecs[0]) if vecs else 0
)
return vecs
except Exception as exc:
logger.warning("queen_memory_index: embed_batch failed (%s), retrying individually", exc)
# Fall back to individual calls
results: list[list[float] | None] = []
for t in texts:
results.append(await embed_text(t))
return results
# ---------------------------------------------------------------------------
# Vector math (Phase 2)
# ---------------------------------------------------------------------------
def cosine_similarity(a: list[float] | None, b: list[float] | None) -> float:
"""Return cosine similarity in [0, 1]. Returns 0.0 on null or zero-norm inputs."""
if not a or not b:
return 0.0
try:
import numpy as np # already a project dependency
va = np.array(a, dtype=np.float32)
vb = np.array(b, dtype=np.float32)
norm_a = float(np.linalg.norm(va))
norm_b = float(np.linalg.norm(vb))
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(va, vb) / (norm_a * norm_b))
except Exception:
return 0.0
def find_knn(
query_vec: list[float],
index: dict[str, Any],
k: int = 5,
exclude_id: str | None = None,
) -> list[tuple[str, float]]:
"""Return up to *k* nearest neighbours as (entry_id, similarity) pairs, descending."""
scores: list[tuple[str, float]] = []
for entry_id, raw in index.get("entries", {}).items():
if entry_id == exclude_id:
continue
vec = raw.get("embedding")
if not vec:
continue
sim = cosine_similarity(query_vec, vec)
scores.append((entry_id, sim))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:k]
# ---------------------------------------------------------------------------
# Semantic search (Phase 2)
# ---------------------------------------------------------------------------
async def semantic_search(
query: str,
index: dict[str, Any],
*,
k: int = 20,
date_range: tuple[str, str] | None = None,
) -> list[tuple[str, float]]:
"""Embed *query* and return top-k (entry_id, score) pairs.
Returns [] if embeddings are disabled or the embed call fails.
date_range is an inclusive (YYYY-MM-DD, YYYY-MM-DD) filter applied
before ranking.
"""
if not embeddings_enabled():
return []
query_vec = await embed_text(query)
if query_vec is None:
return []
candidates: list[tuple[str, float]] = []
for entry_id, raw in index.get("entries", {}).items():
if date_range:
d = raw.get("date", "")
if d < date_range[0] or d > date_range[1]:
continue
vec = raw.get("embedding")
if not vec:
continue
sim = cosine_similarity(query_vec, vec)
candidates.append((entry_id, sim))
candidates.sort(key=lambda x: x[1], reverse=True)
return candidates[:k]
# ---------------------------------------------------------------------------
# Importance tracking (Phase 4)
# ---------------------------------------------------------------------------
def importance_score(entry: MemoryEntry, now: datetime | None = None) -> float:
"""Composite importance: log1p(count) * recency decay (half-life 30 days).
Returns 0.0 for entries that have never been retrieved.
"""
if entry.retrieval_count == 0:
return 0.0
count_score = math.log1p(entry.retrieval_count)
if entry.last_retrieved:
try:
last = datetime.fromisoformat(entry.last_retrieved)
days_since = ((now or datetime.now()) - last).total_seconds() / 86400
decay = math.exp(-days_since / 30)
except ValueError:
decay = 0.0
else:
decay = 0.0
return count_score * decay
def record_retrieval(
index: dict[str, Any],
entry_ids: list[str],
*,
auto_save: bool = True,
) -> None:
"""Increment retrieval_count and update last_retrieved for each entry_id."""
now_str = datetime.now().isoformat()
entries = index.get("entries", {})
for eid in entry_ids:
if eid in entries:
entries[eid]["retrieval_count"] = entries[eid].get("retrieval_count", 0) + 1
entries[eid]["last_retrieved"] = now_str
if auto_save:
try:
save_index(index)
except Exception as exc:
logger.warning("queen_memory_index: failed to save index after retrieval: %s", exc)
# ---------------------------------------------------------------------------
# Hybrid re-ranking (Phase 4)
# ---------------------------------------------------------------------------
def hybrid_search(
query: str,
index: dict[str, Any],
candidate_ids: list[str],
semantic_scores: dict[str, float],
*,
keyword_weight: float = 0.3,
semantic_weight: float = 0.7,
) -> list[tuple[str, float]]:
"""Re-rank candidates combining semantic cosine, keyword overlap, and importance.
Combined score = semantic_weight * cosine
+ keyword_weight * keyword_overlap
+ 0.1 * normalised_importance
keyword_overlap = |query_terms entry.keywords| / max(1, |entry.keywords|)
normalised_importance is scaled to [0, 1] relative to the highest importance
in the candidate set.
"""
query_terms = set(re.findall(r"\w+", query.lower()))
now = datetime.now()
raw_scores: list[tuple[str, float]] = []
imp_values: list[float] = []
for eid in candidate_ids:
entry = get_entry(index, eid)
if entry is None:
continue
sem = semantic_scores.get(eid, 0.0)
kw_list = [k.lower() for k in entry.keywords]
overlap = len(query_terms & set(kw_list)) / max(1, len(kw_list))
imp = importance_score(entry, now)
imp_values.append(imp)
raw_scores.append((eid, sem, overlap, imp))
# Normalise importance to [0, 1]
max_imp = max(imp_values) if imp_values else 1.0
if max_imp == 0.0:
max_imp = 1.0
ranked: list[tuple[str, float]] = []
for eid, sem, overlap, imp in raw_scores:
score = (
semantic_weight * sem
+ keyword_weight * overlap
+ 0.1 * (imp / max_imp)
)
ranked.append((eid, score))
ranked.sort(key=lambda x: x[1], reverse=True)
return ranked
# ---------------------------------------------------------------------------
# Cross-reference linking (Phase 3)
# ---------------------------------------------------------------------------
def link_entry(
index: dict[str, Any],
entry_id: str,
similarity_threshold: float = 0.85,
) -> list[str]:
"""Discover k-NN above threshold and add bidirectional related[] links.
Mutates the index dict in place. Returns the list of newly linked
neighbour ids (may be empty).
"""
entries = index.get("entries", {})
raw = entries.get(entry_id)
if not raw or not raw.get("embedding"):
return []
neighbours = find_knn(raw["embedding"], index, k=10, exclude_id=entry_id)
linked: list[str] = []
for nid, sim in neighbours:
if sim < similarity_threshold:
break # sorted descending, so we can stop early
linked.append(nid)
# Update entry
if nid not in raw.setdefault("related", []):
raw["related"].append(nid)
# Update neighbour
neighbour = entries.get(nid)
if neighbour is not None and entry_id not in neighbour.setdefault("related", []):
neighbour["related"].append(entry_id)
return linked
# ---------------------------------------------------------------------------
# Prompt constants for LLM calls
# ---------------------------------------------------------------------------
_ENRICHMENT_SYSTEM = """\
Analyse the following diary entry from an AI assistant's episodic memory.
Extract structured metadata and return it as a JSON object with exactly these keys:
"keywords": list of 5-8 important terms (nouns, verbs, proper names)
"category": exactly one string from this list: agent_build, infrastructure,
user_preference, communication_style, diagnostic_learning, milestone,
pipeline, data_processing, other
"tags": list of 3-5 freeform topic labels (short phrases)
Return ONLY the JSON object. No explanation, no code fences.
"""
_EVOLUTION_SYSTEM = """\
You are refining the metadata of an older memory entry based on a newly discovered
related memory entry.
Given the TWO entries below, decide if the OLDER entry's tags or category should be
updated to better reflect the thematic connection.
Rules:
- Only suggest changes if the connection reveals a clearly missing tag or a category
correction. When in doubt, return {}.
- You may only modify "tags" and "category" never the prose, never keywords.
- Return a JSON object with only the keys you are changing: {"tags": [...], "category": "..."}
or {} if no change is warranted.
Return ONLY the JSON object. No explanation, no code fences.
"""
# ---------------------------------------------------------------------------
# Phase 1 — enrichment helpers
# ---------------------------------------------------------------------------
def _parse_diary_sections(content: str) -> list[tuple[str, str]]:
"""Return (timestamp, prose) pairs from a diary file's ### HH:MM blocks.
The date heading (# ...) is stripped. Non-timestamped content before the
first ### block is ignored.
"""
sections: list[tuple[str, str]] = []
# Split on ### HH:MM markers
parts = re.split(r"###\s*(\d{2}:\d{2})\b", content)
# parts = [pre_text, ts1, prose1, ts2, prose2, ...]
i = 1
while i + 1 < len(parts):
ts = parts[i].strip()
prose = parts[i + 1].strip()
if prose:
sections.append((ts, prose))
i += 2
return sections
def index_entry_from_diary_section(
date_str: str,
timestamp: str,
prose: str,
) -> MemoryEntry:
"""Construct a bare MemoryEntry (no enrichment, no embedding) from a diary section."""
entry_id = f"{date_str}:{timestamp}"
summary = prose[:300].replace("\n", " ")
return MemoryEntry(
id=entry_id,
date=date_str,
timestamp=timestamp,
summary=summary,
)
async def enrich_entry(
entry_text: str,
llm: object,
) -> tuple[list[str], str, list[str]]:
"""Call the consolidation LLM to extract keywords, category, and tags.
Returns ([], "other", []) on any failure so the caller can continue.
"""
try:
resp = await llm.acomplete(
messages=[{"role": "user", "content": entry_text}],
system=_ENRICHMENT_SYSTEM,
max_tokens=256,
json_mode=True,
)
data = json.loads(resp.content)
keywords = [str(k) for k in data.get("keywords", [])][:8]
raw_cat = str(data.get("category", "other"))
category = raw_cat if raw_cat in _CATEGORIES else "other"
tags = [str(t) for t in data.get("tags", [])][:5]
return keywords, category, tags
except Exception as exc:
logger.warning("queen_memory_index: enrich_entry failed (%s)", exc)
return [], "other", []
# ---------------------------------------------------------------------------
# Phase 5 — memory evolution
# ---------------------------------------------------------------------------
async def maybe_evolve_neighbors(
new_entry_id: str,
neighbor_ids: list[str],
index: dict[str, Any],
llm: object,
*,
max_neighbors_to_evolve: int = 2,
) -> None:
"""Potentially refine the tags/category of neighbour entries.
Only mutates metadata (tags, category) never prose, never embeddings.
Failures are logged and silently skipped.
"""
if not neighbor_ids:
return
new_raw = index.get("entries", {}).get(new_entry_id)
if not new_raw:
return
for nid in neighbor_ids[:max_neighbors_to_evolve]:
neighbor_raw = index.get("entries", {}).get(nid)
if not neighbor_raw:
continue
try:
prompt = (
f"NEWER entry ({new_entry_id}):\n"
f"Summary: {new_raw.get('summary', '')}\n"
f"Keywords: {', '.join(new_raw.get('keywords', []))}\n"
f"Tags: {', '.join(new_raw.get('tags', []))}\n\n"
f"OLDER entry ({nid}):\n"
f"Summary: {neighbor_raw.get('summary', '')}\n"
f"Keywords: {', '.join(neighbor_raw.get('keywords', []))}\n"
f"Tags: {', '.join(neighbor_raw.get('tags', []))}\n"
f"Category: {neighbor_raw.get('category', 'other')}"
)
resp = await llm.acomplete(
messages=[{"role": "user", "content": prompt}],
system=_EVOLUTION_SYSTEM,
max_tokens=128,
json_mode=True,
)
updates = json.loads(resp.content)
if not updates:
continue
if "tags" in updates and isinstance(updates["tags"], list):
neighbor_raw["tags"] = [str(t) for t in updates["tags"]][:5]
if "category" in updates:
raw_cat = str(updates["category"])
neighbor_raw["category"] = raw_cat if raw_cat in _CATEGORIES else "other"
logger.debug("queen_memory_index: evolved metadata for entry %s", nid)
except Exception as exc:
logger.warning("queen_memory_index: evolution failed for %s: %s", nid, exc)
# ---------------------------------------------------------------------------
# Index rebuild / backfill
# ---------------------------------------------------------------------------
async def rebuild_index_for_date(
date_str: str,
llm: object | None = None,
) -> int:
"""Parse today's diary file and index any sections not yet in the index.
Optionally enriches new entries via LLM if *llm* is provided.
Returns the count of new entries added.
"""
from framework.agents.queen.queen_memory import episodic_memory_path
from datetime import date as _date
try:
year, month, day = map(int, date_str.split("-"))
d = _date(year, month, day)
except ValueError:
logger.warning("queen_memory_index: invalid date_str %r", date_str)
return 0
ep_path = episodic_memory_path(d)
if not ep_path.exists():
return 0
content = ep_path.read_text(encoding="utf-8")
sections = _parse_diary_sections(content)
if not sections:
return 0
index = load_index()
# Detect embedding model change and clear stale vectors
if embeddings_enabled() and _detect_model_change(index):
_clear_embeddings(index)
added = 0
for ts, prose in sections:
entry_id = f"{date_str}:{ts}"
existing = get_entry(index, entry_id)
if existing is None:
entry = index_entry_from_diary_section(date_str, ts, prose)
elif existing.enriched:
# Already fully processed; update embedding only if missing
entry = existing
else:
entry = existing
# Enrich if LLM provided and not yet enriched
if llm is not None and not entry.enriched:
keywords, category, tags = await enrich_entry(prose, llm)
entry.keywords = keywords
entry.category = category
entry.tags = tags
entry.enriched = True
# Embed if model is configured and vector is missing
if embeddings_enabled() and entry.embedding is None:
vec = await embed_text(prose[:1500]) # cap input length
if vec is not None:
entry.embedding = vec
index["embed_model"] = get_embed_model()
index["embed_dim"] = len(vec)
put_entry(index, entry)
if existing is None:
added += 1
save_index(index)
logger.debug(
"queen_memory_index: indexed %d section(s) for %s, %d new", len(sections), date_str, added
)
return added
async def backfill_index(
llm: object | None = None,
embed: bool = True,
) -> dict[str, int]:
"""Walk all MEMORY-YYYY-MM-DD.md files and index unindexed entries.
This is a one-shot utility call it once after initial deployment to
catch up historical diary files. Not called automatically.
Usage:
uv run python -c "
import asyncio
from framework.agents.queen.queen_memory_index import backfill_index
print(asyncio.run(backfill_index()))
"
"""
memories_dir = _queen_memories_dir()
if not memories_dir.exists():
return {"dates_processed": 0, "entries_added": 0}
total_added = 0
dates_processed = 0
for md_file in sorted(memories_dir.glob("MEMORY-????-??-??.md")):
date_str = md_file.stem.removeprefix("MEMORY-")
if not re.fullmatch(r"\d{4}-\d{2}-\d{2}", date_str):
continue
added = await rebuild_index_for_date(date_str, llm=llm)
total_added += added
dates_processed += 1
logger.info(
"queen_memory_index: backfill complete — %d dates, %d entries added",
dates_processed,
total_added,
)
return {"dates_processed": dates_processed, "entries_added": total_added}
# ---------------------------------------------------------------------------
# Resolve full prose from diary file by entry_id
# ---------------------------------------------------------------------------
def resolve_prose(entry_id: str) -> str:
"""Read the source diary file and return the full prose for *entry_id*.
Returns the summary from the index as a fallback if the file section
cannot be found.
"""
from framework.agents.queen.queen_memory import episodic_memory_path
from datetime import date as _date
try:
date_str, ts = entry_id.split(":", 1)
year, month, day = map(int, date_str.split("-"))
d = _date(year, month, day)
except ValueError:
return ""
ep_path = episodic_memory_path(d)
if not ep_path.exists():
return ""
content = ep_path.read_text(encoding="utf-8")
sections = _parse_diary_sections(content)
for section_ts, prose in sections:
if section_ts == ts:
return prose
return ""
@@ -0,0 +1,656 @@
"""Unit tests for queen_memory_index.py.
All tests run without HIVE_EMBED_MODEL set. Embedding behaviour is tested
via a lightweight mock that injects deterministic fixed vectors.
"""
from __future__ import annotations
import json
import math
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from framework.agents.queen.queen_memory_index import (
MemoryEntry,
_CATEGORIES,
_parse_diary_sections,
backfill_index,
cosine_similarity,
embed_text,
embeddings_enabled,
enrich_entry,
find_knn,
get_embed_model,
get_entry,
hybrid_search,
importance_score,
index_entry_from_diary_section,
index_path,
link_entry,
load_index,
maybe_evolve_neighbors,
put_entry,
rebuild_index_for_date,
record_retrieval,
resolve_prose,
save_index,
semantic_search,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_index(*entries: MemoryEntry) -> dict:
idx = {"version": 1, "embed_model": None, "embed_dim": None, "entries": {}}
for e in entries:
put_entry(idx, e)
return idx
def _entry(
date_str: str = "2026-03-01",
ts: str = "10:00",
summary: str = "test summary",
keywords: list[str] | None = None,
tags: list[str] | None = None,
category: str = "other",
embedding: list[float] | None = None,
retrieval_count: int = 0,
last_retrieved: str | None = None,
related: list[str] | None = None,
) -> MemoryEntry:
return MemoryEntry(
id=f"{date_str}:{ts}",
date=date_str,
timestamp=ts,
summary=summary,
keywords=keywords or [],
tags=tags or [],
category=category,
embedding=embedding,
retrieval_count=retrieval_count,
last_retrieved=last_retrieved,
related=related or [],
)
# ---------------------------------------------------------------------------
# cosine_similarity
# ---------------------------------------------------------------------------
class TestCosineSimilarity:
def test_identical_vectors(self):
v = [1.0, 0.0, 0.0]
assert cosine_similarity(v, v) == pytest.approx(1.0)
def test_orthogonal_vectors(self):
assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
def test_opposite_vectors(self):
# cosine of 180° = -1, but our vectors are floats so it can be -1
result = cosine_similarity([1.0, 0.0], [-1.0, 0.0])
assert result == pytest.approx(-1.0)
def test_none_inputs(self):
assert cosine_similarity(None, [1.0]) == 0.0
assert cosine_similarity([1.0], None) == 0.0
assert cosine_similarity(None, None) == 0.0
def test_zero_vector(self):
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
def test_known_similarity(self):
# [1, 1] vs [1, 0] → cos(45°) ≈ 0.707
result = cosine_similarity([1.0, 1.0], [1.0, 0.0])
assert result == pytest.approx(math.sqrt(2) / 2, abs=1e-4)
# ---------------------------------------------------------------------------
# find_knn
# ---------------------------------------------------------------------------
class TestFindKnn:
def test_returns_sorted_descending(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.9, 0.1])
e3 = _entry("2026-03-01", "11:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2, e3)
results = find_knn([1.0, 0.0], idx, k=3)
ids = [r[0] for r in results]
scores = [r[1] for r in results]
assert ids[0] == "2026-03-01:09:00" # exact match
assert scores[0] == pytest.approx(1.0)
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
def test_excludes_self(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
idx = _make_index(e1)
results = find_knn([1.0, 0.0], idx, k=5, exclude_id="2026-03-01:09:00")
assert results == []
def test_skips_null_embeddings(self):
e1 = _entry("2026-03-01", "09:00", embedding=None)
e2 = _entry("2026-03-01", "10:00", embedding=[1.0, 0.0])
idx = _make_index(e1, e2)
results = find_knn([1.0, 0.0], idx, k=5)
ids = [r[0] for r in results]
assert "2026-03-01:09:00" not in ids
assert "2026-03-01:10:00" in ids
def test_respects_k(self):
entries = [_entry("2026-03-01", f"0{i}:00", embedding=[float(i), 0.0]) for i in range(5)]
idx = _make_index(*entries)
results = find_knn([1.0, 0.0], idx, k=2)
assert len(results) <= 2
# ---------------------------------------------------------------------------
# load_index / save_index (round-trip and atomic write)
# ---------------------------------------------------------------------------
class TestIndexIO:
def test_round_trip(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index(_entry())
idx["embed_model"] = "test-model"
save_index(idx)
loaded = load_index()
assert loaded["embed_model"] == "test-model"
assert "2026-03-01:10:00" in loaded["entries"]
def test_missing_file_returns_empty(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = load_index()
assert idx["entries"] == {}
assert idx["version"] == 1
def test_corrupt_file_returns_empty(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
(tmp_path / "index.json").write_text("not json at all", encoding="utf-8")
idx = load_index()
assert idx["entries"] == {}
def test_atomic_write_uses_tmp_then_rename(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index()
save_index(idx)
# tmp file should be gone after rename
assert not (tmp_path / "index.json.tmp").exists()
assert (tmp_path / "index.json").exists()
# ---------------------------------------------------------------------------
# get_entry / put_entry
# ---------------------------------------------------------------------------
class TestGetPutEntry:
def test_put_and_get_roundtrip(self):
e = _entry(keywords=["foo", "bar"], tags=["t1"], category="milestone")
idx = _make_index()
put_entry(idx, e)
loaded = get_entry(idx, e.id)
assert loaded is not None
assert loaded.keywords == ["foo", "bar"]
assert loaded.category == "milestone"
def test_get_missing_returns_none(self):
idx = _make_index()
assert get_entry(idx, "no-such-id") is None
def test_put_overwrites_existing(self):
e = _entry(summary="original")
idx = _make_index(e)
e2 = _entry(summary="updated")
put_entry(idx, e2)
loaded = get_entry(idx, e.id)
assert loaded.summary == "updated"
# ---------------------------------------------------------------------------
# index_entry_from_diary_section
# ---------------------------------------------------------------------------
class TestIndexEntryFromDiarySection:
def test_id_format(self):
e = index_entry_from_diary_section("2026-03-01", "14:30", "Some prose here.")
assert e.id == "2026-03-01:14:30"
assert e.date == "2026-03-01"
assert e.timestamp == "14:30"
def test_summary_truncated_to_300(self):
prose = "x" * 500
e = index_entry_from_diary_section("2026-03-01", "14:30", prose)
assert len(e.summary) == 300
def test_defaults_empty_enrichment(self):
e = index_entry_from_diary_section("2026-03-01", "14:30", "text")
assert e.keywords == []
assert e.tags == []
assert e.category == "other"
assert e.embedding is None
assert not e.enriched
# ---------------------------------------------------------------------------
# _parse_diary_sections
# ---------------------------------------------------------------------------
class TestParseDiarySections:
def test_parses_two_sections(self):
content = "# March 1, 2026\n\n### 09:00\n\nFirst entry.\n\n### 14:30\n\nSecond entry."
sections = _parse_diary_sections(content)
assert len(sections) == 2
assert sections[0] == ("09:00", "First entry.")
assert sections[1] == ("14:30", "Second entry.")
def test_ignores_content_before_first_timestamp(self):
content = "# Heading\n\nIntro text.\n\n### 10:00\n\nEntry."
sections = _parse_diary_sections(content)
assert len(sections) == 1
assert sections[0][0] == "10:00"
def test_empty_content(self):
assert _parse_diary_sections("") == []
def test_no_timestamp_sections(self):
assert _parse_diary_sections("# Just a heading\n\nSome text.") == []
# ---------------------------------------------------------------------------
# record_retrieval
# ---------------------------------------------------------------------------
class TestRecordRetrieval:
def test_increments_count(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
e = _entry(retrieval_count=2)
idx = _make_index(e)
record_retrieval(idx, [e.id], auto_save=False)
assert idx["entries"][e.id]["retrieval_count"] == 3
def test_sets_last_retrieved(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
e = _entry()
idx = _make_index(e)
record_retrieval(idx, [e.id], auto_save=False)
assert idx["entries"][e.id]["last_retrieved"] is not None
def test_ignores_missing_ids(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"framework.agents.queen.queen_memory_index._queen_memories_dir",
lambda: tmp_path,
)
idx = _make_index()
# Should not raise
record_retrieval(idx, ["nonexistent:00:00"], auto_save=False)
# ---------------------------------------------------------------------------
# importance_score
# ---------------------------------------------------------------------------
class TestImportanceScore:
def test_zero_for_never_retrieved(self):
e = _entry(retrieval_count=0)
assert importance_score(e) == 0.0
def test_positive_for_retrieved_recently(self):
now = datetime.now()
e = _entry(retrieval_count=5, last_retrieved=now.isoformat())
score = importance_score(e, now=now)
assert score > 0.0
def test_decays_over_time(self):
from datetime import timedelta
now = datetime.now()
recent = _entry("2026-03-01", "10:00", retrieval_count=5,
last_retrieved=now.isoformat())
old = _entry("2026-03-01", "11:00", retrieval_count=5,
last_retrieved=(now - timedelta(days=60)).isoformat())
assert importance_score(recent, now=now) > importance_score(old, now=now)
def test_higher_count_higher_score(self):
now = datetime.now()
low = _entry("2026-03-01", "10:00", retrieval_count=1,
last_retrieved=now.isoformat())
high = _entry("2026-03-01", "11:00", retrieval_count=10,
last_retrieved=now.isoformat())
assert importance_score(high, now=now) > importance_score(low, now=now)
# ---------------------------------------------------------------------------
# link_entry (Phase 3)
# ---------------------------------------------------------------------------
class TestLinkEntry:
def test_links_above_threshold(self):
# Two nearly identical vectors should be linked
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.99, 0.01, 0.0])
idx = _make_index(e1, e2)
linked = link_entry(idx, e1.id, similarity_threshold=0.90)
assert e2.id in linked
def test_bidirectional_links(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[1.0, 0.0])
idx = _make_index(e1, e2)
link_entry(idx, e1.id, similarity_threshold=0.90)
assert e2.id in idx["entries"][e1.id]["related"]
assert e1.id in idx["entries"][e2.id]["related"]
def test_does_not_link_below_threshold(self):
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2)
linked = link_entry(idx, e1.id, similarity_threshold=0.90)
assert linked == []
def test_skips_entry_without_embedding(self):
e1 = _entry("2026-03-01", "09:00", embedding=None)
idx = _make_index(e1)
linked = link_entry(idx, e1.id)
assert linked == []
# ---------------------------------------------------------------------------
# hybrid_search (Phase 4)
# ---------------------------------------------------------------------------
class TestHybridSearch:
def test_semantic_score_dominates(self):
e_high = _entry("2026-03-01", "09:00", keywords=["unrelated"])
e_low = _entry("2026-03-01", "10:00", keywords=["pipeline", "agent"])
idx = _make_index(e_high, e_low)
sem_scores = {e_high.id: 0.95, e_low.id: 0.40}
ranked = hybrid_search("pipeline", idx, [e_high.id, e_low.id], sem_scores)
# e_high has much higher semantic score, should still rank first
assert ranked[0][0] == e_high.id
def test_keyword_overlap_breaks_tie(self):
e_kw = _entry("2026-03-01", "09:00", keywords=["pipeline", "agent", "workflow"])
e_no_kw = _entry("2026-03-01", "10:00", keywords=["unrelated", "other"])
idx = _make_index(e_kw, e_no_kw)
# Equal semantic scores
sem_scores = {e_kw.id: 0.80, e_no_kw.id: 0.80}
ranked = hybrid_search("pipeline agent", idx, [e_kw.id, e_no_kw.id], sem_scores)
assert ranked[0][0] == e_kw.id
def test_returns_sorted_descending(self):
entries = [_entry("2026-03-01", f"0{i}:00") for i in range(3)]
idx = _make_index(*entries)
sem_scores = {e.id: float(i) / 10 for i, e in enumerate(entries)}
ids = [e.id for e in entries]
ranked = hybrid_search("query", idx, ids, sem_scores)
scores = [s for _, s in ranked]
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
# ---------------------------------------------------------------------------
# embeddings_enabled / get_embed_model
# ---------------------------------------------------------------------------
class TestEmbeddingsEnabled:
def test_disabled_when_env_unset(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
assert not embeddings_enabled()
assert get_embed_model() is None
def test_enabled_when_env_set(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
assert embeddings_enabled()
assert get_embed_model() == "text-embedding-3-small"
# ---------------------------------------------------------------------------
# embed_text — mocked
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestEmbedText:
async def test_returns_none_when_disabled(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
result = await embed_text("hello")
assert result is None
async def test_returns_vector_when_enabled(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
fake_vec = [0.1, 0.2, 0.3]
mock_resp = MagicMock()
mock_resp.data = [{"embedding": fake_vec}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
result = await embed_text("hello world")
assert result == fake_vec
async def test_returns_none_on_api_failure(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
with patch("litellm.aembedding", new=AsyncMock(side_effect=RuntimeError("API down"))):
result = await embed_text("hello")
assert result is None
# ---------------------------------------------------------------------------
# semantic_search — mocked embeddings
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestSemanticSearch:
async def test_returns_empty_when_disabled(self, monkeypatch):
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
idx = _make_index(_entry(embedding=[1.0, 0.0]))
results = await semantic_search("query", idx)
assert results == []
async def test_finds_nearest_neighbours(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
e1 = _entry("2026-03-01", "09:00", embedding=[1.0, 0.0])
e2 = _entry("2026-03-01", "10:00", embedding=[0.0, 1.0])
idx = _make_index(e1, e2)
query_vec = [1.0, 0.0]
mock_resp = MagicMock()
mock_resp.data = [{"embedding": query_vec}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
results = await semantic_search("query", idx, k=2)
assert results[0][0] == e1.id # closest to [1.0, 0.0]
async def test_date_range_filter(self, monkeypatch):
monkeypatch.setenv("HIVE_EMBED_MODEL", "text-embedding-3-small")
e_in = _entry("2026-03-15", "09:00", embedding=[1.0, 0.0])
e_out = _entry("2026-02-01", "09:00", embedding=[1.0, 0.0])
idx = _make_index(e_in, e_out)
mock_resp = MagicMock()
mock_resp.data = [{"embedding": [1.0, 0.0]}]
with patch("litellm.aembedding", new=AsyncMock(return_value=mock_resp)):
results = await semantic_search(
"query", idx, k=10, date_range=("2026-03-01", "2026-03-31")
)
ids = [r[0] for r in results]
assert e_in.id in ids
assert e_out.id not in ids
# ---------------------------------------------------------------------------
# enrich_entry — mocked LLM
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestEnrichEntry:
async def test_parses_llm_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps(
{"keywords": ["pipeline", "agent"], "category": "pipeline", "tags": ["build", "test"]}
)
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
kw, cat, tags = await enrich_entry("Some diary text", mock_llm)
assert "pipeline" in kw
assert cat == "pipeline"
assert "build" in tags
async def test_rejects_invalid_category(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps(
{"keywords": [], "category": "invented_category", "tags": []}
)
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
_, cat, _ = await enrich_entry("text", mock_llm)
assert cat == "other"
async def test_returns_defaults_on_failure(self):
mock_llm = MagicMock()
mock_llm.acomplete = AsyncMock(side_effect=RuntimeError("LLM down"))
kw, cat, tags = await enrich_entry("text", mock_llm)
assert kw == []
assert cat == "other"
assert tags == []
# ---------------------------------------------------------------------------
# maybe_evolve_neighbors — mocked LLM
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestMaybeEvolveNeighbors:
async def test_updates_tags_on_non_empty_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({"tags": ["new_tag", "updated"]})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00", keywords=["new"], tags=["tag_a"])
old_e = _entry("2026-03-01", "09:00", keywords=["old"], tags=["old_tag"])
idx = _make_index(new_e, old_e)
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
assert "new_tag" in idx["entries"][old_e.id]["tags"]
async def test_no_op_on_empty_response(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00")
old_e = _entry("2026-03-01", "09:00", tags=["original"])
idx = _make_index(new_e, old_e)
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
# Tags unchanged
assert idx["entries"][old_e.id]["tags"] == ["original"]
async def test_silently_handles_llm_failure(self):
mock_llm = MagicMock()
mock_llm.acomplete = AsyncMock(side_effect=RuntimeError("down"))
new_e = _entry("2026-03-01", "10:00")
old_e = _entry("2026-03-01", "09:00")
idx = _make_index(new_e, old_e)
# Must not raise
await maybe_evolve_neighbors(new_e.id, [old_e.id], idx, mock_llm)
async def test_respects_max_neighbors_cap(self):
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = json.dumps({})
mock_llm.acomplete = AsyncMock(return_value=mock_resp)
new_e = _entry("2026-03-01", "10:00")
neighbors = [_entry("2026-03-01", f"0{i}:00") for i in range(5)]
idx = _make_index(new_e, *neighbors)
await maybe_evolve_neighbors(
new_e.id, [n.id for n in neighbors], idx, mock_llm, max_neighbors_to_evolve=2
)
assert mock_llm.acomplete.call_count == 2
# ---------------------------------------------------------------------------
# recall_diary — semantic path and fallback (integration-style)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestRecallDiary:
async def test_substring_fallback_when_embeddings_disabled(
self, tmp_path, monkeypatch
):
"""When HIVE_EMBED_MODEL is not set, recall_diary uses substring matching."""
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
# Write a fake diary file
memories_dir = tmp_path / ".hive" / "queen" / "memories"
memories_dir.mkdir(parents=True)
today_str = "2026-03-24"
(memories_dir / f"MEMORY-{today_str}.md").write_text(
"# March 24, 2026\n\n### 09:00\n\nWorked on the pipeline agent today.\n",
encoding="utf-8",
)
# Patch the path functions
import framework.agents.queen.queen_memory as qm
monkeypatch.setattr(qm, "episodic_memory_path", lambda d=None: memories_dir / f"MEMORY-{today_str}.md")
from framework.tools.queen_memory_tools import recall_diary
result = await recall_diary(query="pipeline", days_back=1)
assert "pipeline agent" in result
async def test_no_results_message(self, monkeypatch):
"""Returns a helpful message when nothing matches."""
monkeypatch.delenv("HIVE_EMBED_MODEL", raising=False)
import framework.agents.queen.queen_memory as qm
# Point to a non-existent path
monkeypatch.setattr(
qm, "episodic_memory_path", lambda d=None: Path("/nonexistent/MEMORY.md")
)
from framework.tools.queen_memory_tools import recall_diary
result = await recall_diary(query="nonexistent topic", days_back=1)
assert "No diary entries" in result
+43 -1
View File
@@ -362,6 +362,48 @@ def get_antigravity_client_secret() -> str | None:
return secret
def get_embed_model() -> str | None:
"""Return the configured embedding model string, or None if not set.
Reads from the ``embedding`` section of ~/.hive/configuration.json:
{
"embedding": {
"provider": "openai",
"model": "text-embedding-3-small",
"api_key_env_var": "OPENAI_API_KEY"
}
}
Returns a litellm-compatible ``"provider/model"`` string, e.g.
``"openai/text-embedding-3-small"``.
Falls back to the ``HIVE_EMBED_MODEL`` environment variable for
backward compatibility.
"""
embed = get_hive_config().get("embedding", {})
if embed.get("provider") and embed.get("model"):
provider = str(embed["provider"]).strip()
model = str(embed["model"]).strip()
if provider and model:
return f"{provider}/{model}"
return os.environ.get("HIVE_EMBED_MODEL") or None
def get_embed_api_key() -> str | None:
"""Return the API key for the embedding provider, or None if not set."""
embed = get_hive_config().get("embedding", {})
api_key_env_var = embed.get("api_key_env_var")
if api_key_env_var:
return os.environ.get(api_key_env_var)
return None
def get_embed_api_base() -> str | None:
"""Return a custom api_base for the embedding provider, or None."""
embed = get_hive_config().get("embedding", {})
return embed.get("api_base") or None
def get_gcu_enabled() -> bool:
"""Return whether GCU (browser automation) is enabled in user config."""
return get_hive_config().get("gcu_enabled", True)
@@ -436,7 +478,7 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
# ---------------------------------------------------------------------------
# RuntimeConfig shared across agent templates
# RuntimeConfig - shared across agent templates
# ---------------------------------------------------------------------------
+103 -15
View File
@@ -8,8 +8,11 @@ written by the queen directly.
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from framework.runner.tool_registry import ToolRegistry
@@ -34,7 +37,7 @@ def write_to_diary(entry: str) -> str:
return "Diary entry recorded."
def recall_diary(query: str = "", days_back: int = 7) -> str:
async def recall_diary(query: str = "", days_back: int = 7) -> str:
"""Search recent diary entries (episodic memory).
Use this when the user asks about what happened in the past "what did we
@@ -45,26 +48,112 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
Args:
query: Optional keyword or phrase to filter entries. If empty, all
recent entries are returned.
days_back: How many days to look back (130). Defaults to 7.
days_back: How many days to look back (1-30). Defaults to 7.
"""
from datetime import date, timedelta
from framework.agents.queen.queen_memory import read_episodic_memory
from framework.agents.queen.queen_memory_index import (
embeddings_enabled,
hybrid_search,
load_index,
record_retrieval,
resolve_prose,
semantic_search,
)
days_back = max(1, min(days_back, 30))
today = date.today()
results: list[str] = []
total_chars = 0
char_budget = 12_000
# ------------------------------------------------------------------
# Semantic path — used when embedding model is configured and query given
# ------------------------------------------------------------------
if query and embeddings_enabled():
logger.info("queen_memory: semantic recall — query=%r days_back=%d", query, days_back)
oldest = (today - timedelta(days=days_back - 1)).strftime("%Y-%m-%d")
newest = today.strftime("%Y-%m-%d")
index = load_index()
sem_results = await semantic_search(
query, index, k=30, date_range=(oldest, newest)
)
if sem_results:
sem_scores = dict(sem_results)
candidate_ids = [eid for eid, _ in sem_results]
ranked = hybrid_search(query, index, candidate_ids, sem_scores)
results: list[str] = []
total_chars = 0
returned_ids: list[str] = []
for entry_id, _score in ranked:
date_str, ts = entry_id.split(":", 1)
prose = resolve_prose(entry_id)
if not prose:
continue
# Format label from date_str
try:
y, m, d_int = map(int, date_str.split("-"))
d = date(y, m, d_int)
label = d.strftime("%B %-d, %Y")
if d == today:
label = f"Today — {label}"
except ValueError:
label = date_str
section = f"## {label} ({ts})\n\n{prose}"
# Also include linked neighbours (Phase 3 expansion)
raw = index.get("entries", {}).get(entry_id, {})
related_prose_parts: list[str] = []
for related_id in raw.get("related", [])[:2]:
if related_id in (eid for eid, _ in ranked):
continue # will appear in main results
rp = resolve_prose(related_id)
if rp:
r_date_str, r_ts = related_id.split(":", 1)
try:
ry, rm, rd = map(int, r_date_str.split("-"))
r_label = date(ry, rm, rd).strftime("%B %-d, %Y")
except ValueError:
r_label = r_date_str
related_prose_parts.append(
f"_Related ({r_label} {r_ts}):_ {rp[:300]}"
)
if related_prose_parts:
section += "\n\n" + "\n\n".join(related_prose_parts)
if total_chars + len(section) > char_budget:
remaining = char_budget - total_chars
if remaining > 200:
section = section[: remaining - 100] + "\n\n…(truncated)"
results.append(section)
returned_ids.append(entry_id)
break
results.append(section)
returned_ids.append(entry_id)
total_chars += len(section)
if results:
record_retrieval(index, returned_ids)
return "\n\n---\n\n".join(results)
# Fall through to substring if semantic found nothing useful
# ------------------------------------------------------------------
# Substring fallback — original behaviour, unchanged
# ------------------------------------------------------------------
results_fb: list[str] = []
total_chars_fb = 0
for offset in range(days_back):
d = today - timedelta(days=offset)
content = read_episodic_memory(d)
if not content:
continue
# If a query is given, only include entries that mention it
if query:
# Check each section (split by ###) for relevance
sections = content.split("### ")
matched = [s for s in sections if query.lower() in s.lower()]
if not matched:
@@ -74,24 +163,23 @@ def recall_diary(query: str = "", days_back: int = 7) -> str:
if d == today:
label = f"Today — {label}"
entry = f"## {label}\n\n{content}"
if total_chars + len(entry) > char_budget:
remaining = char_budget - total_chars
if total_chars_fb + len(entry) > char_budget:
remaining = char_budget - total_chars_fb
if remaining > 200:
# Fit a partial entry within budget
trimmed = content[: remaining - 100] + "\n\n…(truncated)"
results.append(f"## {label}\n\n{trimmed}")
results_fb.append(f"## {label}\n\n{trimmed}")
else:
results.append(f"## {label}\n\n(truncated — hit size limit)")
results_fb.append(f"## {label}\n\n(truncated — hit size limit)")
break
results.append(entry)
total_chars += len(entry)
results_fb.append(entry)
total_chars_fb += len(entry)
if not results:
if not results_fb:
if query:
return f"No diary entries matching '{query}' in the last {days_back} days."
return f"No diary entries found in the last {days_back} days."
return "\n\n---\n\n".join(results)
return "\n\n---\n\n".join(results_fb)
def register_queen_memory_tools(registry: ToolRegistry) -> None: