Skip to content

Commit f950025

Browse files
committed
feat: LLM backbone encoder for JEPA training with local model selection
Add support for using open-weights LLMs (Llama, Qwen, Gemma, Mistral, Phi) as frozen feature extractors for the JEPA training pipeline. The LLM's hidden states provide rich semantic embeddings; only a small projection head (~2-8M params) is trained and distributed via FedAvg. The LLM weights never leave the node. New files: - nodes/common/local_models.py: Model catalog (11 models), host hardware probe (RAM, VRAM, disk, MPS), compatibility filtering, recommendation - nodes/common/llm_backbone.py: LLMBackboneEncoder (frozen LLM + trainable projection head), BackboneJEPATrainer (JEPA training with LLM features), learned layer mixing, EMA target encoder Integration: - ml.py: New backbone training path in train_vljepa_on_task() when backbone_model_id is set in task_spec - text_data.py: Include raw_texts in batch dict for LLM re-tokenization - training_feed.py: Pass backbone_model_id through to task spec - config.py: Add backbone_model_id to ModelConfig - service.py: Wire backbone_model_id from config to training feed, fix governance variable reference bug - ws_server.py: Add local_models and set_local_model WS endpoints - pyproject.toml: Add transformers + accelerate to [network] deps
1 parent 88580b6 commit f950025

9 files changed

Lines changed: 1380 additions & 3 deletions

File tree

atn/ws_server.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,39 @@ async def _do_oauth() -> None:
900900
except Exception as e:
901901
return {"msg_id": msg_id, "ok": False, "error": str(e)}
902902

903+
# Local model selection (LLM backbone for JEPA training)
904+
if msg_type == "local_models":
905+
try:
906+
from nodes.common.local_models import host_status
907+
return {"msg_id": msg_id, "ok": True, "result": host_status()}
908+
except ImportError:
909+
return {"msg_id": msg_id, "ok": False, "error": "Network package not installed"}
910+
except Exception as e:
911+
return {"msg_id": msg_id, "ok": False, "error": str(e)}
912+
913+
if msg_type == "set_local_model":
914+
model_id = msg.get("model_id", "")
915+
try:
916+
from nodes.common.local_models import get_model_spec
917+
if model_id and not get_model_spec(model_id):
918+
return {"msg_id": msg_id, "ok": False, "error": f"Unknown model: {model_id}"}
919+
# Update the autonet config
920+
if hasattr(self.runtime, "autonet") and self.runtime.autonet:
921+
bridge = self.runtime.autonet
922+
if hasattr(bridge, "_autonet_config") and bridge._autonet_config:
923+
bridge._autonet_config.model.backbone_model_id = model_id
924+
# Update active training feed config
925+
if bridge._service and bridge._service._training_feed:
926+
bridge._service._training_feed.config.backbone_model_id = model_id
927+
return {
928+
"msg_id": msg_id, "ok": True,
929+
"result": {"model_id": model_id, "status": "set" if model_id else "cleared"},
930+
}
931+
except ImportError:
932+
return {"msg_id": msg_id, "ok": False, "error": "Network package not installed"}
933+
except Exception as e:
934+
return {"msg_id": msg_id, "ok": False, "error": str(e)}
935+
903936
# New conversation: reset conversation history without changing model
904937
if msg_type == "new_conversation":
905938
await self.runtime.new_conversation()

nodes/common/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,16 @@ class BlobStoreConfig:
5050
@dataclass
5151
class ModelConfig:
5252
"""Model architecture config."""
53-
architecture: str = "jepa" # "simplenet", "jepa", "vl_jepa"
53+
architecture: str = "jepa" # "simplenet", "jepa", "vl_jepa", "backbone"
5454
image_size: int = 32
5555
patch_size: int = 4
5656
embed_dim: int = 192
5757
num_heads: int = 3
5858
encoder_depth: int = 6
5959
predictor_depth: int = 3
6060
predictor_embed_dim: int = 96
61+
# LLM backbone (when architecture="backbone")
62+
backbone_model_id: str = "" # HuggingFace repo ID, e.g. "Qwen/Qwen3-1.7B"
6163

6264

6365
@dataclass

0 commit comments

Comments
 (0)