Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 7abdf3c

Browse files
authored
Overload df.getitem with boolean series idx (#583)
1 parent c358ba9 commit 7abdf3c

4 files changed

Lines changed: 157 additions & 40 deletions

File tree

sdc/datatypes/hpat_pandas_dataframe_functions.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
from numba import types
3939
from numba.special import literally
4040
from sdc.hiframes.pd_dataframe_ext import DataFrameType
41-
from sdc.utilities.sdc_typing_utils import TypeChecker
41+
from sdc.hiframes.pd_series_type import SeriesType
42+
from sdc.utilities.sdc_typing_utils import (TypeChecker, check_index_is_numeric,
43+
check_types_comparable,
44+
gen_df_impl_generator)
4245
from sdc.str_arr_ext import StringArrayType
4346

4447
from sdc.hiframes.pd_dataframe_type import DataFrameType
@@ -929,6 +932,32 @@ def sdc_pandas_dataframe_drop_impl(df, _func_name, args, columns):
929932
return sdc_pandas_dataframe_drop_impl(df, _func_name, args, columns)
930933

931934

935+
def df_getitem_bool_series_idx_main_codelines(self, idx):
936+
"""Generate main code lines for df.getitem"""
937+
func_lines = [' self_length = len(get_dataframe_data(self, 0))',
938+
' trimmed_idx_data = idx._data[:self_length]']
939+
940+
if isinstance(self.index, types.NoneType):
941+
func_lines += [' self_index = numpy.arange(self_length)']
942+
else:
943+
func_lines += [' self_index = self._index']
944+
945+
results = []
946+
for i, col in enumerate(self.columns):
947+
res_data = f'res_data_{i}'
948+
func_lines += [
949+
f' data_{i} = get_dataframe_data(self, {i})',
950+
f' series = pandas.Series(data_{i}, index=self_index, name="{col}")',
951+
f' {res_data} = series[trimmed_idx_data]',
952+
]
953+
results.append((col, res_data))
954+
955+
data = ', '.join(f'"{col}": {data}' for col, data in results)
956+
func_lines += [f' return pandas.DataFrame({{{data}}}, index=self_index[trimmed_idx_data])']
957+
958+
return func_lines
959+
960+
932961
def df_index_codelines(self):
933962
"""Generate code lines to get or create index of DF"""
934963
if isinstance(self.index, types.NoneType):
@@ -941,6 +970,11 @@ def df_index_codelines(self):
941970
return func_lines
942971

943972

973+
def df_getitem_key_error_codelines():
974+
"""Generate code lines to raise KeyError"""
975+
return [' raise KeyError("Column is not in the DataFrame")']
976+
977+
944978
def df_getitem_slice_idx_main_codelines(self, idx):
945979
"""Generate main code lines for df.getitem with idx of slice"""
946980
results = []
@@ -978,6 +1012,35 @@ def df_getitem_tuple_idx_main_codelines(self, literal_idx):
9781012
return func_lines
9791013

9801014

1015+
def df_getitem_bool_series_codegen(self, idx):
1016+
"""
1017+
Example of generated implementation with provided index:
1018+
def _df_getitem_bool_series_idx_impl(self, idx):
1019+
self_length = len(get_dataframe_data(self, 0))
1020+
trimmed_idx_data = idx._data[:self_length]
1021+
self_index = self._index
1022+
data_0 = get_dataframe_data(self, 0)
1023+
series = pandas.Series(data_0, index=self_index, name="A")
1024+
res_data_0 = series[trimmed_idx_data]
1025+
data_1 = get_dataframe_data(self, 1)
1026+
series = pandas.Series(data_1, index=self_index, name="B")
1027+
res_data_1 = series[trimmed_idx_data]
1028+
return pandas.DataFrame({"A": res_data_0, "B": res_data_1}, index=self_index[trimmed_idx_data])
1029+
"""
1030+
func_lines = ['def _df_getitem_bool_series_idx_impl(self, idx):']
1031+
if self.columns:
1032+
func_lines += df_getitem_bool_series_idx_main_codelines(self, idx)
1033+
else:
1034+
# raise KeyError if input DF is empty
1035+
func_lines += df_getitem_key_error_codelines()
1036+
1037+
func_text = '\n'.join(func_lines)
1038+
global_vars = {'pandas': pandas, 'numpy': numpy,
1039+
'get_dataframe_data': get_dataframe_data}
1040+
1041+
return func_text, global_vars
1042+
1043+
9811044
def df_getitem_slice_idx_codegen(self, idx):
9821045
"""
9831046
Example of generated implementation with provided index:
@@ -994,7 +1057,7 @@ def _df_getitem_slice_idx_impl(self, idx)
9941057
func_lines += df_getitem_slice_idx_main_codelines(self, idx)
9951058
else:
9961059
# raise KeyError if input DF is empty
997-
func_lines += [' raise KeyError']
1060+
func_lines += df_getitem_key_error_codelines()
9981061

9991062
func_text = '\n'.join(func_lines)
10001063
global_vars = {'pandas': pandas, 'numpy': numpy,
@@ -1022,7 +1085,7 @@ def _df_getitem_tuple_idx_impl(self, idx)
10221085
func_lines += df_getitem_tuple_idx_main_codelines(self, literal_idx)
10231086
else:
10241087
# raise KeyError if input DF is empty or idx is invalid
1025-
func_lines += [' raise KeyError']
1088+
func_lines += df_getitem_key_error_codelines()
10261089

10271090
func_text = '\n'.join(func_lines)
10281091
global_vars = {'pandas': pandas, 'numpy': numpy,
@@ -1031,28 +1094,17 @@ def _df_getitem_tuple_idx_impl(self, idx)
10311094
return func_text, global_vars
10321095

10331096

1034-
def gen_df_getitem_impl_generator(codegen, impl_name):
1035-
"""Generate generator of df.getitem"""
1036-
def _df_getitem_impl_generator(self, idx):
1037-
func_text, global_vars = codegen(self, idx)
1038-
1039-
loc_vars = {}
1040-
exec(func_text, global_vars, loc_vars)
1041-
_impl = loc_vars[impl_name]
1042-
1043-
return _impl
1044-
1045-
return _df_getitem_impl_generator
1046-
1047-
1048-
gen_df_getitem_slice_idx_impl = gen_df_getitem_impl_generator(
1097+
gen_df_getitem_slice_idx_impl = gen_df_impl_generator(
10491098
df_getitem_slice_idx_codegen, '_df_getitem_slice_idx_impl')
1050-
gen_df_getitem_tuple_idx_impl = gen_df_getitem_impl_generator(
1099+
gen_df_getitem_tuple_idx_impl = gen_df_impl_generator(
10511100
df_getitem_tuple_idx_codegen, '_df_getitem_tuple_idx_impl')
1101+
gen_df_getitem_bool_series_idx_impl = gen_df_impl_generator(
1102+
df_getitem_bool_series_codegen, '_df_getitem_bool_series_idx_impl')
10521103

10531104

10541105
@sdc_overload(operator.getitem)
10551106
def sdc_pandas_dataframe_getitem(self, idx):
1107+
ty_checker = TypeChecker('Operator getitem().')
10561108

10571109
if not isinstance(self, DataFrameType):
10581110
return None
@@ -1069,7 +1121,7 @@ def _df_getitem_str_literal_idx_impl(self, idx):
10691121
data = get_dataframe_data(self, col_idx)
10701122
return pandas.Series(data, index=self._index, name=idx)
10711123
else:
1072-
raise KeyError
1124+
raise KeyError('Column is not in the DataFrame')
10731125

10741126
return _df_getitem_str_literal_idx_impl
10751127

@@ -1082,12 +1134,30 @@ def _df_getitem_unicode_idx_impl(self, idx):
10821134
return _df_getitem_unicode_idx_impl
10831135

10841136
if isinstance(idx, types.Tuple):
1085-
return gen_df_getitem_tuple_idx_impl(self, idx)
1137+
if all([isinstance(item, types.StringLiteral) for item in idx]):
1138+
return gen_df_getitem_tuple_idx_impl(self, idx)
10861139

10871140
if isinstance(idx, types.SliceType):
10881141
return gen_df_getitem_slice_idx_impl(self, idx)
10891142

1090-
ty_checker = TypeChecker('Operator getitem().')
1143+
if isinstance(idx, SeriesType) and isinstance(idx.dtype, types.Boolean):
1144+
self_index_is_none = isinstance(self.index, types.NoneType)
1145+
idx_index_is_none = isinstance(idx.index, types.NoneType)
1146+
1147+
if self_index_is_none and not idx_index_is_none:
1148+
if not check_index_is_numeric(idx):
1149+
ty_checker.raise_exc(idx.index.dtype, 'number', 'idx.index.dtype')
1150+
1151+
if not self_index_is_none and idx_index_is_none:
1152+
if not check_index_is_numeric(self):
1153+
ty_checker.raise_exc(idx.index.dtype, self.index.dtype, 'idx.index.dtype')
1154+
1155+
if not self_index_is_none and not idx_index_is_none:
1156+
if not check_types_comparable(self.index, idx.index):
1157+
ty_checker.raise_exc(idx.index.dtype, self.index.dtype, 'idx.index.dtype')
1158+
1159+
return gen_df_getitem_bool_series_idx_impl(self, idx)
1160+
10911161
ty_checker.raise_exc(idx, 'str', 'idx')
10921162

10931163

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,18 +349,17 @@ def df_getitem_overload(df, ind):
349349
index = df.columns.index(ind.literal_value)
350350
return lambda df, ind: sdc.hiframes.api.init_series(df._data[index])
351351

352-
353-
@infer_global(operator.getitem)
354-
class GetItemDataFrame(AbstractTemplate):
355-
key = operator.getitem
356-
357-
def generic(self, args, kws):
358-
df, idx = args
359-
# df1 = df[df.A > .5]
360-
if (isinstance(df, DataFrameType)
361-
and isinstance(idx, (SeriesType, types.Array))
362-
and idx.dtype == types.bool_):
363-
return signature(df, *args)
352+
@infer_global(operator.getitem)
353+
class GetItemDataFrame(AbstractTemplate):
354+
key = operator.getitem
355+
356+
def generic(self, args, kws):
357+
df, idx = args
358+
# df1 = df[df.A > .5]
359+
if (isinstance(df, DataFrameType)
360+
and isinstance(idx, (SeriesType, types.Array))
361+
and idx.dtype == types.bool_):
362+
return signature(df, *args)
364363

365364

366365
@infer

sdc/tests/test_dataframe.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,17 +1236,37 @@ def test_impl(df, start, end):
12361236
ref_result = test_impl(df, start, end)
12371237
pd.testing.assert_frame_equal(jit_result, ref_result)
12381238

1239-
@skip_sdc_jit('DF.getitem unsupported Series name')
12401239
def _test_df_getitem_tuple_idx(self, df):
1241-
def test_impl(df):
1242-
# pd.df.getitem does not support idx as a tuple
1243-
return df[['A', 'C']]
1240+
def gen_test_impl(do_jit=False):
1241+
def test_impl(df):
1242+
if do_jit == True: # noqa
1243+
return df[('A', 'C')]
1244+
else:
1245+
return df[['A', 'C']]
12441246

1245-
# SDC pd.df.getitem does not support idx as a list
1246-
sdc_func = self.jit(lambda df: df[('A', 'C')])
1247+
return test_impl
12471248

1249+
test_impl = gen_test_impl()
1250+
sdc_func = self.jit(gen_test_impl(do_jit=True))
1251+
1252+
pd.testing.assert_frame_equal(sdc_func(df), test_impl(df))
1253+
1254+
def _test_df_getitem_bool_series_idx(self, df):
1255+
def test_impl(df):
1256+
return df[df['A'] == -1.]
1257+
1258+
sdc_func = self.jit(test_impl)
12481259
pd.testing.assert_frame_equal(sdc_func(df), test_impl(df))
12491260

1261+
def _test_df_getitem_bool_series_even_idx(self, df):
1262+
def test_impl(df, series):
1263+
return df[series]
1264+
1265+
s = pd.Series([False, True] * 5)
1266+
1267+
sdc_func = self.jit(test_impl)
1268+
pd.testing.assert_frame_equal(sdc_func(df, s), test_impl(df, s))
1269+
12501270
@skip_sdc_jit('DF.getitem unsupported exceptions')
12511271
def test_df_getitem_str_literal_idx_exception_key_error(self):
12521272
def test_impl(df):
@@ -1292,6 +1312,14 @@ def test_df_getitem_idx(self):
12921312
self._test_df_getitem_slice_idx(df)
12931313
self._test_df_getitem_unbox_slice_idx(df, 1, 3)
12941314
self._test_df_getitem_tuple_idx(df)
1315+
self._test_df_getitem_bool_series_idx(df)
1316+
1317+
@skip_sdc_jit('DF.getitem unsupported Series name')
1318+
def test_df_getitem_idx_no_index(self):
1319+
dfs = [gen_df(test_global_input_data_float64), pd.DataFrame({'A': []})]
1320+
for df in dfs:
1321+
with self.subTest(df=df):
1322+
self._test_df_getitem_bool_series_even_idx(df)
12951323

12961324
@skip_sdc_jit('DF.getitem unsupported Series name')
12971325
def test_df_getitem_idx_multiple_types(self):
@@ -1306,6 +1334,12 @@ def test_df_getitem_idx_multiple_types(self):
13061334
self._test_df_getitem_slice_idx(df)
13071335
self._test_df_getitem_unbox_slice_idx(df, 1, 3)
13081336
self._test_df_getitem_tuple_idx(df)
1337+
self._test_df_getitem_bool_series_even_idx(df)
1338+
1339+
@unittest.skip('DF.getitem df[bool_series] unsupported index')
1340+
def test_df_getitem_bool_series_even_idx_with_index(self):
1341+
df = gen_df(test_global_input_data_float64, with_index=True)
1342+
self._test_df_getitem_bool_series_even_idx(df)
13091343

13101344
@unittest.skip('DF.getitem unsupported integer columns')
13111345
def test_df_getitem_int_literal_idx(self):

sdc/utilities/sdc_typing_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,17 @@ def find_common_dtype_from_numpy_dtypes(array_types, scalar_types):
168168
numba_common_dtype = numpy_support.from_dtype(np_common_dtype)
169169

170170
return numba_common_dtype
171+
172+
173+
def gen_df_impl_generator(codegen, impl_name):
174+
"""Generate generator of df methods"""
175+
def _df_impl_generator(self, idx):
176+
func_text, global_vars = codegen(self, idx)
177+
178+
loc_vars = {}
179+
exec(func_text, global_vars, loc_vars)
180+
_impl = loc_vars[impl_name]
181+
182+
return _impl
183+
184+
return _df_impl_generator

0 commit comments

Comments
 (0)