@@ -8,7 +8,7 @@ use std::{collections::HashMap, ffi::CString, sync::Mutex};
88
99use lazy_static:: lazy_static;
1010use pyo3:: exceptions:: PyBaseException ;
11- use pyo3:: types:: { PyAnyMethods , PyDictMethods } ;
11+ use pyo3:: types:: { PyAnyMethods , PyDictMethods , PyList , PyListMethods } ;
1212use pyo3:: PyErr ;
1313use pyo3:: { marker, types:: PyDict , Py , PyAny , PyResult } ;
1414
@@ -26,6 +26,11 @@ pub fn init_python() -> PyResult<()> {
2626 let code = py_main_import:: read_at_startup ( ) ;
2727 let c_code = CString :: new ( code) . expect ( "error loading python" ) ;
2828 marker:: Python :: with_gil ( |py| -> PyResult < ( ) > {
29+ let syspath = py
30+ . import ( "sys" ) ?
31+ . getattr ( "path" ) ?
32+ . downcast_into :: < PyList > ( ) ?;
33+ syspath. insert ( 0 , py_main_import:: get_py_path ( ) . to_str ( ) ) ?;
2934 let globals = GLOBALS . lock ( ) . unwrap ( ) . clone_ref ( py) . into_bound ( py) ;
3035 py. run ( & c_code, Some ( & globals) , None )
3136 } )
@@ -39,7 +44,7 @@ pub fn run_python(payload: StringRequest) -> PyResult<()> {
3944 } )
4045}
4146pub fn register_function ( payload : RegisterRequest ) -> PyResult < ( ) > {
42- let fn_name = payload. function_name ;
47+ let fn_name = payload. python_function_call ;
4348 // TODO, check actual function signature
4449 if INIT_BLOCKED . load ( std:: sync:: atomic:: Ordering :: Relaxed ) {
4550 return Err ( pyo3:: exceptions:: PyException :: new_err (
@@ -48,14 +53,20 @@ pub fn register_function(payload: RegisterRequest) -> PyResult<()> {
4853 }
4954 marker:: Python :: with_gil ( |py| -> PyResult < ( ) > {
5055 let globals = GLOBALS . lock ( ) . unwrap ( ) . clone_ref ( py) . into_bound ( py) ;
51- let app = globals. get_item ( & fn_name) ?;
56+
57+ let fn_dot_split: Vec < & str > = fn_name. split ( "." ) . collect ( ) ;
58+ let app = globals. get_item ( & fn_dot_split[ 0 ] ) ?;
5259 if app. is_none ( ) {
5360 return Err ( pyo3:: exceptions:: PyException :: new_err ( format ! (
5461 "{} not found" ,
5562 & fn_name
5663 ) ) ) ;
5764 }
58- let app = app. unwrap ( ) ;
65+ let app = if fn_dot_split. len ( ) > 1 {
66+ app. unwrap ( ) . getattr ( fn_dot_split. get ( 1 ) . unwrap ( ) ) ?
67+ } else {
68+ app. unwrap ( )
69+ } ;
5970 if !app. is_callable ( ) {
6071 return Err ( pyo3:: exceptions:: PyException :: new_err ( format ! (
6172 "{} not a callable function" ,
0 commit comments