3030from sqlmesh .dbt .test import TestConfig
3131from sqlmesh .dbt .util import DBT_VERSION
3232from sqlmesh .utils import AttributeDict
33+ from sqlmesh .utils .dag import find_path_with_dfs
3334from sqlmesh .utils .errors import ConfigError
3435from sqlmesh .utils .pydantic import field_validator
3536
@@ -270,9 +271,10 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
270271
271272 def fix_circular_test_refs (self , context : DbtContext ) -> None :
272273 """
273- Checks for direct circular references between two models and moves the test to the downstream
274- model if found. This addresses the most common circular reference - relationship tests in both
275- directions. In the future, we may want to increase coverage by checking for indirect circular references.
274+ Checks for circular references between models and moves tests to break cycles.
275+ This handles both direct circular references (A -> B -> A) and indirect circular
276+ references (A -> B -> C -> A). Tests are moved to the model that appears latest
277+ in the dependency chain to ensure the cycle is broken.
276278
277279 Args:
278280 context: The dbt context this model resides within.
@@ -284,16 +286,91 @@ def fix_circular_test_refs(self, context: DbtContext) -> None:
284286 for ref in test .dependencies .refs :
285287 if ref == self .name or ref in self .dependencies .refs :
286288 continue
287- model = context .refs [ref ]
288- if (
289- self .name in model .dependencies .refs
290- or self .name in model .tests_ref_source_dependencies .refs
291- ):
289+
290+ # Check if moving this test would create or maintain a cycle
291+ cycle_path = self ._find_circular_path (ref , context , set ())
292+ if cycle_path :
293+ # Find the model in the cycle that should receive the test
294+ # We want to move to the model that appears latest in the dependency chain
295+ target_model_name = self ._select_target_model_for_test (cycle_path , context )
296+ target_model = context .refs [target_model_name ]
297+
292298 logger .info (
293- f"Moving test '{ test .name } ' from model '{ self .name } ' to '{ model .name } ' to avoid circular reference."
299+ f"Moving test '{ test .name } ' from model '{ self .name } ' to '{ target_model_name } ' "
300+ f"to avoid circular reference through path: { ' -> ' .join (cycle_path )} "
294301 )
295- model .tests .append (test )
302+ target_model .tests .append (test )
296303 self .tests .remove (test )
304+ break
305+
306+ def _find_circular_path (
307+ self , ref : str , context : DbtContext , visited : t .Set [str ]
308+ ) -> t .Optional [t .List [str ]]:
309+ """
310+ Find if there's a circular dependency path from ref back to this model.
311+
312+ Args:
313+ ref: The model name to start searching from
314+ context: The dbt context
315+ visited: Set of model names already visited in this path
316+
317+ Returns:
318+ List of model names forming the circular path, or None if no cycle exists
319+ """
320+ # Build a graph of all models and their dependencies from the context
321+ graph : t .Dict [str , t .Set [str ]] = {}
322+
323+ def build_graph_from_node (node_name : str , current_visited : t .Set [str ]) -> None :
324+ if node_name in current_visited or node_name in graph :
325+ return
326+ current_visited .add (node_name )
327+
328+ model = context .refs [node_name ]
329+ # Include both direct model dependencies and test dependencies
330+ all_refs = model .dependencies .refs | model .tests_ref_source_dependencies .refs
331+ graph [node_name ] = all_refs .copy ()
332+
333+ # Recursively build graph for dependencies
334+ for dep in all_refs :
335+ build_graph_from_node (dep , current_visited )
336+
337+ # Build the graph starting from the ref, including visited nodes to avoid infinite recursion
338+ build_graph_from_node (ref , visited .copy ())
339+
340+ # Add self.name to the graph if it's not already there
341+ if self .name not in graph :
342+ graph [self .name ] = set ()
343+
344+ # Use the shared DFS function to find path from ref to self.name
345+ return find_path_with_dfs (graph , start_node = ref , target_node = self .name )
346+
347+ def _select_target_model_for_test (self , cycle_path : t .List [str ], context : DbtContext ) -> str :
348+ """
349+ Select which model in the cycle should receive the test.
350+ We select the model that has the most downstream dependencies in the cycle
351+
352+ Args:
353+ cycle_path: List of model names in the circular dependency path
354+ context: The dbt context
355+
356+ Returns:
357+ Name of the model that should receive the test
358+ """
359+ # Count how many other models in the cycle each model depends on
360+ dependency_counts = {}
361+
362+ for model_name in cycle_path :
363+ model = context .refs [model_name ]
364+ all_refs = model .dependencies .refs | model .tests_ref_source_dependencies .refs
365+ count = len ([ref for ref in all_refs if ref in cycle_path ])
366+ dependency_counts [model_name ] = count
367+
368+ # Return the model with the fewest dependencies within the cycle
369+ # (i.e., the most downstream model in the cycle)
370+ if dependency_counts :
371+ return min (dependency_counts , key = dependency_counts .get ) # type: ignore
372+ # Fallback to the last model in the path
373+ return cycle_path [- 1 ]
297374
298375 @property
299376 def sqlmesh_config_fields (self ) -> t .Set [str ]:
0 commit comments