Skip to content

Commit a08740f

Browse files
committed
feat: training data pipeline — agent framework output → JEPA training
Implement the core training data pipeline that feeds ATN agent framework output (conversations, tool calls, execution records) into JEPA training. This is the local encoder side of the latency symmetry: backprop runs locally, only weight deltas cross the network. Stories implemented: - 1.1 TextTrainingDataSource: JSONL → byte-tokenized batches with PII scrubbing - 1.2 TextMasker: span masking for text (3-15 token spans, ~20% mask ratio) - 1.3 TextJEPA: self-supervised text encoder with masked prediction - 2.2 train_vljepa_on_task(): unified training for text/visual/multimodal - 5.1 BehavioralProfile: EMA accumulator for alignment-based inference pricing 64 tests, all passing. 🤖 Generated with [Claude Code](https://claude.com/claude-code)
1 parent c6bbf46 commit a08740f

9 files changed

Lines changed: 3615 additions & 0 deletions

BACKLOG_TRAINING_DATA.md

Lines changed: 692 additions & 0 deletions
Large diffs are not rendered by default.

nodes/common/behavioral_profile.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""
2+
Behavioral Semantic Profile — EMA accumulator for alignment pricing.
3+
4+
After each training cycle, the node's agent interaction data is encoded into
5+
K-vectors (the same latent representation used by VL-JEPA's SemanticPredictor).
6+
These are accumulated into an exponential moving average (EMA) that represents
7+
the node's behavioral history.
8+
9+
profile_t = decay * profile_{t-1} + (1 - decay) * current_embeddings
10+
11+
With decay=0.998 and daily updates, old behavior decays to 1/e in ~500 days.
12+
This prevents gaming — you can't flip your agent prompts today and get cheap
13+
inference tomorrow.
14+
15+
The profile is:
16+
- A single [K, D] tensor (~50KB at K=32, D=384)
17+
- Persisted locally across restarts
18+
- Published as a hash on-chain per epoch (privacy-preserving)
19+
- Used at inference time for K-NN alignment scoring (Story 5.3)
20+
21+
Story 5.1 (BACKLOG_TRAINING_DATA.md)
22+
"""
23+
24+
import hashlib
25+
import logging
26+
from pathlib import Path
27+
from typing import Dict, Optional
28+
29+
import torch
30+
import torch.nn.functional as F
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
class BehavioralProfile:
36+
"""Accumulated behavioral semantic profile for a node.
37+
38+
Maintains an EMA over K-vector representations of the node's training
39+
data. Updated after each training cycle with the mean-pooled embeddings
40+
from that cycle's data.
41+
42+
Usage:
43+
profile = BehavioralProfile(K=32, D=384)
44+
# After each training cycle:
45+
profile.update(embeddings) # embeddings: (num_samples, K, D)
46+
# Persist:
47+
profile.save("~/.atn/profile.pt")
48+
# On-chain attestation:
49+
profile_hash = profile.hash()
50+
"""
51+
52+
def __init__(
53+
self,
54+
K: int = 32,
55+
D: int = 384,
56+
decay: float = 0.998,
57+
profile_path: Optional[str] = None,
58+
):
59+
"""
60+
Args:
61+
K: Number of latent vectors (matches SemanticPredictor.num_latent_vectors)
62+
D: Embedding dimension (matches VLJEPAConfig.embed_dim)
63+
decay: EMA decay factor. 0.998 with daily updates → ~500 day half-life.
64+
profile_path: Path to load/save persisted profile.
65+
"""
66+
self.K = K
67+
self.D = D
68+
self.decay = decay
69+
self.profile_path = profile_path
70+
71+
# The accumulated profile — starts as zeros (no history)
72+
self._profile: torch.Tensor = torch.zeros(K, D)
73+
self._initialized: bool = False
74+
self._update_count: int = 0
75+
76+
# Try to load persisted profile
77+
if profile_path:
78+
self._load(profile_path)
79+
80+
@property
81+
def profile(self) -> torch.Tensor:
82+
"""Current behavioral profile tensor [K, D]."""
83+
return self._profile
84+
85+
@property
86+
def initialized(self) -> bool:
87+
"""Whether the profile has received at least one update."""
88+
return self._initialized
89+
90+
@property
91+
def update_count(self) -> int:
92+
"""Number of updates applied to this profile."""
93+
return self._update_count
94+
95+
def update(self, embeddings: torch.Tensor) -> None:
96+
"""Update the behavioral profile with new training cycle embeddings.
97+
98+
Args:
99+
embeddings: Tensor of shape (N, K, D) or (K, D).
100+
N = number of samples from this training cycle.
101+
If (N, K, D), mean-pools over N first.
102+
K and D must match profile dimensions.
103+
"""
104+
if embeddings.dim() == 3:
105+
# (N, K, D) → mean over samples → (K, D)
106+
current = embeddings.mean(dim=0)
107+
elif embeddings.dim() == 2:
108+
current = embeddings
109+
else:
110+
raise ValueError(
111+
f"Expected 2D or 3D tensor, got shape {embeddings.shape}"
112+
)
113+
114+
if current.shape != (self.K, self.D):
115+
raise ValueError(
116+
f"Embedding shape {current.shape} doesn't match profile "
117+
f"({self.K}, {self.D})"
118+
)
119+
120+
current = current.detach().cpu()
121+
122+
if not self._initialized:
123+
# First update: initialize directly (no decay of zeros)
124+
self._profile = current.clone()
125+
self._initialized = True
126+
else:
127+
# EMA: profile_t = decay * profile_{t-1} + (1 - decay) * current
128+
self._profile = self.decay * self._profile + (1 - self.decay) * current
129+
130+
self._update_count += 1
131+
132+
def similarity_to(self, other: "BehavioralProfile") -> float:
133+
"""Cosine similarity between this profile and another.
134+
135+
Args:
136+
other: Another BehavioralProfile to compare against.
137+
138+
Returns:
139+
Cosine similarity in [-1, 1]. Higher = more similar behavior.
140+
"""
141+
if not self._initialized or not other._initialized:
142+
return 0.0
143+
144+
# Flatten to 1D for cosine similarity
145+
a = self._profile.flatten()
146+
b = other._profile.flatten()
147+
return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
148+
149+
def distance_to_embedding(self, embedding: torch.Tensor) -> float:
150+
"""Cosine distance from this profile to a single embedding.
151+
152+
Used for K-NN alignment scoring at inference time (Story 5.3):
153+
- profile ↔ jurisdiction standards
154+
- profile ↔ request semantics
155+
156+
Args:
157+
embedding: Tensor of shape (K, D) — e.g. jurisdiction standards
158+
encoded through the model, or an inference request's K-vectors.
159+
160+
Returns:
161+
Cosine similarity in [-1, 1].
162+
"""
163+
if not self._initialized:
164+
return 0.0
165+
166+
if embedding.dim() == 3 and embedding.shape[0] == 1:
167+
embedding = embedding.squeeze(0)
168+
169+
a = self._profile.flatten()
170+
b = embedding.detach().cpu().flatten()
171+
return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
172+
173+
def hash(self) -> str:
174+
"""Compute a deterministic hash of the profile for on-chain attestation.
175+
176+
The hash is published on-chain per epoch — it links training activity
177+
to behavioral signature without revealing the profile itself.
178+
179+
Returns:
180+
Hex string (SHA-256 of profile tensor bytes).
181+
"""
182+
# Quantize to float16 for deterministic hashing across platforms
183+
quantized = self._profile.half()
184+
raw_bytes = quantized.numpy().tobytes()
185+
return hashlib.sha256(raw_bytes).hexdigest()
186+
187+
def save(self, path: Optional[str] = None) -> None:
188+
"""Persist profile to disk.
189+
190+
Args:
191+
path: File path. Uses self.profile_path if not specified.
192+
"""
193+
save_path = Path(path or self.profile_path)
194+
save_path.parent.mkdir(parents=True, exist_ok=True)
195+
torch.save(
196+
{
197+
"profile": self._profile,
198+
"K": self.K,
199+
"D": self.D,
200+
"decay": self.decay,
201+
"initialized": self._initialized,
202+
"update_count": self._update_count,
203+
},
204+
save_path,
205+
)
206+
logger.info("Saved behavioral profile to %s (%d updates)", save_path, self._update_count)
207+
208+
def _load(self, path: str) -> None:
209+
"""Load profile from disk if it exists."""
210+
p = Path(path)
211+
if not p.exists():
212+
logger.debug("No persisted profile at %s — starting fresh", path)
213+
return
214+
215+
try:
216+
data = torch.load(p, map_location="cpu", weights_only=True)
217+
if data["K"] != self.K or data["D"] != self.D:
218+
logger.warning(
219+
"Profile dimensions mismatch: saved (%d, %d) vs expected (%d, %d). "
220+
"Starting fresh.",
221+
data["K"], data["D"], self.K, self.D,
222+
)
223+
return
224+
225+
self._profile = data["profile"]
226+
self._initialized = data["initialized"]
227+
self._update_count = data["update_count"]
228+
logger.info(
229+
"Loaded behavioral profile from %s (%d updates, decay=%.4f)",
230+
path, self._update_count, self.decay,
231+
)
232+
except Exception as e:
233+
logger.warning("Failed to load profile from %s: %s", path, e)
234+
235+
def to_dict(self) -> Dict:
236+
"""Serialize profile metadata (not the tensor) for reporting."""
237+
return {
238+
"K": self.K,
239+
"D": self.D,
240+
"decay": self.decay,
241+
"initialized": self._initialized,
242+
"update_count": self._update_count,
243+
"hash": self.hash() if self._initialized else None,
244+
"profile_norm": self._profile.norm().item(),
245+
}
246+
247+
248+
def compute_training_embeddings(
249+
trainer,
250+
data_source,
251+
max_batches: int = 50,
252+
) -> torch.Tensor:
253+
"""Extract K-vector embeddings from training data using a trained model.
254+
255+
After a training cycle completes, this function runs the trained model
256+
on the same data to produce K-vector embeddings that represent the
257+
semantic content of the training data. These embeddings are then used
258+
to update the behavioral profile.
259+
260+
Args:
261+
trainer: A TextJEPATrainer (or JEPATrainer) with a trained model.
262+
data_source: Iterable yielding training batches.
263+
max_batches: Maximum batches to process (limits compute cost).
264+
265+
Returns:
266+
Tensor of shape (N, D) where N = total samples processed, D = embed_dim.
267+
Each row is the mean-pooled context encoder output for one sample.
268+
"""
269+
trainer.model.eval()
270+
all_embeddings = []
271+
272+
with torch.no_grad():
273+
for i, batch in enumerate(data_source):
274+
if i >= max_batches:
275+
break
276+
277+
token_ids = batch["token_ids"].to(next(trainer.model.parameters()).device)
278+
279+
# Get context encoder output (full sequence, no masking)
280+
embeddings = trainer.model.context_encoder(token_ids) # (B, S, D)
281+
282+
# Mean-pool over sequence length → (B, D)
283+
pooled = embeddings.mean(dim=1)
284+
all_embeddings.append(pooled.cpu())
285+
286+
if not all_embeddings:
287+
return torch.zeros(0)
288+
289+
return torch.cat(all_embeddings, dim=0) # (N, D)

0 commit comments

Comments
 (0)