From 6d58f1978edd4e0f1de462478ba333dd6246eb33 Mon Sep 17 00:00:00 2001 From: zhushuang Date: Thu, 4 Jun 2026 09:26:23 +0000 Subject: [PATCH] refactor InfiniOps cpu runtime through InfiniRT --- CMakeLists.txt | 10 ++ scripts/generate_wrappers.py | 84 +++++++++++++++- src/CMakeLists.txt | 67 ++++++++++++- src/data_type.h | 183 +++-------------------------------- src/device.h | 109 +-------------------- src/native/cpu/data_type_.h | 17 +--- src/native/cpu/device_.h | 9 +- src/native/cpu/runtime_.h | 30 +----- src/runtime.h | 44 +-------- src/tensor.h | 149 +--------------------------- 10 files changed, 179 insertions(+), 523 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cdc54878..ad2857682 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,10 @@ option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) option(GENERATE_OPERATOR_CALL_INSTANTIATIONS "Generate explicit operator call instantiations" ON) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) +option(USE_EXISTING_GENERATED_WRAPPERS + "Build from existing generated wrapper sources instead of regenerating them" OFF) +option(INFINIOPS_MINIMAL_ADD_BINDINGS + "Build a minimal Python module exposing only the add op" OFF) set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk") @@ -476,6 +480,7 @@ endif() # If all other platforms are not enabled, CPU is enabled by default. if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_HYGON AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) + set(WITH_CPU ON CACHE BOOL "Enable CPU backend" FORCE) add_compile_definitions(WITH_CPU=1) endif() @@ -487,6 +492,11 @@ if(WITH_HYGON AND NOT EXISTS "${DTK_ROOT}/llvm/lib/LLVMgold.so") set(PYBIND11_ENABLE_EXTRAS OFF) endif() +set(INFINIRT_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../InfiniRT" CACHE PATH "InfiniRT source directory") +if(NOT TARGET infinirt) + add_subdirectory("${INFINIRT_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/InfiniRT") +endif() + add_subdirectory(src) if(NOT GENERATE_PYTHON_BINDINGS) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index f5734b7ad..0e199f22c 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -32,6 +32,18 @@ _INDENTATION = " " +def _get_infinirt_include_flags(): + infinirt_source_dir = pathlib.Path( + os.environ.get("INFINIRT_SOURCE_DIR", "../InfiniRT") + ) + infinirt_include_dir = infinirt_source_dir / "src" + + if not infinirt_include_dir.exists(): + return () + + return ("-I", str(infinirt_include_dir)) + + @functools.lru_cache(maxsize=1) def _get_system_include_flags(): """Probe the system C++ compiler for default include paths so libclang @@ -85,7 +97,7 @@ def __call__(self, op_name): "src", "-I", str(_GENERATION_DIR), - ) + _get_system_include_flags() + ) + _get_infinirt_include_flags() + _get_system_include_flags() translation_unit = index.parse(str(_find_base_header(op_name)), args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -160,6 +172,18 @@ def _find_vector_int64_params(op_name): return set(re.findall(r"std::vector\s+(\w+)", source)) +def _find_tensor_params(op_name): + source = _find_base_header(op_name).read_text() + + params = set() + params.update( + re.findall(r"(?:^|[,(]\s*)(?:const\s+)?Tensor\s+(\w+)", source) + ) + params.update(_find_optional_tensor_params(op_name)) + params.update(_find_vector_tensor_params(op_name)) + return params + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) @@ -537,6 +561,7 @@ def _generate_tensor_caster(name, is_data=False): def _generate_generated_dispatch_entries(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) + tensor_params = _find_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) vector_int64_params = _find_vector_int64_params(operator.name) @@ -555,6 +580,12 @@ def _is_vector_tensor(arg): def _is_vector_int64(arg): return arg.spelling in vector_int64_params + def _is_tensor(arg): + if arg.spelling in tensor_params: + return True + + return "Tensor" in arg.type.spelling or "TensorView" in arg.type.spelling + def _generate_params(node): parts = [] @@ -568,6 +599,8 @@ def _generate_params(node): parts.append(f"std::vector {arg.spelling}") elif _is_vector_int64(arg): parts.append(f"std::vector {arg.spelling}") + elif _is_tensor(arg): + parts.append(f"Tensor {arg.spelling}") else: parts.append(f"{arg.type.spelling} {arg.spelling}") @@ -712,16 +745,61 @@ def _strip_top_level_const(type_spelling): def _generate_operator_call_instantiation_entries(operator): + optional_tensor_params = _find_optional_tensor_params(operator.name) + tensor_params = _find_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) + vector_int64_params = _find_vector_int64_params(operator.name) + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + + return "std::optional" in arg.type.spelling and ( + "Tensor" in arg.type.spelling or "TensorView" in arg.type.spelling + ) + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + + return "std::vector" in arg.type.spelling and ( + "Tensor" in arg.type.spelling or "TensorView" in arg.type.spelling + ) + + def _is_vector_int64(arg): + return arg.spelling in vector_int64_params + + def _is_tensor(arg): + if arg.spelling in tensor_params: + return True + + return "Tensor" in arg.type.spelling or "TensorView" in arg.type.spelling + + def _normalized_type(arg): + if _is_optional_tensor(arg): + return "std::optional" + + if _is_vector_tensor(arg): + return "std::vector" + + if _is_vector_int64(arg): + return "std::vector" + + if _is_tensor(arg): + return "Tensor" + + return _strip_top_level_const(arg.type.spelling) + def _generate_template_arguments(node): return ", ".join( - _strip_top_level_const(arg.type.spelling) + _normalized_type(arg) for arg in node.get_arguments() if arg.spelling != "stream" ) def _generate_parameters(node): return ", ".join( - f"const {_strip_top_level_const(arg.type.spelling)}& {arg.spelling}" + f"const {_normalized_type(arg)}& {arg.spelling}" for arg in node.get_arguments() if arg.spelling != "stream" ) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4b0ca3028..a2be1d041 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,8 +3,16 @@ add_library(infiniops SHARED) include(GNUInstallDirs) file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") +list(FILTER BASE_SRCS EXCLUDE REGEX ".*tensor\\.cc$") target_sources(infiniops PRIVATE ${BASE_SRCS}) +set(INFINIOPS_EMPTY_SOURCE "${CMAKE_CURRENT_BINARY_DIR}/infiniops_empty.cc") +file(WRITE "${INFINIOPS_EMPTY_SOURCE}" + "namespace infini::ops { void infiniops_link_anchor() {} }\n") +target_sources(infiniops PRIVATE "${INFINIOPS_EMPTY_SOURCE}") + +target_link_libraries(infiniops PUBLIC infinirt) + set(DEVICE_LIST "") if(WITH_CPU) @@ -15,6 +23,7 @@ if(WITH_CPU) file(GLOB_RECURSE CPU_SOURCES CONFIGURE_DEPENDS ${CPU_PATTERNS}) list(APPEND CORE_SOURCES ${CPU_SOURCES}) + target_sources(infiniops PRIVATE ${CPU_SOURCES}) target_compile_definitions(infiniops PUBLIC WITH_CPU=1) @@ -510,7 +519,9 @@ target_include_directories(infiniops $ ) -if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) +if((GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) + AND NOT USE_EXISTING_GENERATED_WRAPPERS + AND NOT INFINIOPS_MINIMAL_ADD_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) # Always regenerate wrappers so emitted call instantiations and bindings # match the active device list. Stale generated files would omit @@ -526,7 +537,11 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) endif() execute_process( - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py ${GENERATOR_ARGS} + COMMAND ${CMAKE_COMMAND} -E env + INFINIRT_SOURCE_DIR=${INFINIRT_SOURCE_DIR} + ${Python_EXECUTABLE} + ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py + ${GENERATOR_ARGS} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} RESULT_VARIABLE script_result ) @@ -536,6 +551,15 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS OR GENERATE_PYTHON_BINDINGS) else() message(STATUS "Generating wrappers - done") endif() +elseif(USE_EXISTING_GENERATED_WRAPPERS) + if(NOT EXISTS "${PROJECT_SOURCE_DIR}/generated/include" OR + NOT EXISTS "${PROJECT_SOURCE_DIR}/generated/src" OR + NOT EXISTS "${PROJECT_SOURCE_DIR}/generated/bindings") + message(FATAL_ERROR + "`USE_EXISTING_GENERATED_WRAPPERS` is ON but generated wrapper " + "sources are missing under `${PROJECT_SOURCE_DIR}/generated`.") + endif() + message(STATUS "Using existing generated wrapper sources") endif() if(GENERATE_OPERATOR_CALL_INSTANTIATIONS) @@ -606,6 +630,37 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS) endif() if(GENERATE_PYTHON_BINDINGS) + if(INFINIOPS_MINIMAL_ADD_BINDINGS) + find_package(Python COMPONENTS Interpreter Development REQUIRED) + + if(NOT pybind11_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir + OUTPUT_VARIABLE _pybind11_cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _pybind11_result + ) + + if(_pybind11_result EQUAL 0) + set(pybind11_DIR "${_pybind11_cmake_dir}" CACHE PATH "pybind11 CMake directory") + endif() + endif() + + find_package(pybind11 CONFIG REQUIRED) + pybind11_add_module(ops python/add_module.cc) + target_include_directories(ops PRIVATE + ${PROJECT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/generated + ${PROJECT_SOURCE_DIR}/generated/include + ${CMAKE_CURRENT_SOURCE_DIR} + $) + target_compile_definitions(ops PRIVATE + $) + target_link_libraries(ops PRIVATE infiniops) + return() + endif() + file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/generated/bindings/*.cc") @@ -792,7 +847,7 @@ if(GENERATE_PYTHON_BINDINGS) set_target_properties(infiniops PROPERTIES INSTALL_RPATH "${_INFINIOPS_INSTALL_RPATH}") set_target_properties(ops PROPERTIES INSTALL_RPATH "${_INFINIOPS_INSTALL_RPATH}") - install(TARGETS infiniops ops DESTINATION .) + install(TARGETS infinirt infiniops ops DESTINATION .) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" "") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/__init__.py" DESTINATION .) @@ -811,6 +866,12 @@ install(TARGETS infiniops RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) +install(TARGETS infinirt + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) diff --git a/src/data_type.h b/src/data_type.h index 75483d2b8..1c8bc5e0a 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -1,192 +1,31 @@ #ifndef INFINI_OPS_DATA_TYPE_H_ #define INFINI_OPS_DATA_TYPE_H_ -#include -#include -#include - -#include "common/constexpr_map.h" #include "common/traits.h" #include "device.h" +#include "infini_rt/data_type.h" namespace infini::ops { -enum class DataType : std::int8_t { - kInt8, - kInt16, - kInt32, - kInt64, - kUInt8, - kUInt16, - kUInt32, - kUInt64, - kFloat16, - kBFloat16, - kFloat32, - kFloat64 -}; - -constexpr ConstexprMap kDataTypeToSize{{{ - {DataType::kInt8, 1}, - {DataType::kInt16, 2}, - {DataType::kInt32, 4}, - {DataType::kInt64, 8}, - {DataType::kUInt8, 1}, - {DataType::kUInt16, 2}, - {DataType::kUInt32, 4}, - {DataType::kUInt64, 8}, - {DataType::kFloat16, 2}, - {DataType::kBFloat16, 2}, - {DataType::kFloat32, 4}, - {DataType::kFloat64, 8}, -}}}; - -constexpr ConstexprMap kDataTypeToDesc{{{ - {DataType::kInt8, "int8"}, - {DataType::kInt16, "int16"}, - {DataType::kInt32, "int32"}, - {DataType::kInt64, "int64"}, - {DataType::kUInt8, "uint8"}, - {DataType::kUInt16, "uint16"}, - {DataType::kUInt32, "uint32"}, - {DataType::kUInt64, "uint64"}, - {DataType::kFloat16, "float16"}, - {DataType::kBFloat16, "bfloat16"}, - {DataType::kFloat32, "float32"}, - {DataType::kFloat64, "float64"}, -}}}; - -constexpr ConstexprMap kStringToDataType{{{ - {"int8", DataType::kInt8}, - {"int16", DataType::kInt16}, - {"int32", DataType::kInt32}, - {"int64", DataType::kInt64}, - {"uint8", DataType::kUInt8}, - {"uint16", DataType::kUInt16}, - {"uint32", DataType::kUInt32}, - {"uint64", DataType::kUInt64}, - {"float16", DataType::kFloat16}, - {"bfloat16", DataType::kBFloat16}, - {"float32", DataType::kFloat32}, - {"float64", DataType::kFloat64}, -}}}; - -struct Float16 { - std::uint16_t bits; - - static inline Float16 FromFloat(float val) { - std::uint32_t f32; - std::memcpy(&f32, &val, sizeof(f32)); - std::uint16_t sign = (f32 >> 16) & 0x8000; - std::int32_t exponent = ((f32 >> 23) & 0xFF) - 127; - std::uint32_t mantissa = f32 & 0x7FFFFF; - - if (exponent >= 16) { - // NaN - if (exponent == 128 && mantissa != 0) { - return {static_cast(sign | 0x7E00)}; - } - // Inf - return {static_cast(sign | 0x7C00)}; - } else if (exponent >= -14) { - return {static_cast(sign | ((exponent + 15) << 10) | - (mantissa >> 13))}; - } else if (exponent >= -24) { - mantissa |= 0x800000; - mantissa >>= (-14 - exponent); - return {static_cast(sign | (mantissa >> 13))}; - } - // Too small for subnormal: return signed zero. - return {sign}; - } - - inline float ToFloat() const { - std::uint32_t sign = (bits & 0x8000) << 16; - std::int32_t exponent = (bits >> 10) & 0x1F; - std::uint32_t mantissa = bits & 0x3FF; - std::uint32_t f32_bits; - - if (exponent == 31) { - f32_bits = sign | 0x7F800000 | (mantissa << 13); - } else if (exponent == 0) { - if (mantissa == 0) { - f32_bits = sign; - } else { - exponent = -14; - while ((mantissa & 0x400) == 0) { - mantissa <<= 1; - exponent--; - } - mantissa &= 0x3FF; - f32_bits = sign | ((exponent + 127) << 23) | (mantissa << 13); - } - } else { - f32_bits = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); - } - - float result; - std::memcpy(&result, &f32_bits, sizeof(result)); - return result; - } -}; - -struct BFloat16 { - std::uint16_t bits; - - static inline BFloat16 FromFloat(float val) { - std::uint32_t bits32; - std::memcpy(&bits32, &val, sizeof(bits32)); - - const std::uint32_t rounding_bias = 0x00007FFF + ((bits32 >> 16) & 1); - std::uint16_t bf16_bits = - static_cast((bits32 + rounding_bias) >> 16); - return {bf16_bits}; - } - - inline float ToFloat() const { - std::uint32_t bits32 = static_cast(bits) << 16; - float result; - std::memcpy(&result, &bits32, sizeof(result)); - return result; - } -}; +using infini::rt::BFloat16; +using infini::rt::DataType; +using infini::rt::Float16; +using infini::rt::kDataTypeToDesc; +using infini::rt::kDataTypeToSize; +using infini::rt::kStringToDataType; template -struct TypeMap; +using TypeMap = infini::rt::TypeMap; template -using TypeMapType = typename TypeMap::type; - -#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ - template \ - struct TypeMap { \ - using type = CPP_TYPE; \ - }; - -DEFINE_DATA_TYPE_MAPPING(kUInt8, std::uint8_t) -DEFINE_DATA_TYPE_MAPPING(kInt8, std::int8_t) -DEFINE_DATA_TYPE_MAPPING(kUInt16, std::uint16_t) -DEFINE_DATA_TYPE_MAPPING(kInt16, std::int16_t) -DEFINE_DATA_TYPE_MAPPING(kUInt32, std::uint32_t) -DEFINE_DATA_TYPE_MAPPING(kInt32, std::int32_t) -DEFINE_DATA_TYPE_MAPPING(kUInt64, std::uint64_t) -DEFINE_DATA_TYPE_MAPPING(kInt64, std::int64_t) -DEFINE_DATA_TYPE_MAPPING(kFloat32, float) -DEFINE_DATA_TYPE_MAPPING(kFloat64, double) -#undef DEFINE_DATA_TYPE_MAPPING +using TypeMapType = infini::rt::TypeMapType; -// Checks whether a C++ type is the bfloat16 or float16 type for the given -// device. Full specializations for each device's float16/bfloat16 types are -// provided in the corresponding platform-specific device type headers. template -inline constexpr bool IsBFloat16 = - std::is_same_v>; +inline constexpr bool IsBFloat16 = infini::rt::IsBFloat16; template -inline constexpr bool IsFP16 = - std::is_same_v>; +inline constexpr bool IsFP16 = infini::rt::IsFP16; -// Defines the common categories of data types using List. using FloatTypes = List; using ReducedFloatTypes = List; using IntTypes = diff --git a/src/device.h b/src/device.h index 688cd0dc2..8c18efc5c 100644 --- a/src/device.h +++ b/src/device.h @@ -1,110 +1,22 @@ #ifndef INFINI_OPS_DEVICE_H_ #define INFINI_OPS_DEVICE_H_ -#include -#include - -#include "common/constexpr_map.h" #include "common/traits.h" -#include "hash.h" +#include "infini_rt/device.h" namespace infini::ops { -class Device { - public: - enum class Type { - kCpu = 0, - kNvidia = 1, - kCambricon = 2, - kAscend = 3, - kMetax = 4, - kMoore = 5, - kIluvatar = 6, - kKunlun = 7, - kHygon = 8, - kQy = 9, - kCount - }; - - Device() = default; - - Device(const Type& type, const int& index = 0) : type_{type}, index_{index} {} - - static const Type TypeFromString(const std::string& name) { - return kDescToDevice.at(name); - } - - static const std::string_view StringFromType(const Type& type) { - return kDeviceToDesc.at(type); - } - - const Type& type() const { return type_; } - - const int& index() const { return index_; } - - std::string ToString() const { - return std::string{StringFromType(type_)} + ":" + std::to_string(index_); - } - - bool operator==(const Device& other) const { - return type_ == other.type_ && index_ == other.index_; - } - - bool operator!=(const Device& other) const { return !(*this == other); } +using Device = infini::rt::Device; - private: - Type type_{Type::kCpu}; +template +using DeviceEnabled = infini::rt::DeviceEnabled; - static constexpr ConstexprMap(Device::Type::kCount)> - kDeviceToDesc{{{ - {Type::kCpu, "cpu"}, - {Type::kNvidia, "nvidia"}, - {Type::kCambricon, "cambricon"}, - {Type::kAscend, "ascend"}, - {Type::kMetax, "metax"}, - {Type::kMoore, "moore"}, - {Type::kIluvatar, "iluvatar"}, - {Type::kKunlun, "kunlun"}, - {Type::kHygon, "hygon"}, - {Type::kQy, "qy"}, - }}}; - - static constexpr ConstexprMap(Device::Type::kCount)> - kDescToDevice{{{ - {"cpu", Type::kCpu}, - {"nvidia", Type::kNvidia}, - {"cambricon", Type::kCambricon}, - {"ascend", Type::kAscend}, - {"metax", Type::kMetax}, - {"moore", Type::kMoore}, - {"iluvatar", Type::kIluvatar}, - {"kunlun", Type::kKunlun}, - {"hygon", Type::kHygon}, - {"qy", Type::kQy}, - }}}; - - int index_{0}; -}; - -// Primary template: Devices are disabled by default. Platform-specific -// headers (e.g. `cpu/device_.h`) specialize this to `std::true_type`. -template -struct DeviceEnabled : std::false_type {}; - -// Defines the common categories of devices using List. using AllDeviceTypes = List; -// Deferred computation of active devices. The `Filter` and `FilterList` -// evaluation are nested inside a class template so that `DeviceEnabled` -// specializations from platform `device_.h` headers are visible at -// instantiation time. Use with a dependent type parameter -// (e.g. `ActiveDevices`) to ensure deferred instantiation. template struct ActiveDevicesImpl { struct Filter { @@ -121,17 +33,4 @@ using ActiveDevices = typename ActiveDevicesImpl::type; } // namespace infini::ops -template <> -struct std::hash { - std::size_t operator()(const infini::ops::Device& device) const { - std::size_t seed{0}; - - HashCombine(seed, device.type()); - - HashCombine(seed, device.index()); - - return seed; - } -}; - #endif diff --git a/src/native/cpu/data_type_.h b/src/native/cpu/data_type_.h index 36231db51..d07d60c3e 100644 --- a/src/native/cpu/data_type_.h +++ b/src/native/cpu/data_type_.h @@ -1,21 +1,6 @@ #ifndef INFINI_OPS_CPU_DATA_TYPE__H_ #define INFINI_OPS_CPU_DATA_TYPE__H_ -#include "data_type.h" -#include "native/cpu/device_.h" - -namespace infini::ops { - -template <> -struct TypeMap { - using type = Float16; -}; - -template <> -struct TypeMap { - using type = BFloat16; -}; - -} // namespace infini::ops +#include "infini_rt/cpu/data_type_.h" #endif diff --git a/src/native/cpu/device_.h b/src/native/cpu/device_.h index e5e7d85a3..92047ac3c 100644 --- a/src/native/cpu/device_.h +++ b/src/native/cpu/device_.h @@ -1,13 +1,6 @@ #ifndef INFINI_OPS_CPU_DEVICE__H_ #define INFINI_OPS_CPU_DEVICE__H_ -#include "device.h" - -namespace infini::ops { - -template <> -struct DeviceEnabled : std::true_type {}; - -} // namespace infini::ops +#include "infini_rt/cpu/device_.h" #endif diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index cb6176ba1..ecd4305a4 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,34 +1,6 @@ #ifndef INFINI_OPS_CPU_RUNTIME_H_ #define INFINI_OPS_CPU_RUNTIME_H_ -#include -#include - -#include "runtime.h" - -namespace infini::ops { - -template <> -struct Runtime : RuntimeBase> { - static constexpr Device::Type kDeviceType = Device::Type::kCpu; - - static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } - - static void Free(void* ptr) { std::free(ptr); } - - static void Memcpy(void* dst, const void* src, std::size_t size, int) { - std::memcpy(dst, src, size); - } - - static constexpr auto Memset = std::memset; - - static constexpr int MemcpyHostToDevice = 0; - - static constexpr int MemcpyDeviceToHost = 1; -}; - -static_assert(Runtime::Validate()); - -} // namespace infini::ops +#include "infini_rt/cpu/runtime_.h" #endif diff --git a/src/runtime.h b/src/runtime.h index 38257893c..34e3c4ba6 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -1,54 +1,18 @@ #ifndef INFINI_OPS_RUNTIME_H_ #define INFINI_OPS_RUNTIME_H_ -#include - -#include "device.h" +#include "infini_rt/runtime.h" namespace infini::ops { template -struct Runtime; - -/// ## Interface enforcement via CRTP. -/// -/// Inherit from the appropriate base to declare which interface level a -/// `Runtime` specialization implements. After the struct is fully defined, call -/// `static_assert(Runtime<...>::Validate())`. The chained `Validate()` checks -/// every required member's existence and signature at compile time, analogous -/// to how `override` catches signature mismatches for virtual functions. -/// -/// - `RuntimeBase`: `kDeviceType` only (e.g. CPU). -/// - `DeviceRuntime`: adds `Stream`, `Malloc`, and `Free` (e.g. Cambricon). +using Runtime = infini::rt::Runtime; -/// Every Runtime must provide `static constexpr Device::Type kDeviceType`. template -struct RuntimeBase { - static constexpr bool Validate() { - static_assert( - std::is_same_v, - Device::Type>, - "`Runtime` must define `static constexpr Device::Type kDeviceType`."); - return true; - } -}; +using RuntimeBase = infini::rt::RuntimeBase; -/// Runtimes with device memory must additionally provide `Stream`, `Malloc`, -/// and `Free`. template -struct DeviceRuntime : RuntimeBase { - static constexpr bool Validate() { - RuntimeBase::Validate(); - static_assert(sizeof(typename Derived::Stream) > 0, - "`Runtime` must define a `Stream` type alias."); - static_assert( - std::is_invocable_v, - "`Runtime::Malloc` must be callable with `(void**, size_t)`."); - static_assert(std::is_invocable_v, - "`Runtime::Free` must be callable with `(void*)`."); - return true; - } -}; +using DeviceRuntime = infini::rt::DeviceRuntime; } // namespace infini::ops diff --git a/src/tensor.h b/src/tensor.h index 290e3cf96..576e0ba2d 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -1,157 +1,12 @@ #ifndef INFINI_OPS_TENSOR_H_ #define INFINI_OPS_TENSOR_H_ -#include -#include -#include - -#include "data_type.h" -#include "device.h" -#include "hash.h" +#include "infini_rt/tensor_view.h" namespace infini::ops { -class Tensor { - public: - using Size = std::size_t; - - using Stride = std::ptrdiff_t; - - using Index = Stride; - - using Shape = std::vector; - - using Strides = std::vector; - - template - Tensor(void* data, const Shape& shape) - : data_{data}, - shape_{shape}, - dtype_{DefaultDataType()}, - device_{DefaultDevice()}, - strides_{DefaultStrides(shape)} {} - - template - Tensor(void* data, const Shape& shape, const DataType& dtype) - : data_{data}, - shape_{shape}, - dtype_{dtype}, - device_{DefaultDevice()}, - strides_{DefaultStrides(shape)} {} - - template - Tensor(void* data, const Shape& shape, const Device& device) - : data_{data}, - shape_{shape}, - dtype_{DefaultDataType()}, - device_{device}, - strides_{DefaultStrides(shape)} {} - - template - Tensor(void* data, const Shape& shape, const DataType& dtype, - const Device& device) - : data_{data}, - shape_{shape}, - dtype_{dtype}, - device_{device}, - strides_{DefaultStrides(shape)} {} - - template - Tensor(void* data, const Shape& shape, const DataType& dtype, - const Device& device, const Strides& strides) - : data_{data}, - shape_{shape}, - dtype_{dtype}, - device_{device}, - strides_{strides} {} - - Tensor(void* data, std::initializer_list shape, const DataType& dtype, - const Device& device, std::initializer_list strides); - - Tensor operator[](const Index& index) const; - - void*& data(); - - const void* data() const; - - const DataType& dtype() const; - - const Device& device() const; - - const Shape& shape() const; - - const Strides& strides() const; - - Size size(const Index& index) const; - - Stride stride(const Index& index) const; - - Size ndim() const; - - Size element_size() const; - - Size numel() const; - - Tensor T() const; - - std::string ToString() const; - - bool HasBroadcastDim() const; - - bool IsContiguous() const; - - private: - static const DataType DefaultDataType(); - - static Device DefaultDevice(); - - static Strides DefaultStrides(const Shape& shape); - - std::string ToStringHelper() const; - - bool IsMergeable(Size dim_start, Size dim_end) const; - - void* data_{nullptr}; - - Shape shape_; - - const DataType dtype_; - - Device device_; - - Strides strides_; -}; +using Tensor = infini::rt::TensorView; } // namespace infini::ops -template <> -struct std::hash { - std::size_t operator()(const infini::ops::Tensor& tensor) const { - std::size_t seed{0}; - - for (const auto& size : tensor.shape()) { - HashCombine(seed, size); - } - - HashCombine(seed, tensor.dtype()); - - HashCombine(seed, tensor.device()); - - for (const auto& stride : tensor.strides()) { - HashCombine(seed, stride); - } - - return seed; - } -}; - -template <> -struct std::equal_to { - bool operator()(const infini::ops::Tensor& a, - const infini::ops::Tensor& b) const { - return a.dtype() == b.dtype() && a.device() == b.device() && - a.shape() == b.shape() && a.strides() == b.strides(); - } -}; - #endif