Skip to content

Commit f775476

Browse files
committed
fix: nkode-review round 2 — security + full inspect findings
Findings addressed: - Multi-worker resume: sync and async workers now attempt VectorBackup.load() before VectorBackup.create(), resuming from partial backups on re-run - Python 3.8 compat: replaced str.removesuffix() with Path.with_suffix('') - Rollback progress counter: count only keys with actual originals, not all keys - Codespell: renamed 'nd' variable to 'num_indexed' in e2e scripts Tests added: - TestRollbackCLI: header path derivation, iter_batches restore, edge cases nkode-review results: - security: 0 confirmed findings (2 informational residual risks) - inspect --full: 5 findings, all addressed
1 parent 2ce220a commit f775476

5 files changed

Lines changed: 215 additions & 93 deletions

File tree

redisvl/cli/migrate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def rollback(self):
399399
sys.exit(1)
400400

401401
# Derive backup base paths (strip .header suffix)
402-
backup_paths = [str(h).removesuffix(".header") for h in header_files]
402+
backup_paths = [str(h.with_suffix("")) for h in header_files]
403403

404404
client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url)
405405
total_restored = 0
@@ -419,13 +419,15 @@ def rollback(self):
419419
batch_count = 0
420420
for keys, originals in backup.iter_batches():
421421
pipe = client.pipeline(transaction=False)
422+
batch_restored = 0
422423
for key in keys:
423424
if key in originals:
424425
for field_name, original_bytes in originals[key].items():
425426
pipe.hset(key, field_name, original_bytes)
427+
batch_restored += 1
426428
pipe.execute()
427429
batch_count += 1
428-
total_restored += len(keys)
430+
total_restored += batch_restored
429431
if batch_count % 10 == 0:
430432
print(
431433
f" Restored {total_restored:,} vectors "

redisvl/migration/quantize.py

Lines changed: 128 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -157,40 +157,65 @@ def _worker_quantize(
157157

158158
client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url)
159159
try:
160-
# Phase 1: Dump originals to backup shard
161-
backup = VectorBackup.create(
162-
path=backup_path,
163-
index_name=index_name,
164-
fields=datatype_changes,
165-
batch_size=batch_size,
166-
)
160+
# Try to resume from existing backup shard first
161+
backup = VectorBackup.load(backup_path)
162+
if backup is not None:
163+
logger.info(
164+
"Worker %d: resuming from existing backup (phase=%s, "
165+
"dump_batches=%d, quantize_batches=%d)",
166+
worker_id,
167+
backup.header.phase,
168+
backup.header.dump_completed_batches,
169+
backup.header.quantize_completed_batches,
170+
)
171+
else:
172+
backup = VectorBackup.create(
173+
path=backup_path,
174+
index_name=index_name,
175+
fields=datatype_changes,
176+
batch_size=batch_size,
177+
)
167178

168179
total = len(keys)
169-
for batch_start in range(0, total, batch_size):
170-
batch_keys = keys[batch_start : batch_start + batch_size]
171-
originals = pipeline_read_vectors(client, batch_keys, datatype_changes)
172-
backup.write_batch(batch_start // batch_size, batch_keys, originals)
173-
if progress_callback:
174-
progress_callback(
175-
"dump", worker_id, min(batch_start + batch_size, total)
176-
)
177-
178-
backup.mark_dump_complete()
179-
180-
# Phase 2: Convert + write from backup
181-
backup.start_quantize()
182-
docs_quantized = 0
183-
184-
for batch_idx, (batch_keys, originals) in enumerate(backup.iter_batches()):
185-
converted = convert_vectors(originals, datatype_changes)
186-
if converted:
187-
pipeline_write_vectors(client, converted)
188-
backup.mark_batch_quantized(batch_idx)
189-
docs_quantized += len(batch_keys)
190-
if progress_callback:
191-
progress_callback("quantize", worker_id, docs_quantized)
192-
193-
backup.mark_complete()
180+
181+
# Phase 1: Dump originals to backup shard (skip if already complete)
182+
if backup.header.phase == "dump":
183+
start_batch = backup.header.dump_completed_batches
184+
for batch_start in range(start_batch * batch_size, total, batch_size):
185+
batch_keys = keys[batch_start : batch_start + batch_size]
186+
originals = pipeline_read_vectors(client, batch_keys, datatype_changes)
187+
backup.write_batch(batch_start // batch_size, batch_keys, originals)
188+
if progress_callback:
189+
progress_callback(
190+
"dump", worker_id, min(batch_start + batch_size, total)
191+
)
192+
backup.mark_dump_complete()
193+
194+
# Phase 2: Convert + write from backup (skip completed batches)
195+
if backup.header.phase in ("ready", "active"):
196+
backup.start_quantize()
197+
docs_quantized = 0
198+
199+
for batch_idx, (batch_keys, originals) in enumerate(backup.iter_batches()):
200+
if batch_idx < backup.header.quantize_completed_batches:
201+
docs_quantized += len(batch_keys)
202+
continue
203+
converted = convert_vectors(originals, datatype_changes)
204+
if converted:
205+
pipeline_write_vectors(client, converted)
206+
backup.mark_batch_quantized(batch_idx)
207+
docs_quantized += len(batch_keys)
208+
if progress_callback:
209+
progress_callback("quantize", worker_id, docs_quantized)
210+
211+
backup.mark_complete()
212+
elif backup.header.phase == "completed":
213+
# Already done from previous run
214+
docs_quantized = sum(
215+
1 for _ in range(0, total, batch_size) for _ in keys[:batch_size]
216+
)
217+
docs_quantized = total
218+
194219
return {"worker_id": worker_id, "docs": docs_quantized}
195220
finally:
196221
try:
@@ -309,62 +334,82 @@ async def _async_worker_quantize(
309334

310335
client = aioredis.from_url(redis_url)
311336
try:
312-
# Phase 1: Dump originals
313-
backup = VectorBackup.create(
314-
path=backup_path,
315-
index_name=index_name,
316-
fields=datatype_changes,
317-
batch_size=batch_size,
318-
)
337+
# Try to resume from existing backup shard first
338+
backup = VectorBackup.load(backup_path)
339+
if backup is not None:
340+
logger.info(
341+
"Async worker %d: resuming from existing backup (phase=%s, "
342+
"dump_batches=%d, quantize_batches=%d)",
343+
worker_id,
344+
backup.header.phase,
345+
backup.header.dump_completed_batches,
346+
backup.header.quantize_completed_batches,
347+
)
348+
else:
349+
backup = VectorBackup.create(
350+
path=backup_path,
351+
index_name=index_name,
352+
fields=datatype_changes,
353+
batch_size=batch_size,
354+
)
319355

320356
total = len(keys)
321357
field_names = list(datatype_changes.keys())
322358

323-
for batch_start in range(0, total, batch_size):
324-
batch_keys = keys[batch_start : batch_start + batch_size]
325-
pipe = client.pipeline(transaction=False)
326-
call_order: List[tuple] = []
327-
for key in batch_keys:
328-
for field_name in field_names:
329-
pipe.hget(key, field_name)
330-
call_order.append((key, field_name))
331-
results = await pipe.execute()
332-
333-
originals: Dict[str, Dict[str, bytes]] = {}
334-
for (key, field_name), value in zip(call_order, results):
335-
if value is not None:
336-
if key not in originals:
337-
originals[key] = {}
338-
originals[key][field_name] = value
339-
340-
backup.write_batch(batch_start // batch_size, batch_keys, originals)
341-
if progress_callback:
342-
progress_callback(
343-
"dump", worker_id, min(batch_start + batch_size, total)
344-
)
345-
346-
backup.mark_dump_complete()
347-
348-
# Phase 2: Convert + write from backup
349-
backup.start_quantize()
350-
docs_quantized = 0
351-
352-
for batch_idx, (batch_keys, batch_originals) in enumerate(
353-
backup.iter_batches()
354-
):
355-
converted = convert_vectors(batch_originals, datatype_changes)
356-
if converted:
359+
# Phase 1: Dump originals (skip if already complete)
360+
if backup.header.phase == "dump":
361+
start_batch = backup.header.dump_completed_batches
362+
for batch_start in range(start_batch * batch_size, total, batch_size):
363+
batch_keys = keys[batch_start : batch_start + batch_size]
357364
pipe = client.pipeline(transaction=False)
358-
for key, fields in converted.items():
359-
for field_name, data in fields.items():
360-
pipe.hset(key, field_name, data)
361-
await pipe.execute()
362-
backup.mark_batch_quantized(batch_idx)
363-
docs_quantized += len(batch_keys)
364-
if progress_callback:
365-
progress_callback("quantize", worker_id, docs_quantized)
366-
367-
backup.mark_complete()
365+
call_order: List[tuple] = []
366+
for key in batch_keys:
367+
for field_name in field_names:
368+
pipe.hget(key, field_name)
369+
call_order.append((key, field_name))
370+
results = await pipe.execute()
371+
372+
originals: Dict[str, Dict[str, bytes]] = {}
373+
for (key, field_name), value in zip(call_order, results):
374+
if value is not None:
375+
if key not in originals:
376+
originals[key] = {}
377+
originals[key][field_name] = value
378+
379+
backup.write_batch(batch_start // batch_size, batch_keys, originals)
380+
if progress_callback:
381+
progress_callback(
382+
"dump", worker_id, min(batch_start + batch_size, total)
383+
)
384+
backup.mark_dump_complete()
385+
386+
# Phase 2: Convert + write from backup (skip completed batches)
387+
if backup.header.phase in ("ready", "active"):
388+
backup.start_quantize()
389+
docs_quantized = 0
390+
391+
for batch_idx, (batch_keys, batch_originals) in enumerate(
392+
backup.iter_batches()
393+
):
394+
if batch_idx < backup.header.quantize_completed_batches:
395+
docs_quantized += len(batch_keys)
396+
continue
397+
converted = convert_vectors(batch_originals, datatype_changes)
398+
if converted:
399+
pipe = client.pipeline(transaction=False)
400+
for key, fields in converted.items():
401+
for field_name, data in fields.items():
402+
pipe.hset(key, field_name, data)
403+
await pipe.execute()
404+
backup.mark_batch_quantized(batch_idx)
405+
docs_quantized += len(batch_keys)
406+
if progress_callback:
407+
progress_callback("quantize", worker_id, docs_quantized)
408+
409+
backup.mark_complete()
410+
elif backup.header.phase == "completed":
411+
docs_quantized = total
412+
368413
return {"worker_id": worker_id, "docs": docs_quantized}
369414
finally:
370415
await client.aclose()

scripts/test_crash_resume_e2e.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ def create_index_and_load(r):
6969
for _ in range(60):
7070
info = r.execute_command("FT.INFO", INDEX_NAME)
7171
info_dict = dict(zip(info[::2], info[1::2]))
72-
nd = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0)))
73-
if nd >= NUM_DOCS:
72+
num_indexed = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0)))
73+
if num_indexed >= NUM_DOCS:
7474
break
7575
time.sleep(0.5)
76-
log(f"Index ready: {nd:,} docs indexed")
77-
return nd
76+
log(f"Index ready: {num_indexed:,} docs indexed")
77+
return num_indexed
7878

