Skip to content

Commit ac86568

Browse files
Joe JevnikGerry Manoim
andauthored
ENH: Add helper for creating modules. (#121)
* ENH: Add helper for creating modules. The created modules help ensure that `import_array` and `py::abi::ensure_compatible_libpy_abi()` are called correctly. This wrapper also guards against C++ exceptions being thrown in the module init. The wrapper also manages many Python 2/3 compatibility issues. Co-authored-by: Gerry Manoim <gmanoim@quantopian.com>
1 parent 31f4086 commit ac86568

8 files changed

Lines changed: 162 additions & 13 deletions

File tree

Makefile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,20 @@ src/%.o: src/%.cc .make/all-flags
195195
$(CXX) $(CXXFLAGS) $(INCLUDE) -MD -fPIC -c $< -o $@
196196

197197
.PHONY: test
198-
test: $(PYTHON_TESTS) $(TEST_MODULE)
198+
test: $(PYTHON_TESTS) $(TEST_MODULE) tests/_test_automodule.so
199199
GTEST_OUTPUT=$(GTEST_OUTPUT) \
200200
$(LD_PRELOAD_VAR)="$(TEST_LD_PRELOAD)" \
201201
ASAN_OPTIONS=$(ASAN_OPTIONS) \
202202
LSAN_OPTIONS=$(LSAN_OPTIONS) \
203-
LSAN_OPTIONS=$(LSAN_OPTIONS) \
204203
GTEST_ARGS=--gtest_filter=$(GTEST_FILTER) \
205204
$(PYTEST) tests/ $(PYTEST_ARGS)
206205

206+
tests/_test_automodule.o: tests/_test_automodule.cc .make/all-flags
207+
$(CXX) $(CXXFLAGS) $(INCLUDE) -MD -fPIC -c $< -o $@
208+
209+
tests/_test_automodule.so: tests/_test_automodule.o
210+
$(CXX) -shared -o $@ $< -lpthread $(LDFLAGS)
211+
207212
.PHONY: gdbtest
208213
gdbtest: $(PYTHON_TESTS)
209214
@LD_LIBRARY_PATH=. GTEST_BREAK_ON_FAILURE=$(GTEST_BREAK) \
@@ -215,9 +220,9 @@ tests/%.o: tests/%.cc .make/all-flags
215220
-isystem submodules/googletest/googletest/src \
216221
-MD -fPIC -c $< -o $@
217222

218-
$(TEST_MODULE): gtest.a $(TEST_OBJECTS) $(SONAME)
223+
$(TEST_MODULE): gtest.a $(TEST_OBJECTS) libpy/libpy.so
219224
$(CXX) -shared -o $@ $(TEST_OBJECTS) gtest.a $(TEST_INCLUDE) \
220-
-Wl,-rpath,`pwd` -lpthread -L. $(SONAME) $(LDFLAGS)
225+
-Wl,-rpath,`pwd` -lpthread -L. $(SONAME) $(LDFLAGS)
221226

222227
gtest.o: $(GTEST_SRCS) .make/all-flags
223228
$(CXX) $(filter-out $(WARNINGS),$(CXXFLAGS)) -I $(GTEST_DIR) \

include/libpy/autoclass.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,32 +150,32 @@ class autoclass_impl {
150150

151151
// dispatch for free function that accepts as a first argument `T`
152152
template<typename R, typename... Args, auto impl>
153-
struct free_function_impl<R(T, Args...), impl>
153+
struct free_function_impl<R(*)(T, Args...), impl>
154154
: public free_function_base<impl, R, Args...> {};
155155

156156
// dispatch for free function that accepts as a first argument `T&`
157157
template<typename R, typename... Args, auto impl>
158-
struct free_function_impl<R(T&, Args...), impl>
158+
struct free_function_impl<R(*)(T&, Args...), impl>
159159
: public free_function_base<impl, R, Args...> {};
160160

161161
// dispatch for free function that accepts as a first argument `const T&`
162162
template<typename R, typename... Args, auto impl>
163-
struct free_function_impl<R(const T&, Args...), impl>
163+
struct free_function_impl<R(*)(const T&, Args...), impl>
164164
: public free_function_base<impl, R, Args...> {};
165165

166166
// dispatch for a noexcept free function that accepts as a first argument `T`
167167
template<typename R, typename... Args, auto impl>
168-
struct free_function_impl<R(T, Args...) noexcept, impl>
168+
struct free_function_impl<R(*)(T, Args...) noexcept, impl>
169169
: public free_function_base<impl, R, Args...> {};
170170

171171
// dispatch for noexcept free function that accepts as a first argument `T&`
172172
template<typename R, typename... Args, auto impl>
173-
struct free_function_impl<R(T&, Args...) noexcept, impl>
173+
struct free_function_impl<R(*)(T&, Args...) noexcept, impl>
174174
: public free_function_base<impl, R, Args...> {};
175175

176176
// dispatch for a noexcept free function that accepts as a first argument `const T&`
177177
template<typename R, typename... Args, auto impl>
178-
struct free_function_impl<R(const T&, Args...) noexcept, impl>
178+
struct free_function_impl<R(*)(const T&, Args...) noexcept, impl>
179179
: public free_function_base<impl, R, Args...> {};
180180

181181
template<auto impl, typename R, typename... Args>
@@ -246,7 +246,7 @@ class autoclass_impl {
246246
static py::owned_ref<PyTypeObject> lookup_type() {
247247
auto type_search = detail::autoclass_type_cache.get().find(typeid(T));
248248
if (type_search != detail::autoclass_type_cache.get().end()) {
249-
PyTypeObject* type = type_search->second->type;
249+
PyTypeObject* type = type_search->second->type.get();
250250
Py_INCREF(type);
251251
return py::owned_ref(type);
252252
}

include/libpy/automodule.h

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#pragma once
2+
3+
#include <forward_list>
4+
#include <unordered_map>
5+
#include <vector>
6+
7+
#include "libpy/abi.h"
8+
#include "libpy/borrowed_ref.h"
9+
#include "libpy/detail/api.h"
10+
#include "libpy/detail/numpy.h"
11+
#include "libpy/detail/python.h"
12+
#include "libpy/owned_ref.h"
13+
14+
#define _libpy_XSTR(s) #s
15+
#define _libpy_STR(s) _libpy_XSTR(s)
16+
17+
#define _libpy_MODULE_PATH(parent, name) _libpy_STR(parent) "." _libpy_STR(name)
18+
19+
#define _libpy_XCAT(a, b) a##b
20+
#define _libpy_CAT(a, b) _libpy_XCAT(a, b)
21+
22+
#define _libpy_MODINIT_NAME(name) _libpy_CAT(PyInit_, name)
23+
#define _libpy_MODULE_CREATE(path) PyModule_Create(&_libpy_module)
24+
25+
/** Define a Python module.
26+
27+
@param parent A symbol indicating the parent module.
28+
@param name The leaf name of the module.
29+
@param methods ({...}) list of objects representing the functions to add to the
30+
module. Note this list must be surrounded by parentheses.
31+
32+
## Examples
33+
34+
Create a module `my_package.submodule.my_module` with two functions `f` and
35+
`g` and one type `T`.
36+
37+
\code
38+
LIBPY_AUTOMODULE(my_package.submodule,
39+
my_module,
40+
({py::autofunction<f>("f"),
41+
py::autofunction<g>("g")}))
42+
(py::borrowed_ref<> m) {
43+
py::borrowed_ref t = py::autoclass<T>("T").new_().type();
44+
return PyObject_SetAttrString(m.get(), "T", static_cast<PyObject*>(t));
45+
}
46+
/endcode
47+
*/
48+
#define LIBPY_AUTOMODULE(parent, name, methods) \
49+
bool _libpy_user_mod_init(py::borrowed_ref<>); \
50+
PyMODINIT_FUNC _libpy_MODINIT_NAME(name)() LIBPY_EXPORT; \
51+
PyMODINIT_FUNC _libpy_MODINIT_NAME(name)() { \
52+
import_array(); \
53+
if (py::abi::ensure_compatible_libpy_abi()) { \
54+
return nullptr; \
55+
} \
56+
static std::vector<PyMethodDef> ms methods; \
57+
ms.emplace_back(py::end_method_list); \
58+
static PyModuleDef _libpy_module{ \
59+
PyModuleDef_HEAD_INIT, \
60+
_libpy_MODULE_PATH(parent, name), \
61+
nullptr, \
62+
-1, \
63+
ms.data(), \
64+
}; \
65+
py::owned_ref m(_libpy_MODULE_CREATE(_libpy_MODULE_PATH(parent, name))); \
66+
if (!m) { \
67+
return nullptr; \
68+
} \
69+
try { \
70+
if (_libpy_user_mod_init(m)) { \
71+
return nullptr; \
72+
} \
73+
} \
74+
catch (const std::exception& e) { \
75+
py::raise_from_cxx_exception(e); \
76+
return nullptr; \
77+
} \
78+
catch (...) { \
79+
if (!PyErr_Occurred()) { \
80+
py::raise(PyExc_RuntimeError) << "an unknown C++ exception was raised"; \
81+
return nullptr; \
82+
} \
83+
} \
84+
return std::move(m).escape(); \
85+
} \
86+
bool _libpy_user_mod_init

include/libpy/detail/autoclass_cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <unordered_map>
99
#include <vector>
1010

11+
#include "libpy/borrowed_ref.h"
1112
#include "libpy/detail/api.h"
1213
#include "libpy/detail/no_destruct_wrapper.h"
1314
#include "libpy/detail/python.h"
@@ -21,7 +22,7 @@ struct autoclass_storage {
2122
unbox_fn unbox;
2223

2324
// Borrowed reference to the type that this struct contains storage for.
24-
PyTypeObject* type;
25+
py::borrowed_ref<PyTypeObject> type;
2526

2627
// The method storage for `type`. We may use a vector because this is just a
2728
// collection of pointers and ints. `PyMethodDef` objects may move around until

include/libpy/from_object.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct from_object<T&> {
119119
throw invalid_conversion::make<T&>(ob);
120120
}
121121
int res = PyObject_IsInstance(ob.get(),
122-
reinterpret_cast<PyObject*>(
122+
static_cast<PyObject*>(
123123
search->second->type));
124124
if (res < 0) {
125125
throw py::exception{};

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import libpy # noqa

tests/_test_automodule.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "libpy/autoclass.h"
2+
#include "libpy/automethod.h"
3+
#include "libpy/automodule.h"
4+
5+
bool is_42(int arg) {
6+
return arg == 42;
7+
}
8+
9+
bool is_true(bool arg) {
10+
return arg;
11+
}
12+
13+
using int_float_pair = std::pair<int, float>;
14+
15+
int first(const int_float_pair& ob) {
16+
return ob.first;
17+
}
18+
19+
float second(const int_float_pair& ob) {
20+
return ob.second;
21+
}
22+
23+
LIBPY_AUTOMODULE(tests,
24+
_test_automodule,
25+
({py::autofunction<is_42>("is_42"),
26+
py::autofunction<is_true>("is_true")}))
27+
(py::borrowed_ref<> m) {
28+
py::owned_ref t = py::autoclass<int_float_pair>("_test_automodule.int_float_pair")
29+
.new_<int, float>()
30+
.comparisons<int_float_pair>()
31+
.def<first>("first")
32+
.def<second>("second")
33+
.type();
34+
return PyObject_SetAttrString(m.get(), "int_float_pair", static_cast<PyObject*>(t));
35+
}

tests/test_automodule.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from . import _test_automodule as mod
2+
3+
4+
def test_modname():
5+
assert mod.__name__ == 'tests._test_automodule'
6+
7+
8+
def test_function():
9+
assert mod.is_42(42)
10+
assert not mod.is_42(~42)
11+
12+
13+
def test_type():
14+
assert isinstance(mod.int_float_pair, type)
15+
a = mod.int_float_pair(1, 2.5)
16+
assert a.first() == 1
17+
assert a.second() == 2.5
18+
b = mod.int_float_pair(1, 2.5)
19+
assert a == b
20+
c = mod.int_float_pair(1, 3.5)
21+
assert a != c

0 commit comments

Comments
 (0)