Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions docs/internal/specs/special-token-dataloaders.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Special Token Dataloaders

## Problem

`IndexedDataset` already supports `Config.tokens = "special"` and returns one activation per example for the CLS token. `OrderedDataLoader` and `ShuffledDataLoader` still reject that mode, even though the public docs and downstream plans already assume CLS-only loading is available.

This blocks a simple workflow for training or analyzing models on only special tokens, for example training an SAE on CLS activations.

## Goal

Add `tokens = "special"` support to the ordered and shuffled activation loaders for a fixed transformer layer.

## Non-goals

- No `tokens = "all"` support in either loader.
- No support for `layer = "all"` in either loader.
- No patch-label filtering for special tokens.
- No change to on-disk shard layout or indexing semantics.

## Requirements

### Functional

1. `OrderedConfig.tokens` must accept `"special"` as well as `"content"`.

2. `OrderedDataLoader` and `ShuffledDataLoader` must both accept:
- `tokens = "special"`
- `layer = <int>` where the layer is present in metadata

3. In special-token mode, each yielded sample corresponds to exactly one example:
- `example_idx` is the example index
- `token_idx` is `-1`
- `act` is the activation stored at token position `0` in the shard

4. Epoch sizes in special-token mode:
- `n_samples == metadata.n_examples`
- `len(loader)` follows the existing `batch_size` and `drop_last` logic

5. Ordered loader ordering in special-token mode:
- samples are yielded in increasing `example_idx`
- `token_idx` is always `-1`

6. Shuffled loader semantics in special-token mode:
- each example appears once per epoch
- order remains deterministic for a fixed seed
- batches still expose the same keys: `act`, `example_idx`, `token_idx`

7. Token labels:
- ordered loader must not attach `token_labels` for special tokens, even if `labels.bin` exists
- shuffled loader must reject `ignore_labels` when `tokens != "content"`

### Non-functional

1. Reuse the existing shard protocol. Special-token mode must read token position `0` from each example when `metadata.cls_token` is true.

2. Keep the implementation small. The existing content-token path should remain unchanged except where a shared branch is cleaner than duplicate code.

3. Preserve the current meaning of `token_idx = -1` for special tokens so loader outputs match `IndexedDataset` and `shards.IndexMap`.

## Design

### Ordered loader

Continue to use `shards.IndexMap` to translate a global sample index into a shard location. This already knows that special tokens map to:

- `content_token_idx = -1`
- `token_idx_in_shard = 0`

The ordered manager only needs two behavior changes:

1. permit `tokens = "special"` in the fixed-layer path
2. skip label lookup when `content_token_idx < 0`

### Shuffled loader

The shuffled loader currently iterates over every content token in a shard chunk. In special-token mode it should instead emit exactly one activation per example in the chunk:

- activation source: `mmap[start:end, layer_i, 0]`
- metadata:
- column 0: global example indices
- column 1: `-1`

`ignore_labels` remains content-token-only because `labels.bin` is defined over content tokens.

## Test Plan

Add red tests before implementation:

1. Ordered loader special-token smoke test on fake shards:
- batch iterates successfully
- all `token_idx == -1`
- first batch has sequential `example_idx`

2. Ordered loader matches `IndexedDataset` in special-token mode.

3. Shuffled loader special-token epoch test on fake shards:
- all `token_idx == -1`
- every example appears exactly once in a full epoch
- activations match `IndexedDataset` for the sampled `example_idx`

4. Shuffled loader rejects `ignore_labels` when `tokens = "special"`.

## Acceptance Criteria