7979

8080
def verify_vectors(r, expected_bytes, label=""):

scripts/test_migration_e2e.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,19 +166,19 @@ def create_index_and_load(r):
166166
for attempt in range(7200):
167167
info = r.execute_command("FT.INFO", INDEX_NAME)
168168
info_dict = dict(zip(info[::2], info[1::2]))
169-
nd = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0)))
169+
num_indexed = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0)))
170170
pct = float(info_dict.get(b"percent_indexed",
171171
info_dict.get("percent_indexed", "0")))
172172
if pct >= 1.0:
173173
break
174174
if attempt % 15 == 0:
175175
elapsed_idx = time.perf_counter() - idx_start
176-
log(f" indexing: {nd:,}/{NUM_DOCS:,} docs "
176+
log(f" indexing: {num_indexed:,}/{NUM_DOCS:,} docs "
177177
f"({pct*100:.1f}%, {elapsed_idx:.0f}s elapsed)...")
178178
time.sleep(1)
179179
idx_elapsed = time.perf_counter() - idx_start
180-
log(f" Index ready: {nd:,} docs indexed in {idx_elapsed:.1f}s")
181-
return nd
180+
log(f" Index ready: {num_indexed:,} docs indexed in {idx_elapsed:.1f}s")
181+
return num_indexed
182182

183183

184184
def verify_vectors(r, expected_dtype, bytes_per_element, sample_size=10000):

tests/unit/test_vector_backup.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,78 @@ def test_rollback_reads_all_originals(self, tmp_path):
353353
assert len(all_originals) == 4
354354
for key in ["doc:0", "doc:1", "doc:2", "doc:3"]:
355355
assert all_originals[key]["embedding"] == vecs[key]
356+
357+
358+
class TestRollbackCLI:
359+
"""Tests for the rvl migrate rollback CLI command path derivation and restore logic."""
360+
361+
def _create_backup_with_data(self, tmp_path, name="test_idx"):
362+
"""Helper: create a backup with 2 batches of data."""
363+
from redisvl.migration.backup import VectorBackup
364+
365+
bp = str(tmp_path / f"migration_backup_{name}")
366+
vecs = {
367+
"doc:0": struct.pack("<4f", 1.0, 2.0, 3.0, 4.0),
368+
"doc:1": struct.pack("<4f", 5.0, 6.0, 7.0, 8.0),
369+
}
370+
backup = VectorBackup.create(
371+
path=bp,
372+
index_name=name,
373+
fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}},
374+
batch_size=1,
375+
)
376+
backup.write_batch(0, ["doc:0"], {"doc:0": {"embedding": vecs["doc:0"]}})
377+
backup.write_batch(1, ["doc:1"], {"doc:1": {"embedding": vecs["doc:1"]}})
378+
backup.mark_dump_complete()
379+
return bp, vecs
380+
381+
def test_header_path_derivation_no_removesuffix(self, tmp_path):
382+
"""Verify path derivation works without str.removesuffix (Python 3.8 compat)."""
383+
from pathlib import Path
384+
385+
bp, _ = self._create_backup_with_data(tmp_path)
386+
header_files = sorted(Path(tmp_path).glob("*.header"))
387+
assert len(header_files) == 1
388+
# This is how the CLI derives backup paths — must not use removesuffix
389+
derived = str(header_files[0].with_suffix(""))
390+
assert derived == bp
391+
392+
def test_rollback_restores_via_iter_batches(self, tmp_path):
393+
"""Verify rollback reads all batches and gets correct original vectors."""
394+
from redisvl.migration.backup import VectorBackup
395+
396+
bp, vecs = self._create_backup_with_data(tmp_path)
397+
backup = VectorBackup.load(bp)
398+
assert backup is not None
399+
400+
restored = {}
401+
for batch_keys, originals in backup.iter_batches():
402+
for key in batch_keys:
403+
if key in originals:
404+
restored[key] = originals[key]
405+
406+
assert len(restored) == 2
407+
assert restored["doc:0"]["embedding"] == vecs["doc:0"]
408+
assert restored["doc:1"]["embedding"] == vecs["doc:1"]
409+
410+
def test_rollback_nonexistent_dir(self):
411+
"""Verify error handling for missing backup directory."""
412+
import os
413+
414+
assert not os.path.isdir("/nonexistent/backup/dir/xyz123")
415+
416+
def test_rollback_empty_dir(self, tmp_path):
417+
"""Verify no header files found in empty directory."""
418+
from pathlib import Path
419+
420+
header_files = sorted(Path(tmp_path).glob("*.header"))
421+
assert len(header_files) == 0
422+
423+
def test_rollback_unloadable_backup_returns_none(self, tmp_path):
424+
"""VectorBackup.load returns None for corrupt/missing data."""
425+
from redisvl.migration.backup import VectorBackup
426+
427+
# Create header but no data file
428+
bp = str(tmp_path / "bad_backup")
429+
result = VectorBackup.load(bp)
430+
assert result is None

0 commit comments

Comments
 (0)