|
54 | 54 | argsort = get_xp(cp)(_aliases.argsort) |
55 | 55 | sort = get_xp(cp)(_aliases.sort) |
56 | 56 | nonzero = get_xp(cp)(_aliases.nonzero) |
57 | | -ceil = get_xp(cp)(_aliases.ceil) |
58 | | -floor = get_xp(cp)(_aliases.floor) |
59 | | -trunc = get_xp(cp)(_aliases.trunc) |
60 | 57 | matmul = get_xp(cp)(_aliases.matmul) |
61 | 58 | matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) |
62 | 59 | tensordot = get_xp(cp)(_aliases.tensordot) |
@@ -123,6 +120,25 @@ def count_nonzero( |
123 | 120 | return cp.expand_dims(result, axis) |
124 | 121 | return result |
125 | 122 |
|
| 123 | +# ceil, floor, and trunc return integers for integer inputs |
| 124 | + |
| 125 | +def ceil(x: Array, /) -> Array: |
| 126 | + if cp.issubdtype(x.dtype, cp.integer): |
| 127 | + return x.copy() |
| 128 | + return cp.ceil(x) |
| 129 | + |
| 130 | + |
| 131 | +def floor(x: Array, /) -> Array: |
| 132 | + if cp.issubdtype(x.dtype, cp.integer): |
| 133 | + return x.copy() |
| 134 | + return cp.floor(x) |
| 135 | + |
| 136 | + |
| 137 | +def trunc(x: Array, /) -> Array: |
| 138 | + if cp.issubdtype(x.dtype, cp.integer): |
| 139 | + return x.copy() |
| 140 | + return cp.trunc(x) |
| 141 | + |
126 | 142 |
|
127 | 143 | # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg |
128 | 144 | def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): |
@@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): |
151 | 167 | 'atan2', 'atanh', 'bitwise_left_shift', |
152 | 168 | 'bitwise_invert', 'bitwise_right_shift', |
153 | 169 | 'bool', 'concat', 'count_nonzero', 'pow', 'sign', |
154 | | - 'take_along_axis'] |
| 170 | + 'ceil', 'floor', 'trunc', 'take_along_axis'] |
155 | 171 |
|
156 | 172 | _all_ignore = ['cp', 'get_xp'] |
0 commit comments