55import fnmatch
66import uuid
77from pathlib import Path
8- from typing import Any , List , Optional
8+ from typing import Any , List , Optional , Tuple
99
1010import redis .exceptions
1111import yaml
@@ -86,30 +86,14 @@ def create_batch_plan(
8686 requires_quantization = False
8787
8888 for index_name in index_names :
89- entry = self ._check_index_applicability (
89+ entry , has_quantization = self ._check_index_applicability (
9090 index_name = index_name ,
9191 shared_patch = shared_patch ,
9292 redis_client = client ,
9393 )
9494 batch_entries .append (entry )
95-
96- # Check if any applicable index requires quantization
97- if entry .applicable :
98- try :
99- plan = self ._single_planner .create_plan_from_patch (
100- index_name ,
101- schema_patch = shared_patch ,
102- redis_client = client ,
103- )
104- datatype_changes = MigrationPlanner .get_vector_datatype_changes (
105- plan .source .schema_snapshot ,
106- plan .merged_target_schema ,
107- rename_operations = plan .rename_operations ,
108- )
109- if datatype_changes :
110- requires_quantization = True
111- except Exception :
112- pass # Already handled in applicability check
95+ if has_quantization :
96+ requires_quantization = True
11397
11498 batch_id = f"batch_{ uuid .uuid4 ().hex [:12 ]} "
11599
@@ -171,8 +155,12 @@ def _check_index_applicability(
171155 index_name : str ,
172156 shared_patch : SchemaPatch ,
173157 redis_client : Any ,
174- ) -> BatchIndexEntry :
175- """Check if the shared patch can be applied to a specific index."""
158+ ) -> Tuple [BatchIndexEntry , bool ]:
159+ """Check if the shared patch can be applied to a specific index.
160+
161+ Returns:
162+ Tuple of (BatchIndexEntry, requires_quantization).
163+ """
176164 try :
177165 index = SearchIndex .from_existing (index_name , redis_client = redis_client )
178166 schema_dict = index .schema .to_dict ()
@@ -193,10 +181,13 @@ def _check_index_applicability(
193181 missing_fields .append (field_update .name )
194182
195183 if missing_fields :
196- return BatchIndexEntry (
197- name = index_name ,
198- applicable = False ,
199- skip_reason = f"Missing fields: { ', ' .join (missing_fields )} " ,
184+ return (
185+ BatchIndexEntry (
186+ name = index_name ,
187+ applicable = False ,
188+ skip_reason = f"Missing fields: { ', ' .join (missing_fields )} " ,
189+ ),
190+ False ,
200191 )
201192
202193 # Validate rename targets don't collide with each other or
@@ -213,10 +204,13 @@ def _check_index_applicability(
213204 seen_targets [t ] = seen_targets .get (t , 0 ) + 1
214205 duplicates = [t for t , c in seen_targets .items () if c > 1 ]
215206 if duplicates :
216- return BatchIndexEntry (
217- name = index_name ,
218- applicable = False ,
219- skip_reason = f"Rename targets collide: { ', ' .join (duplicates )} " ,
207+ return (
208+ BatchIndexEntry (
209+ name = index_name ,
210+ applicable = False ,
211+ skip_reason = f"Rename targets collide: { ', ' .join (duplicates )} " ,
212+ ),
213+ False ,
220214 )
221215 # Check if any rename target already exists and isn't itself being renamed away
222216 collisions = [
@@ -225,10 +219,13 @@ def _check_index_applicability(
225219 if t in field_names and t not in rename_sources
226220 ]
227221 if collisions :
228- return BatchIndexEntry (
229- name = index_name ,
230- applicable = False ,
231- skip_reason = f"Rename targets already exist: { ', ' .join (collisions )} " ,
222+ return (
223+ BatchIndexEntry (
224+ name = index_name ,
225+ applicable = False ,
226+ skip_reason = f"Rename targets already exist: { ', ' .join (collisions )} " ,
227+ ),
228+ False ,
232229 )
233230
234231 # Check that add_fields don't already exist.
@@ -242,10 +239,13 @@ def _check_index_applicability(
242239 existing_adds .append (field_name )
243240
244241 if existing_adds :
245- return BatchIndexEntry (
246- name = index_name ,
247- applicable = False ,
248- skip_reason = f"Fields already exist: { ', ' .join (existing_adds )} " ,
242+ return (
243+ BatchIndexEntry (
244+ name = index_name ,
245+ applicable = False ,
246+ skip_reason = f"Fields already exist: { ', ' .join (existing_adds )} " ,
247+ ),
248+ False ,
249249 )
250250
251251 # Try creating a plan to check for blocked changes
@@ -256,17 +256,29 @@ def _check_index_applicability(
256256 )
257257
258258 if not plan .diff_classification .supported :
259- return BatchIndexEntry (
260- name = index_name ,
261- applicable = False ,
262- skip_reason = (
263- plan .diff_classification .blocked_reasons [0 ]
264- if plan .diff_classification .blocked_reasons
265- else "Unsupported changes"
259+ return (
260+ BatchIndexEntry (
261+ name = index_name ,
262+ applicable = False ,
263+ skip_reason = (
264+ plan .diff_classification .blocked_reasons [0 ]
265+ if plan .diff_classification .blocked_reasons
266+ else "Unsupported changes"
267+ ),
266268 ),
269+ False ,
267270 )
268271
269- return BatchIndexEntry (name = index_name , applicable = True )
272+ # Detect quantization from the plan we already created
273+ has_quantization = bool (
274+ MigrationPlanner .get_vector_datatype_changes (
275+ plan .source .schema_snapshot ,
276+ plan .merged_target_schema ,
277+ rename_operations = plan .rename_operations ,
278+ )
279+ )
280+
281+ return BatchIndexEntry (name = index_name , applicable = True ), has_quantization
270282
271283 except (
272284 ConnectionError ,
@@ -278,10 +290,13 @@ def _check_index_applicability(
278290 # treated as "not applicable".
279291 raise
280292 except Exception as e :
281- return BatchIndexEntry (
282- name = index_name ,
283- applicable = False ,
284- skip_reason = str (e ),
293+ return (
294+ BatchIndexEntry (
295+ name = index_name ,
296+ applicable = False ,
297+ skip_reason = str (e ),
298+ ),
299+ False ,
285300 )
286301
287302 def write_batch_plan (self , batch_plan : BatchPlan , path : str ) -> None :
0 commit comments