- The new tests fail before the implementation change.
- The new tests pass after the implementation change.
- Existing content-token tests continue to pass.
28 changes: 20 additions & 8 deletions src/saev/data/ordered.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Config:
"""

shards: pathlib.Path = pathlib.Path("$SAEV_SCRATCH/saev/shards/abcdefg")
tokens: tp.Literal["content"] = "content"
tokens: tp.Literal["special", "content"] = "content"
layer: int | tp.Literal["all"] = -2
Comment on lines 60 to 63
batch_size: int = 1024 * 16
batch_timeout_s: float = 30.0
Expand Down Expand Up @@ -92,9 +92,9 @@ def _manager_main(
)

# 0. PRE-CONDITIONS
if cfg.tokens != "content" or not isinstance(cfg.layer, int):
if cfg.tokens not in ("special", "content") or not isinstance(cfg.layer, int):
raise NotImplementedError(
"High-throughput loader only supports `content` and fixed `layer` mode for now."
"High-throughput loader only supports `special` or `content` with fixed `layer` mode for now."
)

assert cfg.layer in md.layers, f"Layer {cfg.layer} not in {md.layers}"
Expand All @@ -107,7 +107,7 @@ def _manager_main(
# Check if labels.bin exists
labels_mmap = None
labels_path = cfg.shards / "labels.bin"
if labels_path.exists():
if labels_path.exists() and cfg.tokens == "content":
labels_mmap = np.memmap(
labels_path,
mode="r",
Expand All @@ -121,7 +121,7 @@ def _manager_main(
assert shard.n_examples == shard_info[0].n_examples == md.examples_per_shard

# Calculate total number of samples
n_samples = md.n_examples * md.content_tokens_per_example
n_samples = len(index_map)

logger.debug("Found %d samples.", n_samples)

Expand Down Expand Up @@ -162,7 +162,7 @@ def _manager_main(
batch_token_i.append(idx.content_token_idx)

# Add patch label if available
if labels_mmap is not None:
if labels_mmap is not None and idx.content_token_idx >= 0:
batch_token_labels.append(
labels_mmap[idx.example_idx, idx.content_token_idx]
)
Expand All @@ -176,7 +176,7 @@ def _manager_main(
}

# Add labels if available
if labels_mmap is not None:
if labels_mmap is not None and batch_token_labels:
batch["token_labels"] = torch.tensor(
batch_token_labels, dtype=torch.long
)
Expand Down Expand Up @@ -218,6 +218,10 @@ def __init__(self, cfg: Config):
self.cfg = cfg
if not os.path.isdir(self.cfg.shards):
raise RuntimeError(f"Activations are not saved at '{self.cfg.shards}'.")
if self.cfg.layer == "all":
raise NotImplementedError(
"High-throughput loader only supports a fixed integer `layer`."
)

self.md = shards.Metadata.load(self.cfg.shards)

Expand Down Expand Up @@ -279,9 +283,13 @@ def _start_manager(self):
)
self.manager_proc.start()

def __iter__(self) -> collections.abc.Iterable[ExampleBatch]:
def __iter__(self) -> collections.abc.Iterator[ExampleBatch]:
"""Yields batches in order."""
self._start_manager()
msg = "Manager state did not initialize correctly."
assert self.batch_queue is not None, msg
assert self.err_queue is not None, msg
assert self.manager_proc is not None, msg
n = 0

try:
Expand Down Expand Up @@ -352,6 +360,10 @@ def __del__(self):

def _calculate_n_samples(self) -> int:
"""Helper to calculate total number of examples based on config."""
if self.cfg.tokens == "special":
msg = "tokens='special' requires shards with a CLS token."
assert self.md.cls_token, msg

match (self.cfg.tokens, self.cfg.layer):
case ("special", "all"):
return self.md.n_examples * len(self.md.layers)
Comment on lines +365 to 369
Expand Down
6 changes: 3 additions & 3 deletions src/saev/data/shards.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ class IndexMap:

md: Metadata
tokens: tp.Literal["special", "content", "all"]
layer: int
layer: int | tp.Literal["all"]
layer_idx_lookup: dict[int, int]

def __init__(
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def from_global(self, idx: int | np.int_) -> Index:
# [CLS] tokens only right now
example_idx = idx
shard_idx = idx // self.md.examples_per_shard
example_idx_in_shard = idx // self.md.examples_per_shard
example_idx_in_shard = idx % self.md.examples_per_shard
return Index(
idx=idx,
example_idx=example_idx,
Expand Down Expand Up @@ -1101,4 +1101,4 @@ def __len__(self) -> int:
* self.md.tokens_per_example
)
case _:
tp.assert_never((self.cfg.tokens, self.cfg.layer))
tp.assert_never((self.tokens, self.layer))
107 changes: 70 additions & 37 deletions src/saev/data/shuffled.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _io_worker(
shard_info = shards.ShardInfo.load(shards_path)

# Pre-conditions
assert cfg.tokens == "content"
assert cfg.tokens in ("special", "content")
assert isinstance(cfg.layer, int)

# If we need to filter by labels, ensure we have the labels
Expand All @@ -174,6 +174,53 @@ def _io_worker(

chunk_size = min(1024, math.ceil(cfg.batch_size * cfg.buffer_size / cfg.n_threads))

def put_chunk(
acts: Tensor,
meta: Int[Tensor, " n 2"],
*,
shard_i: int,
t0: float,
t1: float,
) -> None:
nonlocal bytes_sent, n_reads, t_last_report

n_examples = acts.shape[0]
msg = f"{n_examples} != {meta.shape[0]}"
assert n_examples == meta.shape[0], msg
msg = f"Expected metadata shape {(n_examples, 2)}, got {tuple(meta.shape)}"
assert tuple(meta.shape) == (n_examples, 2), msg

last_ex_i = int(meta[:, 0].max().item())
if last_ex_i >= md.n_examples:
err = ExampleOutOfBoundsError(md, last_ex_i)
logger.warning(err.message)
raise err

fill_before = reservoir.fill()
reservoir.put(acts, meta)
t2 = time.perf_counter()
fill_after = reservoir.fill()

n_reads += 1
bytes_sent += (
acts.numel() * acts.element_size() + meta.numel() * meta.element_size()
)

now = time.time()
if now - t_last_report < cfg.log_every_s:
return

logger.debug(
"shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f",
shard_i,
bytes_sent / 1e6,
(t1 - t0) * 1e3,
(t2 - t1) * 1e3,
fill_before,
fill_after,
)
t_last_report = now

reason = ""

while not stop_event.is_set():
Expand All @@ -183,7 +230,6 @@ def _io_worker(
logger.debug("Got 'None' from work_queue; exiting.")
reason = "poison_pill"
break
t1 = time.perf_counter()

fname = f"acts{shard_i:06}.bin"
logger.info("Opening %s.", fname)
Expand All @@ -194,12 +240,21 @@ def _io_worker(
mmap = np.memmap(
acts_fpath, mode="r", dtype=np.float32, shape=md.shard_shape
)
t2 = time.perf_counter()

# Only iterate over the actual number of examples in this shard
for start, end in helpers.batched_idx(
shard_info[shard_i].n_examples, chunk_size
):
if cfg.tokens == "special":
t0 = time.perf_counter()
acts = torch.from_numpy(mmap[start:end, layer_i, 0])
t1 = time.perf_counter()

meta = torch.full((end - start, 2), -1, dtype=torch.int32)
meta[:, 0] = ex_i_offset + torch.arange(start, end)
put_chunk(acts, meta, shard_i=shard_i, t0=t0, t1=t1)
continue

for t in range(md.content_tokens_per_example):
token_idx = t + int(md.cls_token)

Expand Down Expand Up @@ -240,35 +295,7 @@ def _io_worker(
meta = torch.full((end - start, 2), t, dtype=torch.int32)
meta[:, 0] = ex_i_offset + torch.arange(start, end)

last_ex_i = meta[:, 0].max().item()
if last_ex_i >= md.n_examples:
err = ExampleOutOfBoundsError(md, last_ex_i)
logger.warning(err.message)
raise err

fill_before = reservoir.fill()
reservoir.put(acts, meta)
t2 = time.perf_counter()
fill_after = reservoir.fill()

n_reads += 1
bytes_sent += (
acts.numel() * acts.element_size()
+ meta.numel() * meta.element_size()
)

now = time.time()
if now - t_last_report >= cfg.log_every_s:
logger.debug(
"shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f",
shard_i,
bytes_sent / 1e6,
(t1 - t0) * 1e3,
(t2 - t1) * 1e3,
fill_before,
fill_after,
)
t_last_report = now
put_chunk(acts, meta, shard_i=shard_i, t0=t0, t1=t1)
except queue.Empty:
# Wait 0.1 seconds for new data.
time.sleep(0.1)
Expand Down Expand Up @@ -315,9 +342,9 @@ def _manager_main(
)

# 0. PRE-CONDITIONS
if cfg.tokens != "content" or not isinstance(cfg.layer, int):
if cfg.tokens not in ("special", "content") or not isinstance(cfg.layer, int):
raise NotImplementedError(
"High-throughput loader only supports `content` and fixed `layer` mode for now."
"High-throughput loader only supports `special` or `content` with fixed `layer` mode for now."
)

assert cfg.layer in metadata.layers, f"Layer {cfg.layer} not in {metadata.layers}"
Expand Down Expand Up @@ -506,6 +533,10 @@ def _start_manager(self):
def __iter__(self) -> collections.abc.Iterator[ExampleBatch]:
"""Yields batches."""
self._start_manager()
msg = "Manager state did not initialize correctly."
assert self.reservoir is not None, msg
assert self.err_queue is not None, msg
assert self.manager_proc is not None, msg
n, b = 0, 0

try:
Expand Down Expand Up @@ -641,12 +672,14 @@ def _calculate_n_samples(self) -> int:
When ignore_labels is specified, this counts the actual number of patches
that remain after filtering out the ignored labels.
"""
if self.cfg.tokens == "special":
msg = "tokens='special' requires shards with a CLS token."
assert self.metadata.cls_token, msg

# First calculate the maximum possible samples
max_samples = 0
match (self.cfg.tokens, self.cfg.layer):
case ("cls", "all"):
max_samples = self.metadata.n_examples * len(self.metadata.layers)
case ("cls", int()):
case ("special", int()):
max_samples = self.metadata.n_examples
case ("content", int()):
max_samples = (
Expand Down
Loading