Skip to content

[BUG] Repartion IR may passed into a sort_actor in cudf_polars + rapidsmpf #22050

@mroeschke

Description

@mroeschke

(Discovered in the Polars test suite so sorry it's not more minimal)

In [1]: from functools import partialmethod
   ...: import polars
   ...: from cudf_polars.utils.config import Runtime, StreamingFallbackMode
   ...: executor = "streaming"
   ...: executor_options: dict[str, Any] = {}
   ...: executor_options["max_rows_per_partition"] = 4
   ...: executor_options["target_partition_size"] = 10
   ...: # We expect many tests to fall back, so silence the warnings
   ...: executor_options["fallback_mode"] = StreamingFallbackMode.SILENT
   ...: executor_options["runtime"] = Runtime.RAPIDSMPF
   ...: collect = polars.LazyFrame.collect
   ...: engine = polars.GPUEngine(executor=executor, executor_options=executor_options)
   ...: polars.LazyFrame.collect = partialmethod(collect, engine=engine)

In [3]: import polars as pl

In [4]:     df = pl.DataFrame(
   ...:         {
   ...:             "id": [1, 2, 3, 4, 5, 6],
   ...:             "category": ["A", "A", "B", "B", "C", "C"],
   ...:             "value": [100, 100, 200, 200, 300, 300],
   ...:         }
   ...:     )

In [5]: query="""
   ...:             SELECT DISTINCT category, value
   ...:             FROM df
   ...:             QUALIFY value = MAX(value) OVER (PARTITION BY category)
   ...:             ORDER BY category
   ...:         """

In [6]: with pl.SQLContext(frames={"df": df}, eager=True) as ctx:
   ...:     ctx.execute(query=query, eager=True)

...
/cudf/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py:399, in sort_actor(context, comm, ir, ir_context, ch_in, ch_out, by, num_partitions, executor, collective_ids)
    397 """Streaming sort actor."""
    398 local_sort_ir = ir.children[0]
--> 399 assert isinstance(local_sort_ir, Sort), f"ShuffleSorted must have a Sort child, got {local_sort_ir}"
    400 ch_replay = context.create_channel()
    401 async with shutdown_on_error(
    402     context, ch_in, ch_out, ch_replay, trace_ir=ir, ir_context=ir_context
    403 ) as tracer:

AssertionError: ShuffleSorted must have a Sort child, got Repartition({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, Sort({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, (NamedExpr(category, Col(<DataType(polars=String, plc=<type_id.STRING: 23>)>, 'category')),), (<order.ASCENDING: 0>,), (<null_order.AFTER: 0>,), False, None, ShuffleSorted({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, (NamedExpr(category, Col(<DataType(polars=String, plc=<type_id.STRING: 23>)>, 'category')),), (<order.ASCENDING: 0>,), (<null_order.AFTER: 0>,), <ShuffleMethod._RAPIDSMPF_SINGLE: 'rapidsmpf-single'>, Sort({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, (NamedExpr(category, Col(<DataType(polars=String, plc=<type_id.STRING: 23>)>, 'category')),), (<order.ASCENDING: 0>,), (<null_order.AFTER: 0>,), False, None, HStack({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, (NamedExpr(category, Col(<DataType(polars=String, plc=<type_id.STRING: 23>)>, 'category')), NamedExpr(value, Col(<DataType(polars=Int64, plc=<type_id.INT64: 4>)>, 'value'))), True, DataFrameScan({'category': <DataType(polars=String, plc=<type_id.STRING: 23>)>, 'value': <DataType(polars=Int64, plc=<type_id.INT64: 4>)>}, <builtins.PyDataFrame object at 0x7a56974c8300>, ('category', 'value')))))))

I amended the assertion to show the input IR. There is a nested Sort -> ShuffleSorted -> Sort in the Repartition. Likely an edge case not covered in #21690 cc @rjzamora

Repartition(
  {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
  Sort(
    {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
    (NamedExpr(category, Col(DataType(polars=String, plc=STRING), 'category')),),
    (order.ASCENDING,),
    (null_order.AFTER,),
    False,
    None,
    ShuffleSorted(
      {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
      (NamedExpr(category, Col(DataType(polars=String, plc=STRING), 'category')),),
      (order.ASCENDING,),
      (null_order.AFTER,),
      ShuffleMethod._RAPIDSMPF_SINGLE,
      Sort(
        {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
        (NamedExpr(category, Col(DataType(polars=String, plc=STRING), 'category')),),
        (order.ASCENDING,),
        (null_order.AFTER,),
        False,
        None,
        HStack(
          {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
          (
            NamedExpr(category, Col(DataType(polars=String, plc=STRING), 'category')),
            NamedExpr(value, Col(DataType(polars=Int64, plc=INT64), 'value'))
          ),
          True,
          DataFrameScan(
            {'category': DataType(polars=String, plc=STRING), 'value': DataType(polars=Int64, plc=INT64)},
            PyDataFrame(...),
            ('category', 'value')
          )
        )
      )
    )
  )
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcudf-polarsIssues specific to cudf-polars

    Type

    No type

    Projects

    Status

    In Progress

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions