From 7585c41502821bbe2bc0c3a127d1cab87ff4e423 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:06:52 +0000 Subject: [PATCH 01/22] chore: add `.clang-format` --- .clang-format | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..2296f7d --- /dev/null +++ b/.clang-format @@ -0,0 +1,3 @@ +--- +BasedOnStyle: Google +... From db27e33ec6954b3de9c70f7731a1c10a1c33397c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:08:19 +0000 Subject: [PATCH 02/22] feat: add common utilities --- src/common/constexpr_map.h | 32 +++++++ src/common/generic_utils.h | 26 ++++++ src/common/traits.h | 170 +++++++++++++++++++++++++++++++++++++ src/hash.h | 12 +++ 4 files changed, 240 insertions(+) create mode 100644 src/common/constexpr_map.h create mode 100644 src/common/generic_utils.h create mode 100644 src/common/traits.h create mode 100644 src/hash.h diff --git a/src/common/constexpr_map.h b/src/common/constexpr_map.h new file mode 100644 index 0000000..0a01e2a --- /dev/null +++ b/src/common/constexpr_map.h @@ -0,0 +1,32 @@ +#ifndef INFINI_RT_COMMON_CONSTEXPR_MAP_H_ +#define INFINI_RT_COMMON_CONSTEXPR_MAP_H_ + +#include +#include +#include +#include + +namespace infini::rt { + +template +struct ConstexprMap { + constexpr ConstexprMap(std::array, size> data) + : data_(data) {} + + constexpr Value at(Key key) const { + for (const auto& pr : data_) { + if (pr.first == key) return pr.second; + } + // TODO(lzm): change to logging. + assert("the key is not found in the `ConstexprMap`"); + // Unreachable, provided to satisfy the compiler's requirement. + std::abort(); + } + + private: + std::array, size> data_; +}; + +} // namespace infini::rt + +#endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h new file mode 100644 index 0000000..bda6cde --- /dev/null +++ b/src/common/generic_utils.h @@ -0,0 +1,26 @@ +#ifndef INFINI_RT_COMMON_GENERIC_UTILS_H_ +#define INFINI_RT_COMMON_GENERIC_UTILS_H_ + +#include + +namespace infini::rt::utils { + +std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, + const std::size_t* shape, + const std::ptrdiff_t* strides) { + std::size_t res = 0; + for (std::size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +template +constexpr auto CeilDiv(const X& x, const Y& y) { + return (x + y - 1) / y; +} + +} // namespace infini::rt::utils + +#endif diff --git a/src/common/traits.h b/src/common/traits.h new file mode 100644 index 0000000..6459bda --- /dev/null +++ b/src/common/traits.h @@ -0,0 +1,170 @@ +#ifndef INFINI_RT_COMMON_TRAITS_H_ +#define INFINI_RT_COMMON_TRAITS_H_ + +#include +#include + +namespace infini::rt { + +// --------------------- List and TypePack --------------------- +// A generic container for a sequence of compile-time values. +template +struct List {}; + +// `ListGet(List{})` extracts the `i`th value from a `List` +// tag. +template +constexpr auto ListGetImpl(List) { + if constexpr (index == 0) + return head; + else + return ListGetImpl(List{}); +} + +template +constexpr auto ListGet(List list) { + return ListGetImpl(list); +} + +template +struct TypePack {}; + +// ----------------------------------------------------------------------------- +// Tags +// ----------------------------------------------------------------------------- +// Tags are passed as regular function arguments to user functors instead of +// template parameters. This lets users write plain C++17 `[](auto tag)` lambdas +// rather than C++20 template lambdas (`[]()`). + +// `TypeTag`: carries a C++ type. Recover with `typename +// decltype(tag)::type`. +template +struct TypeTag { + using type = T; +}; + +// `ValueTag`: carries a compile-time value. Recover with +// `decltype(tag)::value`. +template +struct ValueTag { + using value_type = decltype(v); + static constexpr auto value = v; +}; + +// ----------------------------------------------------------------------------- +// List Queries +// ----------------------------------------------------------------------------- + +// Check at compile-time if a value exists within a construct (e.g., `List<>`). +// Example: `static_assert(ContainsValue)`; +template +struct Contains; + +template +struct Contains, value> + : std::disjunction...> {}; + +template +inline constexpr bool ContainsValue = Contains::value; + +// Check at compile-time if a type `T` is present in a variadic list of types +// `Ts`. +// Example: `static_assert(IsTypeInList)`; +template +inline constexpr bool IsTypeInList = (std::is_same_v || ...); + +// Trait to detect whether `T` is a `List<...>` specialization. +template +struct IsListType : std::false_type {}; + +template +struct IsListType> : std::true_type {}; + +// ----------------------------------------------------------------------------- +// List Operations +// ----------------------------------------------------------------------------- + +// Concatenates two List types into a single `List`. +// Example: `ConcatType, List<3, 4>>` is `List<1, 2, 3, 4>`. +template +struct Concat; + +template +struct Concat, List> { + using type = List; +}; + +template +using ConcatType = typename Concat::type; + +template +struct Flatten; + +template +struct Flatten> { + using type = List; +}; + +template +struct Flatten { + using type = typename Flatten, Rest...>::type; +}; + +// ----------------------------------------------------------------------------- +// Invocability Detection (SFINAE) +// ----------------------------------------------------------------------------- + +// Checks if a `Functor` can be called with a `ValueTag` and `Args...`. +template +struct IsInvocable : std::false_type {}; + +template +struct IsInvocable()( + ValueTag{}, std::declval()...))>, + Args...> : std::true_type {}; + +template +inline constexpr bool IsInvocableValue = + IsInvocable::value; + +// ----------------------------------------------------------------------------- +// Filtering Logic +// ----------------------------------------------------------------------------- + +// Recursive template to filter values based on `Functor` support at +// compile-time. +template +struct Filter; + +// Base case: All values processed. +template +struct Filter, List> { + using type = List; +}; + +// Recursive step: Test the `head` value and accumulate if supported. +template +struct Filter, List, head, tail...> { + using type = typename std::conditional_t< + IsInvocableValue && + !ContainsValue, head>, + Filter, List, tail...>, + Filter, List, tail...>>::type; +}; + +// Interface to filter a `List` type directly. +template +struct FilterList; + +template +struct FilterList, List> { + using type = + typename Filter, List<>, items...>::type; +}; + +} // namespace infini::rt + +#endif diff --git a/src/hash.h b/src/hash.h new file mode 100644 index 0000000..b3a6598 --- /dev/null +++ b/src/hash.h @@ -0,0 +1,12 @@ +#ifndef INFINI_RT_HASH_H_ +#define INFINI_RT_HASH_H_ + +#include + +template +inline void HashCombine(std::size_t& seed, const T& v) { + std::hash> hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +#endif From db218ff327097db1ac3b13d7a56ff36f287aa267 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:09:05 +0000 Subject: [PATCH 03/22] feat: add device abstraction --- src/device.h | 134 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/device.h diff --git a/src/device.h b/src/device.h new file mode 100644 index 0000000..d4d9fcb --- /dev/null +++ b/src/device.h @@ -0,0 +1,134 @@ +#ifndef INFINI_RT_DEVICE_H_ +#define INFINI_RT_DEVICE_H_ + +#include "common/constexpr_map.h" +#include "common/traits.h" +#include "hash.h" + +namespace infini::rt { + +class Device { + public: + enum class Type { + kCpu = 0, + kNvidia = 1, + kCambricon = 2, + kAscend = 3, + kMetax = 4, + kMoore = 5, + kIluvatar = 6, + kKunlun = 7, + kHygon = 8, + kQy = 9, + kCount + }; + + Device() = default; + + Device(const Type& type, const int& index = 0) : type_{type}, index_{index} {} + + static const Type TypeFromString(const std::string& name) { + return kDescToDevice.at(name); + } + + static const std::string_view StringFromType(const Type& type) { + return kDeviceToDesc.at(type); + } + + const Type& type() const { return type_; } + + const int& index() const { return index_; } + + std::string ToString() const { + return std::string{StringFromType(type_)} + ":" + std::to_string(index_); + } + + bool operator==(const Device& other) const { + return type_ == other.type_ && index_ == other.index_; + } + + bool operator!=(const Device& other) const { return !(*this == other); } + + private: + Type type_{Type::kCpu}; + + static constexpr ConstexprMap(Device::Type::kCount)> + kDeviceToDesc{{{ + {Type::kCpu, "cpu"}, + {Type::kNvidia, "nvidia"}, + {Type::kCambricon, "cambricon"}, + {Type::kAscend, "ascend"}, + {Type::kMetax, "metax"}, + {Type::kMoore, "moore"}, + {Type::kIluvatar, "iluvatar"}, + {Type::kKunlun, "kunlun"}, + {Type::kHygon, "hygon"}, + {Type::kQy, "qy"}, + }}}; + + static constexpr ConstexprMap(Device::Type::kCount)> + kDescToDevice{{{ + {"cpu", Type::kCpu}, + {"nvidia", Type::kNvidia}, + {"cambricon", Type::kCambricon}, + {"ascend", Type::kAscend}, + {"metax", Type::kMetax}, + {"moore", Type::kMoore}, + {"iluvatar", Type::kIluvatar}, + {"kunlun", Type::kKunlun}, + {"hygon", Type::kHygon}, + {"qy", Type::kQy}, + }}}; + + int index_{0}; +}; + +// Primary template: Devices are disabled by default. Platform-specific +// headers (e.g. `cpu/device_.h`) specialize this to `std::true_type`. +template +struct DeviceEnabled : std::false_type {}; + +// Defines the common categories of devices using List. +using AllDeviceTypes = + List; + +// Deferred computation of active devices. The `Filter` and `FilterList` +// evaluation are nested inside a class template so that `DeviceEnabled` +// specializations from platform `device_.h` headers are visible at +// instantiation time. Use with a dependent type parameter +// (e.g. `ActiveDevices`) to ensure deferred instantiation. +template +struct ActiveDevicesImpl { + struct Filter { + template + std::enable_if_t::value> operator()( + ValueTag) const {} + }; + + using type = typename FilterList, AllDeviceTypes>::type; +}; + +template +using ActiveDevices = typename ActiveDevicesImpl::type; + +} // namespace infini::rt + +template <> +struct std::hash { + std::size_t operator()(const infini::rt::Device& device) const { + std::size_t seed{0}; + + HashCombine(seed, device.type()); + + HashCombine(seed, device.index()); + + return seed; + } +}; + +#endif From 4d20cffdb98174860eb4d93d40a22912806d6857 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:10:09 +0000 Subject: [PATCH 04/22] feat: add data type system --- src/data_type.h | 211 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 src/data_type.h diff --git a/src/data_type.h b/src/data_type.h new file mode 100644 index 0000000..c71d540 --- /dev/null +++ b/src/data_type.h @@ -0,0 +1,211 @@ +#ifndef INFINI_RT_DATA_TYPE_H_ +#define INFINI_RT_DATA_TYPE_H_ + +#include +#include +#include + +#include "common/constexpr_map.h" +#include "common/traits.h" +#include "device.h" + +namespace infini::rt { + +enum class DataType : std::int8_t { + kInt8, + kInt16, + kInt32, + kInt64, + kUInt8, + kUInt16, + kUInt32, + kUInt64, + kFloat16, + kBFloat16, + kFloat32, + kFloat64 +}; + +constexpr ConstexprMap kDataTypeToSize{{{ + {DataType::kInt8, 1}, + {DataType::kInt16, 2}, + {DataType::kInt32, 4}, + {DataType::kInt64, 8}, + {DataType::kUInt8, 1}, + {DataType::kUInt16, 2}, + {DataType::kUInt32, 4}, + {DataType::kUInt64, 8}, + {DataType::kFloat16, 2}, + {DataType::kBFloat16, 2}, + {DataType::kFloat32, 4}, + {DataType::kFloat64, 8}, +}}}; + +constexpr ConstexprMap kDataTypeToDesc{{{ + {DataType::kInt8, "int8"}, + {DataType::kInt16, "int16"}, + {DataType::kInt32, "int32"}, + {DataType::kInt64, "int64"}, + {DataType::kUInt8, "uint8"}, + {DataType::kUInt16, "uint16"}, + {DataType::kUInt32, "uint32"}, + {DataType::kUInt64, "uint64"}, + {DataType::kFloat16, "float16"}, + {DataType::kBFloat16, "bfloat16"}, + {DataType::kFloat32, "float32"}, + {DataType::kFloat64, "float64"}, +}}}; + +constexpr ConstexprMap kStringToDataType{{{ + {"int8", DataType::kInt8}, + {"int16", DataType::kInt16}, + {"int32", DataType::kInt32}, + {"int64", DataType::kInt64}, + {"uint8", DataType::kUInt8}, + {"uint16", DataType::kUInt16}, + {"uint32", DataType::kUInt32}, + {"uint64", DataType::kUInt64}, + {"float16", DataType::kFloat16}, + {"bfloat16", DataType::kBFloat16}, + {"float32", DataType::kFloat32}, + {"float64", DataType::kFloat64}, +}}}; + +struct Float16 { + std::uint16_t bits; + + static inline Float16 FromFloat(float val) { + std::uint32_t f32; + std::memcpy(&f32, &val, sizeof(f32)); + std::uint16_t sign = (f32 >> 16) & 0x8000; + std::int32_t exponent = ((f32 >> 23) & 0xFF) - 127; + std::uint32_t mantissa = f32 & 0x7FFFFF; + + if (exponent >= 16) { + // NaN + if (exponent == 128 && mantissa != 0) { + return {static_cast(sign | 0x7E00)}; + } + // Inf + return {static_cast(sign | 0x7C00)}; + } else if (exponent >= -14) { + return {static_cast(sign | ((exponent + 15) << 10) | + (mantissa >> 13))}; + } else if (exponent >= -24) { + mantissa |= 0x800000; + mantissa >>= (-14 - exponent); + return {static_cast(sign | (mantissa >> 13))}; + } + // Too small for subnormal: return signed zero. + return {sign}; + } + + inline float ToFloat() const { + std::uint32_t sign = (bits & 0x8000) << 16; + std::int32_t exponent = (bits >> 10) & 0x1F; + std::uint32_t mantissa = bits & 0x3FF; + std::uint32_t f32_bits; + + if (exponent == 31) { + f32_bits = sign | 0x7F800000 | (mantissa << 13); + } else if (exponent == 0) { + if (mantissa == 0) { + f32_bits = sign; + } else { + exponent = -14; + while ((mantissa & 0x400) == 0) { + mantissa <<= 1; + exponent--; + } + mantissa &= 0x3FF; + f32_bits = sign | ((exponent + 127) << 23) | (mantissa << 13); + } + } else { + f32_bits = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); + } + + float result; + std::memcpy(&result, &f32_bits, sizeof(result)); + return result; + } +}; + +struct BFloat16 { + std::uint16_t bits; + + static inline BFloat16 FromFloat(float val) { + std::uint32_t bits32; + std::memcpy(&bits32, &val, sizeof(bits32)); + + const std::uint32_t rounding_bias = 0x00007FFF + ((bits32 >> 16) & 1); + std::uint16_t bf16_bits = + static_cast((bits32 + rounding_bias) >> 16); + return {bf16_bits}; + } + + inline float ToFloat() const { + std::uint32_t bits32 = static_cast(bits) << 16; + float result; + std::memcpy(&result, &bits32, sizeof(result)); + return result; + } +}; + +template +struct TypeMap; + +template +using TypeMapType = typename TypeMap::type; + +#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ + template \ + struct TypeMap { \ + using type = CPP_TYPE; \ + }; + +DEFINE_DATA_TYPE_MAPPING(kUInt8, std::uint8_t) +DEFINE_DATA_TYPE_MAPPING(kInt8, std::int8_t) +DEFINE_DATA_TYPE_MAPPING(kUInt16, std::uint16_t) +DEFINE_DATA_TYPE_MAPPING(kInt16, std::int16_t) +DEFINE_DATA_TYPE_MAPPING(kUInt32, std::uint32_t) +DEFINE_DATA_TYPE_MAPPING(kInt32, std::int32_t) +DEFINE_DATA_TYPE_MAPPING(kUInt64, std::uint64_t) +DEFINE_DATA_TYPE_MAPPING(kInt64, std::int64_t) +DEFINE_DATA_TYPE_MAPPING(kFloat32, float) +DEFINE_DATA_TYPE_MAPPING(kFloat64, double) +#undef DEFINE_DATA_TYPE_MAPPING + +// Checks whether a C++ type is the bfloat16 or float16 type for the given +// device. Full specializations for each device's float16/bfloat16 types are +// provided in the corresponding platform-specific device type headers. +template +inline constexpr bool IsBFloat16 = + std::is_same_v>; + +template +inline constexpr bool IsFP16 = + std::is_same_v>; + +// Defines the common categories of data types using List. +using FloatTypes = List; +using ReducedFloatTypes = List; +using IntTypes = + List; +using UIntTypes = List; + +using BitTypes8 = List; +using BitTypes16 = List; +using BitTypes32 = + List; +using BitTypes64 = + List; + +using AllFloatTypes = ConcatType; +using AllIntTypes = ConcatType; +using AllTypes = ConcatType; + +} // namespace infini::rt + +#endif From d0d25fe3a936e152d2aa20c60bb11fad874ba8e8 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:11:00 +0000 Subject: [PATCH 05/22] feat: add runtime CRTP interfaces --- src/cuda/runtime.h | 29 +++++++++++++++++++++ src/cuda/runtime_utils.h | 25 ++++++++++++++++++ src/runtime.h | 55 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 src/cuda/runtime.h create mode 100644 src/cuda/runtime_utils.h create mode 100644 src/runtime.h diff --git a/src/cuda/runtime.h b/src/cuda/runtime.h new file mode 100644 index 0000000..2b6cb32 --- /dev/null +++ b/src/cuda/runtime.h @@ -0,0 +1,29 @@ +#ifndef INFINI_RT_CUDA_RUNTIME_H_ +#define INFINI_RT_CUDA_RUNTIME_H_ + +#include + +#include "../runtime.h" + +namespace infini::rt { + +/// ## CUDA-like runtime interface enforcement via CRTP. +/// +/// `CudaRuntime` extends `DeviceRuntime` for backends that mirror +/// `cuda_runtime.h`-style memory copy APIs. +template +struct CudaRuntime : DeviceRuntime { + static constexpr bool Validate() { + DeviceRuntime::Validate(); + static_assert( + std::is_invocable_v, + "`Runtime::Memcpy` must be callable with " + "`(void*, const void*, size_t, MemcpyHostToDevice)`."); + return true; + } +}; + +} // namespace infini::rt + +#endif diff --git a/src/cuda/runtime_utils.h b/src/cuda/runtime_utils.h new file mode 100644 index 0000000..f85eace --- /dev/null +++ b/src/cuda/runtime_utils.h @@ -0,0 +1,25 @@ +#ifndef INFINI_RT_CUDA_RUNTIME_UTILS_H_ +#define INFINI_RT_CUDA_RUNTIME_UTILS_H_ + +#include "device.h" + +namespace infini::rt { + +template +struct RuntimeUtils; + +template +struct CudaRuntimeUtils { + static int GetOptimalBlockSize() { + int max_threads = QueryMaxThreadsPerBlockFn(); + if (max_threads >= 2048) return 2048; + if (max_threads >= 1024) return 1024; + if (max_threads >= 512) return 512; + if (max_threads >= 256) return 256; + return 128; + } +}; + +} // namespace infini::rt + +#endif diff --git a/src/runtime.h b/src/runtime.h new file mode 100644 index 0000000..839477b --- /dev/null +++ b/src/runtime.h @@ -0,0 +1,55 @@ +#ifndef INFINI_RT_RUNTIME_H_ +#define INFINI_RT_RUNTIME_H_ + +#include + +#include "device.h" + +namespace infini::rt { + +template +struct Runtime; + +/// ## Interface enforcement via CRTP. +/// +/// Inherit from the appropriate base to declare which interface level a +/// `Runtime` specialization implements. After the struct is fully defined, call +/// `static_assert(Runtime<...>::Validate())`. The chained `Validate()` checks +/// every required member's existence and signature at compile time, analogous +/// to how `override` catches signature mismatches for virtual functions. +/// +/// - `RuntimeBase`: `kDeviceType` only (e.g. CPU). +/// - `DeviceRuntime`: adds `Stream`, `Malloc`, and `Free` (e.g. Cambricon). + +/// Every Runtime must provide `static constexpr Device::Type kDeviceType`. +template +struct RuntimeBase { + static constexpr bool Validate() { + static_assert( + std::is_same_v, + Device::Type>, + "`Runtime` must define `static constexpr Device::Type kDeviceType`."); + return true; + } +}; + +/// Runtimes with device memory must additionally provide `Stream`, `Malloc`, +/// and `Free`. +template +struct DeviceRuntime : RuntimeBase { + static constexpr bool Validate() { + RuntimeBase::Validate(); + static_assert(sizeof(typename Derived::Stream) > 0, + "`Runtime` must define a `Stream` type alias."); + static_assert( + std::is_invocable_v, + "`Runtime::Malloc` must be callable with `(void**, size_t)`."); + static_assert(std::is_invocable_v, + "`Runtime::Free` must be callable with `(void*)`."); + return true; + } +}; + +} // namespace infini::rt + +#endif From 4d89719ae6d5bcca86d2d855edca8dd9bce92506 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:12:45 +0000 Subject: [PATCH 06/22] feat: add `Handle` and `Config` --- src/config.h | 22 ++++++++++++++++++++++ src/handle.h | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 src/config.h create mode 100644 src/handle.h diff --git a/src/config.h b/src/config.h new file mode 100644 index 0000000..d877077 --- /dev/null +++ b/src/config.h @@ -0,0 +1,22 @@ +#ifndef INFINI_RT_CONFIG_H_ +#define INFINI_RT_CONFIG_H_ + +#include + +namespace infini::rt { + +class Config { + public: + std::size_t implementation_index() const { return implementation_index_; } + + void set_implementation_index(std::size_t implementation_index) { + implementation_index_ = implementation_index; + } + + private: + std::size_t implementation_index_{0}; +}; + +} // namespace infini::rt + +#endif diff --git a/src/handle.h b/src/handle.h new file mode 100644 index 0000000..2be7bc7 --- /dev/null +++ b/src/handle.h @@ -0,0 +1,36 @@ +#ifndef INFINI_RT_HANDLE_H_ +#define INFINI_RT_HANDLE_H_ + +#include + +namespace infini::rt { + +class Handle { + public: + void* stream() const { return stream_; } + + void* workspace() const { return workspace_; } + + std::size_t workspace_size_in_bytes() const { + return workspace_size_in_bytes_; + } + + void set_stream(void* stream) { stream_ = stream; } + + void set_workspace(void* workspace) { workspace_ = workspace; } + + void set_workspace_size_in_bytes(std::size_t workspace_size_in_bytes) { + workspace_size_in_bytes_ = workspace_size_in_bytes; + } + + private: + void* stream_{nullptr}; + + void* workspace_{nullptr}; + + std::size_t workspace_size_in_bytes_{0}; +}; + +} // namespace infini::rt + +#endif From b8e51a70513a949a47121438b061dbcdc174982d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:13:55 +0000 Subject: [PATCH 07/22] Revert "feat: add `Handle` and `Config`" This reverts commit 4d89719ae6d5bcca86d2d855edca8dd9bce92506. --- src/config.h | 22 ---------------------- src/handle.h | 36 ------------------------------------ 2 files changed, 58 deletions(-) delete mode 100644 src/config.h delete mode 100644 src/handle.h diff --git a/src/config.h b/src/config.h deleted file mode 100644 index d877077..0000000 --- a/src/config.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef INFINI_RT_CONFIG_H_ -#define INFINI_RT_CONFIG_H_ - -#include - -namespace infini::rt { - -class Config { - public: - std::size_t implementation_index() const { return implementation_index_; } - - void set_implementation_index(std::size_t implementation_index) { - implementation_index_ = implementation_index; - } - - private: - std::size_t implementation_index_{0}; -}; - -} // namespace infini::rt - -#endif diff --git a/src/handle.h b/src/handle.h deleted file mode 100644 index 2be7bc7..0000000 --- a/src/handle.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef INFINI_RT_HANDLE_H_ -#define INFINI_RT_HANDLE_H_ - -#include - -namespace infini::rt { - -class Handle { - public: - void* stream() const { return stream_; } - - void* workspace() const { return workspace_; } - - std::size_t workspace_size_in_bytes() const { - return workspace_size_in_bytes_; - } - - void set_stream(void* stream) { stream_ = stream; } - - void set_workspace(void* workspace) { workspace_ = workspace; } - - void set_workspace_size_in_bytes(std::size_t workspace_size_in_bytes) { - workspace_size_in_bytes_ = workspace_size_in_bytes; - } - - private: - void* stream_{nullptr}; - - void* workspace_{nullptr}; - - std::size_t workspace_size_in_bytes_{0}; -}; - -} // namespace infini::rt - -#endif From a66d15f9078007b2857e728b3149c04bc5fee093 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:20:08 +0000 Subject: [PATCH 08/22] feat: add `DispatchFunc` --- src/dispatcher.h | 341 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 src/dispatcher.h diff --git a/src/dispatcher.h b/src/dispatcher.h new file mode 100644 index 0000000..e277e5d --- /dev/null +++ b/src/dispatcher.h @@ -0,0 +1,341 @@ +#ifndef INFINI_RT_DISPATCHER_H_ +#define INFINI_RT_DISPATCHER_H_ + +#include +#include +#include +#include + +#include "common/traits.h" +#include "data_type.h" +#include "device.h" + +namespace infini::rt { + +// ----------------------------------------------------------------------------- +// Core Generic Runtime Dispatchers +// ----------------------------------------------------------------------------- + +namespace detail { + +// Implements the dispatch body over a resolved `List`. +template +auto DispatchFuncImpl(ValueType value, Functor&& func, + std::string_view context_str, List, + Args&&... args) { + using ReturnType = decltype(std::forward(func)( + ValueTag(head)>{}, std::forward(args)...)); + + // Path for void functions. + if constexpr (std::is_void_v) { + bool handled = ((value == static_cast(tail) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false) || + ... || + (value == static_cast(head) + ? (std::forward(func)( + ValueTag{}, std::forward(args)...), + true) + : false)); + + if (!handled) { + // TODO(lzm): change to logging. + std::cerr << "dispatch error (void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + } + } + // Path for non-void functions. + else { + std::optional result; + bool handled = ((value == static_cast(tail) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false) || + ... || + (value == static_cast(head) + ? (result.emplace(std::forward(func)( + ValueTag{}, std::forward(args)...)), + true) + : false)); + + if (handled) { + return *result; + } + // TODO(lzm): change to logging. + std::cerr << "dispatch error (non-void): value " << static_cast(value) + << " not supported in the context: " << context_str << "\n"; + std::abort(); + return ReturnType{}; + } +} + +// Deduces `head`/`tail` from a `List` type via partial specialization, +// then forwards to `DispatchFuncImpl`. +template +struct DispatchFuncUnwrap; + +template +struct DispatchFuncUnwrap, + std::tuple> { + static auto call(ValueType value, Functor&& func, + std::string_view context_str, Args&&... args) { + return DispatchFuncImpl(value, std::forward(func), context_str, + List{}, std::forward(args)...); + } +}; + +// Empty-list specialization +template +struct DispatchFuncUnwrap, std::tuple> { + static auto call(ValueType value, Functor&&, std::string_view context_str, + Args&&...) { + // TODO(lzm): change to logging. + std::cerr << "dispatch error: no allowed values registered for value " + << static_cast(value) + << " in the context: " << context_str << "\n"; + std::abort(); + } +}; + +} // namespace detail + +// (Single Dispatch) Dispatches a runtime value to a compile-time functor. +template +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + using FilteredPack = typename Filter, List<>, + all_values...>::type; + + return detail::DispatchFuncUnwrap< + ValueType, Functor, FilteredPack, + std::tuple>::call(value, std::forward(func), + context_str, std::forward(args)...); +} + +// (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time +// functor. +// Base Case: All Dimensions Resolved +template +auto DispatchFunc(const std::vector& values, size_t /*index*/, + Functor&& func, std::string_view /*context_str*/, + List, Args&&... args) { + return std::forward(func)(List{}, + std::forward(args)...); +} + +// Forward declaration of the recursive multi-dispatch overload. +template +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args); + +// Adapter used in the recursive multi-dispatch case: given a resolved value +// `val` recurse into the next dimension. +template +struct MultiDispatchRecurseAdapter; + +template +struct MultiDispatchRecurseAdapter, Functor, items...> { + const std::vector& values; + size_t next_index; + Functor& func; + std::string_view context_str; + + template + auto operator()(ValueTag, Args&&... args) const { + return DispatchFunc(values, next_index, func, context_str, + List{}, + std::forward(args)...); + } +}; + +template +auto MultiDispatchFirstDim(const std::vector& values, size_t index, + Functor& func, std::string_view context_str, + List, List, Args&&... args) { + static_assert(sizeof...(allowed) > 0, + "`DispatchFunc` dimension list is empty"); + using EnumType = std::common_type_t; + + MultiDispatchRecurseAdapter adapter{ + values, index + 1, func, context_str}; + + return DispatchFunc( + static_cast(values.at(index)), adapter, context_str, + std::forward(args)...); +} + +// (Multi-Dispatch) Recursive Case +template +auto DispatchFunc(const std::vector& values, size_t index, + Functor&& func, std::string_view context_str, List, + Args&&... args) { + return MultiDispatchFirstDim>( + values, index, func, context_str, List{}, FirstList{}, + std::forward(args)...); +} + +// ----------------------------------------------------------------------------- +// High-Level Specialized Dispatchers +// ----------------------------------------------------------------------------- +// These provide cleaner and more convenient APIs for common InfiniRT types. + +namespace detail { + +// Bridges the generic value dispatch layer to the `DataType`-specific type +// dispatch layer. +template +struct DataTypeAdapter { + Functor& func; + + template + auto operator()(ValueTag, Args&&... args) const { + using T = TypeMapType(dtype)>; + return func(TypeTag{}, std::forward(args)...); + } +}; + +template +struct DataTypeMultiAdapter { + Functor& func; + + template + auto operator()(List, Args&&... args) const { + return func(TypeTag(dtypes)>>{}..., + std::forward(args)...); + } +}; + +template +struct DeviceAdapter { + Functor& func; + + template + auto operator()(ValueTag, Args&&... args) const { + return func(ValueTag{}, std::forward(args)...); + } +}; + +template +struct DeviceMultiAdapter { + Functor& func; + + template + auto operator()(List, Args&&... args) const { + return func(ValueTag{}..., std::forward(args)...); + } +}; + +} // namespace detail + +// `DataType` Dispatch +template +auto DispatchFunc(DataType dtype, Functor&& func, + std::string_view context_str = "", Args&&... args) { + detail::DataTypeAdapter> adapter{func}; + return DispatchFunc(dtype, adapter, context_str, + std::forward(args)...); +} + +// `DataType` Multi-Dispatch +template +auto DispatchFunc(std::initializer_list dtypes, Functor&& func, + std::string_view context_str = "", Args&&... args) { + std::vector v; + for (auto d : dtypes) v.push_back(static_cast(d)); + + detail::DataTypeMultiAdapter> adapter{ + func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); +} + +// `Device` Dispatch +template +auto DispatchFunc(Device::Type device, Functor&& func, + std::string_view context_str = "", Args&&... args) { + detail::DeviceAdapter> adapter{func}; + return DispatchFunc(allowed_devices)...>( + device, adapter, context_str, std::forward(args)...); +} + +// `Device` Multi-Dispatch +template +auto DispatchFunc(std::initializer_list devices, Functor&& func, + std::string_view context_str = "", Args&&... args) { + std::vector v; + for (auto d : devices) v.push_back(static_cast(d)); + + detail::DeviceMultiAdapter> adapter{func}; + return DispatchFunc(v, 0, adapter, context_str, List<>{}, + std::forward(args)...); +} + +template +auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, + std::string_view context_str, List, + Args&&... args) { + return DispatchFunc>(items)...>( + value, std::forward(func), context_str, + std::forward(args)...); +} + +template +auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, + std::string_view context_str, List, + Args&&... args) { + return DispatchFunc>(items)...>( + value, std::forward(func), context_str, + std::forward(args)...); +} + +// Interface for Generic `List` Aliases (for non-DataType dispatch, e.g. Device) +template ::value>> +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return DispatchFuncListAliasImpl(value, std::forward(func), + context_str, ListType{}, + std::forward(args)...); +} + +// Interface for Generic `List` Aliases (for DataType dispatch with device type) +template ::value>> +auto DispatchFunc(ValueType value, Functor&& func, + std::string_view context_str = "", Args&&... args) { + return DispatchFuncListAliasImpl(value, std::forward(func), + context_str, ListType{}, + std::forward(args)...); +} + +// Interface for Any `int64_t`-Convertible Types +template +auto DispatchFunc(std::initializer_list keys, Functor&& func, + std::string_view context_str = "", Args&&... args) { + std::vector v_keys(keys); + return DispatchFunc(v_keys, 0, std::forward(func), + context_str, List<>{}, + std::forward(args)...); +} + +} // namespace infini::rt + +#endif From 6c2e5c96ca4e0c8e3753cf324533d52e6bcaa807 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:21:25 +0000 Subject: [PATCH 09/22] feat: add `Tensor` --- src/tensor.cc | 154 +++++++++++++++++++++++++++++++++++++++++++++++++ src/tensor.h | 157 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 src/tensor.cc create mode 100644 src/tensor.h diff --git a/src/tensor.cc b/src/tensor.cc new file mode 100644 index 0000000..13f5521 --- /dev/null +++ b/src/tensor.cc @@ -0,0 +1,154 @@ +#include "tensor.h" + +#include +#include +#include + +#include "dispatcher.h" + +namespace infini::rt { + +static Tensor::Index GetEffectiveIndex(Tensor::Index index, Tensor::Size size) { + return index < 0 ? index + size : index; +} + +Tensor::Tensor(void* data, std::initializer_list shape, + const DataType& dtype, const Device& device, + std::initializer_list strides) + : Tensor{data, decltype(shape_){shape}, dtype, device, + decltype(strides_){strides}} {} + +Tensor Tensor::operator[](const Index& index) const { + return { + reinterpret_cast( + reinterpret_cast(data_) + + GetEffectiveIndex(index, shape_[0]) * strides_[0] * element_size()), + Shape{shape_.cbegin() + 1, shape_.cend()}, dtype_, device_, + Strides{strides_.cbegin() + 1, strides_.cend()}}; +} + +void*& Tensor::data() { return data_; } + +const void* Tensor::data() const { return data_; } + +const Tensor::Shape& Tensor::shape() const { return shape_; } + +const DataType& Tensor::dtype() const { return dtype_; } + +const Device& Tensor::device() const { return device_; } + +const Tensor::Strides& Tensor::strides() const { return strides_; } + +Tensor::Size Tensor::size(const Index& index) const { + return shape_[GetEffectiveIndex(index, shape_.size())]; +} + +Tensor::Stride Tensor::stride(const Index& index) const { + return strides_[GetEffectiveIndex(index, strides_.size())]; +} + +Tensor::Size Tensor::ndim() const { return shape_.size(); } + +Tensor::Size Tensor::element_size() const { return kDataTypeToSize.at(dtype_); } + +Tensor::Size Tensor::numel() const { + return std::accumulate(shape_.begin(), shape_.end(), + static_cast(1), + [](Tensor::Size a, Tensor::Size b) { return a * b; }); +} + +Tensor Tensor::T() const { + return {data_, + {shape_[1], shape_[0]}, + dtype_, + device_, + {strides_[1], strides_[0]}}; +} + +std::string Tensor::ToString() const { + return "tensor(" + ToStringHelper() + + ", dtype=" + std::string(kDataTypeToDesc.at(dtype_)) + ", device='" + + device_.ToString() + "')"; +} + +bool Tensor::HasBroadcastDim() const { + return std::any_of(shape_.begin(), shape_.end(), + [&, i = 0](const auto&) mutable { + return shape_[i] != 1 && strides_[i++] == 0; + }); +} + +bool Tensor::IsContiguous() const { + if (ndim() == 0) { + return true; + } + + if (!IsMergeable(0, ndim() - 1)) { + return false; + } + + return stride(ndim() - 1) == 1; +} + +const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } + +Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } + +Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { + if (shape.empty()) { + return {}; + } + + Strides strides(shape.size()); + + strides.back() = 1; + + for (auto i{shape.size() - 2}; i != -1; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + + return strides; +} + +std::string Tensor::ToStringHelper() const { + if (ndim() == 0) { + return DispatchFunc>( + dtype_, + [&](auto tag) { + using T = typename decltype(tag)::type; + return std::to_string(*static_cast(data_)); + }, + "Tensor::ToStringHelper()"); + } + + std::string result{"["}; + + for (auto i{Index{0}}; i < shape_[0]; ++i) { + result += operator[](i).ToStringHelper() + ", "; + } + + result.pop_back(); + result.back() = ']'; + + return result; +} + +bool Tensor::IsMergeable(Tensor::Size dim_start, Tensor::Size dim_end) const { + if (dim_start == dim_end) { + return true; + } + + for (Tensor::Size i = dim_start; i < dim_end; ++i) { + if (size(i) == 1 && stride(i) == 0) { + return false; + } + if (stride(i) != size(i + 1) * stride(i + 1)) { + return false; + } + } + + return true; +} + +} // namespace infini::rt diff --git a/src/tensor.h b/src/tensor.h new file mode 100644 index 0000000..c8a51c4 --- /dev/null +++ b/src/tensor.h @@ -0,0 +1,157 @@ +#ifndef INFINI_RT_TENSOR_H_ +#define INFINI_RT_TENSOR_H_ + +#include +#include +#include + +#include "data_type.h" +#include "device.h" +#include "hash.h" + +namespace infini::rt { + +class Tensor { + public: + using Size = std::size_t; + + using Stride = std::ptrdiff_t; + + using Index = Stride; + + using Shape = std::vector; + + using Strides = std::vector; + + template + Tensor(void* data, const Shape& shape) + : data_{data}, + shape_{shape}, + dtype_{DefaultDataType()}, + device_{DefaultDevice()}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + device_{DefaultDevice()}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const Device& device) + : data_{data}, + shape_{shape}, + dtype_{DefaultDataType()}, + device_{device}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype, + const Device& device) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + device_{device}, + strides_{DefaultStrides(shape)} {} + + template + Tensor(void* data, const Shape& shape, const DataType& dtype, + const Device& device, const Strides& strides) + : data_{data}, + shape_{shape}, + dtype_{dtype}, + device_{device}, + strides_{strides} {} + + Tensor(void* data, std::initializer_list shape, const DataType& dtype, + const Device& device, std::initializer_list strides); + + Tensor operator[](const Index& index) const; + + void*& data(); + + const void* data() const; + + const DataType& dtype() const; + + const Device& device() const; + + const Shape& shape() const; + + const Strides& strides() const; + + Size size(const Index& index) const; + + Stride stride(const Index& index) const; + + Size ndim() const; + + Size element_size() const; + + Size numel() const; + + Tensor T() const; + + std::string ToString() const; + + bool HasBroadcastDim() const; + + bool IsContiguous() const; + + private: + static const DataType DefaultDataType(); + + static Device DefaultDevice(); + + static Strides DefaultStrides(const Shape& shape); + + std::string ToStringHelper() const; + + bool IsMergeable(Size dim_start, Size dim_end) const; + + void* data_{nullptr}; + + Shape shape_; + + const DataType dtype_; + + Device device_; + + Strides strides_; +}; + +} // namespace infini::rt + +template <> +struct std::hash { + std::size_t operator()(const infini::rt::Tensor& tensor) const { + std::size_t seed{0}; + + for (const auto& size : tensor.shape()) { + HashCombine(seed, size); + } + + HashCombine(seed, tensor.dtype()); + + HashCombine(seed, tensor.device()); + + for (const auto& stride : tensor.strides()) { + HashCombine(seed, stride); + } + + return seed; + } +}; + +template <> +struct std::equal_to { + bool operator()(const infini::rt::Tensor& a, + const infini::rt::Tensor& b) const { + return a.dtype() == b.dtype() && a.device() == b.device() && + a.shape() == b.shape() && a.strides() == b.strides(); + } +}; + +#endif From 45b7a977a6d6c3605bfa0c43d4e747fc455b2660 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:21:48 +0000 Subject: [PATCH 10/22] feat: add `Caster` template declaration --- src/caster.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/caster.h diff --git a/src/caster.h b/src/caster.h new file mode 100644 index 0000000..58612a9 --- /dev/null +++ b/src/caster.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_CASTER_H_ +#define INFINI_RT_CASTER_H_ + +#include "device.h" + +namespace infini::rt { + +template +struct Caster; + +} // namespace infini::rt + +#endif From 81b4a8c04b2b010f24d6c98c9d57f97734a86454 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:23:42 +0000 Subject: [PATCH 11/22] feat: add CPU runtime --- src/cpu/caster_.h | 74 ++++++++++++++++++++++++++++++++++++++++++++ src/cpu/data_type_.h | 21 +++++++++++++ src/cpu/device_.h | 13 ++++++++ src/cpu/runtime_.h | 34 ++++++++++++++++++++ 4 files changed, 142 insertions(+) create mode 100644 src/cpu/caster_.h create mode 100644 src/cpu/data_type_.h create mode 100644 src/cpu/device_.h create mode 100644 src/cpu/runtime_.h diff --git a/src/cpu/caster_.h b/src/cpu/caster_.h new file mode 100644 index 0000000..15f081b --- /dev/null +++ b/src/cpu/caster_.h @@ -0,0 +1,74 @@ +#ifndef INFINI_RT_CPU_CASTER__H_ +#define INFINI_RT_CPU_CASTER__H_ + +#include + +#include "caster.h" +#include "cpu/data_type_.h" + +namespace infini::rt { + +template <> +struct Caster { + template + static Dst Cast(Src&& x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureDst = std::remove_cv_t>; + using PureSrc = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } + + constexpr bool src_is_custom = IsBFloat16 || + IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || + IsFP16; + + if constexpr (!src_is_custom && !dst_is_custom) { + return static_cast(std::forward(x)); + } else { + return FromFloatHelper(ToFloatHelper(std::forward(x))); + } + } + + private: + template + struct HasToFloat : std::false_type {}; + + template + struct HasToFloat().ToFloat())>> + : std::true_type {}; + + template + struct HasFromFloat : std::false_type {}; + + template + struct HasFromFloat< + T, std::void_t()))>> + : std::true_type {}; + + template + static constexpr float ToFloatHelper(T&& x) { + if constexpr (HasToFloat::value) { + return std::forward(x).ToFloat(); + } else { + return static_cast(x); + } + } + + template + static constexpr PureDst FromFloatHelper(float f) { + if constexpr (HasFromFloat::value) { + return PureDst::FromFloat(f); + } else { + return static_cast(f); + } + } +}; + +} // namespace infini::rt + +#endif diff --git a/src/cpu/data_type_.h b/src/cpu/data_type_.h new file mode 100644 index 0000000..dd6c080 --- /dev/null +++ b/src/cpu/data_type_.h @@ -0,0 +1,21 @@ +#ifndef INFINI_RT_CPU_DATA_TYPE__H_ +#define INFINI_RT_CPU_DATA_TYPE__H_ + +#include "cpu/device_.h" +#include "data_type.h" + +namespace infini::rt { + +template <> +struct TypeMap { + using type = Float16; +}; + +template <> +struct TypeMap { + using type = BFloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/cpu/device_.h b/src/cpu/device_.h new file mode 100644 index 0000000..78d4899 --- /dev/null +++ b/src/cpu/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_CPU_DEVICE__H_ +#define INFINI_RT_CPU_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/cpu/runtime_.h b/src/cpu/runtime_.h new file mode 100644 index 0000000..29219b8 --- /dev/null +++ b/src/cpu/runtime_.h @@ -0,0 +1,34 @@ +#ifndef INFINI_RT_CPU_RUNTIME__H_ +#define INFINI_RT_CPU_RUNTIME__H_ + +#include +#include + +#include "runtime.h" + +namespace infini::rt { + +template <> +struct Runtime : RuntimeBase> { + static constexpr Device::Type kDeviceType = Device::Type::kCpu; + + static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } + + static void Free(void* ptr) { std::free(ptr); } + + static void Memcpy(void* dst, const void* src, std::size_t size, int) { + std::memcpy(dst, src, size); + } + + static constexpr auto Memset = std::memset; + + static constexpr int MemcpyHostToDevice = 0; + + static constexpr int MemcpyDeviceToHost = 1; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif From 32207ec43ed29b0b42946effcab2f2e80e25220b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:24:40 +0000 Subject: [PATCH 12/22] feat: add NVIDIA runtime --- src/nvidia/data_type_.h | 30 ++++++++++++++++++++++++++ src/nvidia/device_.h | 13 +++++++++++ src/nvidia/device_property.h | 42 ++++++++++++++++++++++++++++++++++++ src/nvidia/runtime_.h | 42 ++++++++++++++++++++++++++++++++++++ src/nvidia/runtime_utils.h | 15 +++++++++++++ 5 files changed, 142 insertions(+) create mode 100644 src/nvidia/data_type_.h create mode 100644 src/nvidia/device_.h create mode 100644 src/nvidia/device_property.h create mode 100644 src/nvidia/runtime_.h create mode 100644 src/nvidia/runtime_utils.h diff --git a/src/nvidia/data_type_.h b/src/nvidia/data_type_.h new file mode 100644 index 0000000..e9afbcd --- /dev/null +++ b/src/nvidia/data_type_.h @@ -0,0 +1,30 @@ +#ifndef INFINI_RT_NVIDIA_DATA_TYPE__H_ +#define INFINI_RT_NVIDIA_DATA_TYPE__H_ + +// clang-format off +#include +#include +// clang-format on + +#include "data_type.h" +#include "nvidia/device_.h" + +namespace infini::rt { + +using cuda_bfloat16 = nv_bfloat16; + +using cuda_bfloat162 = nv_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __nv_bfloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/nvidia/device_.h b/src/nvidia/device_.h new file mode 100644 index 0000000..b89a4e9 --- /dev/null +++ b/src/nvidia/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_NVIDIA_DEVICE__H_ +#define INFINI_RT_NVIDIA_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/nvidia/device_property.h b/src/nvidia/device_property.h new file mode 100644 index 0000000..2557cb3 --- /dev/null +++ b/src/nvidia/device_property.h @@ -0,0 +1,42 @@ +#ifndef INFINI_RT_NVIDIA_DEVICE_PROPERTY_H_ +#define INFINI_RT_NVIDIA_DEVICE_PROPERTY_H_ + +#include + +#include +#include + +namespace infini::rt { + +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + assert(device_id >= 0 && device_id < static_cast(cache.size())); + return cache[device_id]; + } +}; + +inline int QueryMaxThreadsPerBlock() { + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} + +} // namespace infini::rt + +#endif diff --git a/src/nvidia/runtime_.h b/src/nvidia/runtime_.h new file mode 100644 index 0000000..f10cc6d --- /dev/null +++ b/src/nvidia/runtime_.h @@ -0,0 +1,42 @@ +#ifndef INFINI_RT_NVIDIA_RUNTIME__H_ +#define INFINI_RT_NVIDIA_RUNTIME__H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/runtime.h" +#include "nvidia/device_.h" +#include "nvidia/runtime_utils.h" + +namespace infini::rt { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + + static constexpr auto Malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto Free = cudaFree; + + static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + + static constexpr auto Memset = cudaMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif diff --git a/src/nvidia/runtime_utils.h b/src/nvidia/runtime_utils.h new file mode 100644 index 0000000..783f71a --- /dev/null +++ b/src/nvidia/runtime_utils.h @@ -0,0 +1,15 @@ +#ifndef INFINI_RT_NVIDIA_RUNTIME_UTILS_H_ +#define INFINI_RT_NVIDIA_RUNTIME_UTILS_H_ + +#include "cuda/runtime_utils.h" +#include "nvidia/device_property.h" + +namespace infini::rt { + +template <> +struct RuntimeUtils + : CudaRuntimeUtils {}; + +} // namespace infini::rt + +#endif From 5df088da1daee7309b4d0a9641b48f5ead37c342 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:25:13 +0000 Subject: [PATCH 13/22] feat: add Cambricon runtime --- src/cambricon/data_type_.h | 23 +++++++++++++++++++++++ src/cambricon/device_.h | 13 +++++++++++++ src/cambricon/runtime_.h | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/cambricon/data_type_.h create mode 100644 src/cambricon/device_.h create mode 100644 src/cambricon/runtime_.h diff --git a/src/cambricon/data_type_.h b/src/cambricon/data_type_.h new file mode 100644 index 0000000..f1b8574 --- /dev/null +++ b/src/cambricon/data_type_.h @@ -0,0 +1,23 @@ +#ifndef INFINI_RT_CAMBRICON_DATA_TYPE__H_ +#define INFINI_RT_CAMBRICON_DATA_TYPE__H_ + +#include "bang_bf16.h" +#include "bang_fp16.h" +#include "cambricon/device_.h" +#include "data_type.h" + +namespace infini::rt { + +template <> +struct TypeMap { + using type = __half; +}; + +template <> +struct TypeMap { + using type = __bang_bfloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/cambricon/device_.h b/src/cambricon/device_.h new file mode 100644 index 0000000..30bda29 --- /dev/null +++ b/src/cambricon/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_CAMBRICON_DEVICE__H_ +#define INFINI_RT_CAMBRICON_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/cambricon/runtime_.h b/src/cambricon/runtime_.h new file mode 100644 index 0000000..b6ff200 --- /dev/null +++ b/src/cambricon/runtime_.h @@ -0,0 +1,35 @@ +#ifndef INFINI_RT_CAMBRICON_RUNTIME__H_ +#define INFINI_RT_CAMBRICON_RUNTIME__H_ + +#include + +#include "cambricon/device_.h" +#include "runtime.h" + +namespace infini::rt { + +template <> +struct Runtime + : DeviceRuntime> { + using Stream = cnrtQueue_t; + + static constexpr Device::Type kDeviceType = Device::Type::kCambricon; + + static constexpr auto Malloc = cnrtMalloc; + + static constexpr auto Free = cnrtFree; + + static constexpr auto Memcpy = cnrtMemcpy; + + static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev; + + static constexpr auto MemcpyDeviceToHost = cnrtMemcpyDevToHost; + + static constexpr auto Memset = cnrtMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif From c77a73f3e70d3d68c77162e063a933b33756f6c2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:25:51 +0000 Subject: [PATCH 14/22] feat: add Ascend runtime --- src/ascend/data_type_.h | 61 +++++++++++++++++++++++++++++++++++++++++ src/ascend/device_.h | 13 +++++++++ src/ascend/runtime_.h | 44 +++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 src/ascend/data_type_.h create mode 100644 src/ascend/device_.h create mode 100644 src/ascend/runtime_.h diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h new file mode 100644 index 0000000..d83a393 --- /dev/null +++ b/src/ascend/data_type_.h @@ -0,0 +1,61 @@ +#ifndef INFINI_RT_ASCEND_DATA_TYPE__H_ +#define INFINI_RT_ASCEND_DATA_TYPE__H_ + +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "data_type.h" + +namespace infini::rt::ascend { + +inline aclDataType ToAclDtype(DataType dt) { + switch (dt) { + case DataType::kInt8: + return ACL_INT8; + case DataType::kInt16: + return ACL_INT16; + case DataType::kInt32: + return ACL_INT32; + case DataType::kInt64: + return ACL_INT64; + case DataType::kUInt8: + return ACL_UINT8; + case DataType::kUInt16: + return ACL_UINT16; + case DataType::kUInt32: + return ACL_UINT32; + case DataType::kUInt64: + return ACL_UINT64; + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; + default: + assert(false && "Unsupported dtype for Ascend backend."); + return ACL_DT_UNDEFINED; + } +} + +// Returns true for integer (signed or unsigned) `DataType` values. +inline bool IsIntegerDtype(DataType dt) { + switch (dt) { + case DataType::kInt8: + case DataType::kInt16: + case DataType::kInt32: + case DataType::kInt64: + case DataType::kUInt8: + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return true; + default: + return false; + } +} + +} // namespace infini::rt::ascend + +#endif diff --git a/src/ascend/device_.h b/src/ascend/device_.h new file mode 100644 index 0000000..ffdec5a --- /dev/null +++ b/src/ascend/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_ASCEND_DEVICE__H_ +#define INFINI_RT_ASCEND_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/ascend/runtime_.h b/src/ascend/runtime_.h new file mode 100644 index 0000000..32595f6 --- /dev/null +++ b/src/ascend/runtime_.h @@ -0,0 +1,44 @@ +#ifndef INFINI_RT_ASCEND_RUNTIME__H_ +#define INFINI_RT_ASCEND_RUNTIME__H_ + +// clang-format off +#include "acl/acl.h" +// clang-format on + +#include "ascend/device_.h" +#include "runtime.h" + +namespace infini::rt { + +template <> +struct Runtime + : DeviceRuntime> { + using Stream = aclrtStream; + + static constexpr Device::Type kDeviceType = Device::Type::kAscend; + + static constexpr auto Malloc = [](void** ptr, size_t size) { + return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + }; + + static constexpr auto Free = aclrtFree; + + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, + aclrtMemcpyKind kind) { + return aclrtMemcpy(dst, count, src, count, kind); + }; + + static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + + static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + + static constexpr auto Memset = [](void* ptr, int value, size_t count) { + return aclrtMemset(ptr, count, value, count); + }; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif From 913b47f2be339ae72fcaddf75d7c3f84df91ed9b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:26:36 +0000 Subject: [PATCH 15/22] feat: add MetaX runtime --- src/metax/data_type_.h | 29 +++++++++++++++++++++++++++++ src/metax/device_.h | 13 +++++++++++++ src/metax/device_property.h | 11 +++++++++++ src/metax/runtime_.h | 36 ++++++++++++++++++++++++++++++++++++ src/metax/runtime_utils.h | 15 +++++++++++++++ 5 files changed, 104 insertions(+) create mode 100644 src/metax/data_type_.h create mode 100644 src/metax/device_.h create mode 100644 src/metax/device_property.h create mode 100644 src/metax/runtime_.h create mode 100644 src/metax/runtime_utils.h diff --git a/src/metax/data_type_.h b/src/metax/data_type_.h new file mode 100644 index 0000000..3c1b932 --- /dev/null +++ b/src/metax/data_type_.h @@ -0,0 +1,29 @@ +#ifndef INFINI_RT_METAX_DATA_TYPE__H_ +#define INFINI_RT_METAX_DATA_TYPE__H_ + +#include +#include +#include + +#include "data_type.h" +#include "metax/device_.h" + +namespace infini::rt { + +using cuda_bfloat16 = maca_bfloat16; + +using cuda_bfloat162 = maca_bfloat162; + +template <> +struct TypeMap { + using type = __half; +}; + +template <> +struct TypeMap { + using type = __maca_bfloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/metax/device_.h b/src/metax/device_.h new file mode 100644 index 0000000..4fc8825 --- /dev/null +++ b/src/metax/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_METAX_DEVICE__H_ +#define INFINI_RT_METAX_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/metax/device_property.h b/src/metax/device_property.h new file mode 100644 index 0000000..5ceaed6 --- /dev/null +++ b/src/metax/device_property.h @@ -0,0 +1,11 @@ +#ifndef INFINI_RT_METAX_DEVICE_PROPERTY_H_ +#define INFINI_RT_METAX_DEVICE_PROPERTY_H_ + +namespace infini::rt { + +// TODO: Add MCR device properties query for Metax. +inline int QueryMaxThreadsPerBlock() { return 256; } + +} // namespace infini::rt + +#endif diff --git a/src/metax/runtime_.h b/src/metax/runtime_.h new file mode 100644 index 0000000..885d8f1 --- /dev/null +++ b/src/metax/runtime_.h @@ -0,0 +1,36 @@ +#ifndef INFINI_RT_METAX_RUNTIME__H_ +#define INFINI_RT_METAX_RUNTIME__H_ + +#include + +#include "cuda/runtime.h" +#include "metax/device_.h" +#include "metax/runtime_utils.h" + +namespace infini::rt { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = mcStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kMetax; + + static constexpr auto Malloc = mcMalloc; + + static constexpr auto Memcpy = mcMemcpy; + + static constexpr auto Free = mcFree; + + static constexpr auto MemcpyHostToDevice = mcMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = mcMemcpyDeviceToHost; + + static constexpr auto Memset = mcMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif diff --git a/src/metax/runtime_utils.h b/src/metax/runtime_utils.h new file mode 100644 index 0000000..2527124 --- /dev/null +++ b/src/metax/runtime_utils.h @@ -0,0 +1,15 @@ +#ifndef INFINI_RT_METAX_RUNTIME_UTILS_H_ +#define INFINI_RT_METAX_RUNTIME_UTILS_H_ + +#include "cuda/runtime_utils.h" +#include "metax/device_property.h" + +namespace infini::rt { + +template <> +struct RuntimeUtils + : CudaRuntimeUtils {}; + +} // namespace infini::rt + +#endif From a140f75c573e3a463c741b6e55b024553f41a326 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:27:21 +0000 Subject: [PATCH 16/22] feat: add Moore runtime --- src/moore/data_type_.h | 28 +++++++++++++++++++++++ src/moore/device_.h | 13 +++++++++++ src/moore/device_property.h | 18 +++++++++++++++ src/moore/runtime_.h | 44 +++++++++++++++++++++++++++++++++++++ src/moore/runtime_utils.h | 15 +++++++++++++ 5 files changed, 118 insertions(+) create mode 100644 src/moore/data_type_.h create mode 100644 src/moore/device_.h create mode 100644 src/moore/device_property.h create mode 100644 src/moore/runtime_.h create mode 100644 src/moore/runtime_utils.h diff --git a/src/moore/data_type_.h b/src/moore/data_type_.h new file mode 100644 index 0000000..0fdeab9 --- /dev/null +++ b/src/moore/data_type_.h @@ -0,0 +1,28 @@ +#ifndef INFINI_RT_MOORE_DATA_TYPE__H_ +#define INFINI_RT_MOORE_DATA_TYPE__H_ + +#include +#include + +#include "data_type.h" +#include "moore/device_.h" + +namespace infini::rt { + +using cuda_bfloat16 = __mt_bfloat16; + +using cuda_bfloat162 = __mt_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __mt_bfloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/moore/device_.h b/src/moore/device_.h new file mode 100644 index 0000000..2bb52a9 --- /dev/null +++ b/src/moore/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_MOORE_DEVICE__H_ +#define INFINI_RT_MOORE_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/moore/device_property.h b/src/moore/device_property.h new file mode 100644 index 0000000..c9eac81 --- /dev/null +++ b/src/moore/device_property.h @@ -0,0 +1,18 @@ +#ifndef INFINI_RT_MOORE_DEVICE_PROPERTY_H_ +#define INFINI_RT_MOORE_DEVICE_PROPERTY_H_ + +#include + +namespace infini::rt { + +inline int QueryMaxThreadsPerBlock() { + int device = 0; + musaGetDevice(&device); + musaDeviceProp prop; + musaGetDeviceProperties(&prop, device); + return prop.maxThreadsPerBlock; +} + +} // namespace infini::rt + +#endif diff --git a/src/moore/runtime_.h b/src/moore/runtime_.h new file mode 100644 index 0000000..076f436 --- /dev/null +++ b/src/moore/runtime_.h @@ -0,0 +1,44 @@ +#ifndef INFINI_RT_MOORE_RUNTIME__H_ +#define INFINI_RT_MOORE_RUNTIME__H_ + +#include + +#include + +#include "cuda/runtime.h" +#include "moore/device_.h" +#include "moore/runtime_utils.h" + +namespace infini::rt { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = musaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kMoore; + + static constexpr auto Malloc = [](auto&&... args) { + return musaMalloc(std::forward(args)...); + }; + + static constexpr auto Memcpy = [](auto&&... args) { + return musaMemcpy(std::forward(args)...); + }; + + static constexpr auto Free = [](auto&&... args) { + return musaFree(std::forward(args)...); + }; + + static constexpr auto MemcpyHostToDevice = musaMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = musaMemcpyDeviceToHost; + + static constexpr auto Memset = musaMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif diff --git a/src/moore/runtime_utils.h b/src/moore/runtime_utils.h new file mode 100644 index 0000000..053146b --- /dev/null +++ b/src/moore/runtime_utils.h @@ -0,0 +1,15 @@ +#ifndef INFINI_RT_MOORE_RUNTIME_UTILS_H_ +#define INFINI_RT_MOORE_RUNTIME_UTILS_H_ + +#include "cuda/runtime_utils.h" +#include "moore/device_property.h" + +namespace infini::rt { + +template <> +struct RuntimeUtils + : CudaRuntimeUtils {}; + +} // namespace infini::rt + +#endif From 79d92f30056c49139eb48c47edf69db31f32a5f7 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 03:28:07 +0000 Subject: [PATCH 17/22] feat: add Iluvatar runtime --- src/iluvatar/data_type_.h | 30 ++++++++++++++++++++++++ src/iluvatar/device_.h | 13 +++++++++++ src/iluvatar/device_property.h | 42 ++++++++++++++++++++++++++++++++++ src/iluvatar/runtime_.h | 42 ++++++++++++++++++++++++++++++++++ src/iluvatar/runtime_utils.h | 15 ++++++++++++ 5 files changed, 142 insertions(+) create mode 100644 src/iluvatar/data_type_.h create mode 100644 src/iluvatar/device_.h create mode 100644 src/iluvatar/device_property.h create mode 100644 src/iluvatar/runtime_.h create mode 100644 src/iluvatar/runtime_utils.h diff --git a/src/iluvatar/data_type_.h b/src/iluvatar/data_type_.h new file mode 100644 index 0000000..4511394 --- /dev/null +++ b/src/iluvatar/data_type_.h @@ -0,0 +1,30 @@ +#ifndef INFINI_RT_ILUVATAR_DATA_TYPE__H_ +#define INFINI_RT_ILUVATAR_DATA_TYPE__H_ + +// clang-format off +#include +#include +// clang-format on + +#include "data_type.h" +#include "iluvatar/device_.h" + +namespace infini::rt { + +using cuda_bfloat16 = nv_bfloat16; + +using cuda_bfloat162 = nv_bfloat162; + +template <> +struct TypeMap { + using type = half; +}; + +template <> +struct TypeMap { + using type = __nv_bfloat16; +}; + +} // namespace infini::rt + +#endif diff --git a/src/iluvatar/device_.h b/src/iluvatar/device_.h new file mode 100644 index 0000000..dc0bb88 --- /dev/null +++ b/src/iluvatar/device_.h @@ -0,0 +1,13 @@ +#ifndef INFINI_RT_ILUVATAR_DEVICE__H_ +#define INFINI_RT_ILUVATAR_DEVICE__H_ + +#include "device.h" + +namespace infini::rt { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::rt + +#endif diff --git a/src/iluvatar/device_property.h b/src/iluvatar/device_property.h new file mode 100644 index 0000000..1605603 --- /dev/null +++ b/src/iluvatar/device_property.h @@ -0,0 +1,42 @@ +#ifndef INFINI_RT_ILUVATAR_DEVICE_PROPERTY_H_ +#define INFINI_RT_ILUVATAR_DEVICE_PROPERTY_H_ + +#include + +#include +#include + +namespace infini::rt { + +class DevicePropertyCache { + public: + static const cudaDeviceProp& GetCurrentDeviceProps() { + int device_id = 0; + cudaGetDevice(&device_id); + return GetDeviceProps(device_id); + } + + static const cudaDeviceProp& GetDeviceProps(int device_id) { + static std::vector cache = []() { + int count = 0; + cudaGetDeviceCount(&count); + if (count == 0) return std::vector{}; + std::vector props(count); + for (int i = 0; i < count; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + assert(device_id >= 0 && device_id < static_cast(cache.size())); + return cache[device_id]; + } +}; + +inline int QueryMaxThreadsPerBlock() { + return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; +} + +} // namespace infini::rt + +#endif diff --git a/src/iluvatar/runtime_.h b/src/iluvatar/runtime_.h new file mode 100644 index 0000000..db2eeaf --- /dev/null +++ b/src/iluvatar/runtime_.h @@ -0,0 +1,42 @@ +#ifndef INFINI_RT_ILUVATAR_RUNTIME__H_ +#define INFINI_RT_ILUVATAR_RUNTIME__H_ + +#include + +// clang-format off +#include +// clang-format on + +#include "cuda/runtime.h" +#include "iluvatar/device_.h" +#include "iluvatar/runtime_utils.h" + +namespace infini::rt { + +template <> +struct Runtime + : CudaRuntime> { + using Stream = cudaStream_t; + + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + + static constexpr auto Malloc = [](auto&&... args) { + return cudaMalloc(std::forward(args)...); + }; + + static constexpr auto Memcpy = cudaMemcpy; + + static constexpr auto Free = cudaFree; + + static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; + + static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + + static constexpr auto Memset = cudaMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::rt + +#endif diff --git a/src/iluvatar/runtime_utils.h b/src/iluvatar/runtime_utils.h new file mode 100644 index 0000000..fcda918 --- /dev/null +++ b/src/iluvatar/runtime_utils.h @@ -0,0 +1,15 @@ +#ifndef INFINI_RT_ILUVATAR_RUNTIME_UTILS_H_ +#define INFINI_RT_ILUVATAR_RUNTIME_UTILS_H_ + +#include "cuda/runtime_utils.h" +#include "iluvatar/device_property.h" + +namespace infini::rt { + +template <> +struct RuntimeUtils + : CudaRuntimeUtils {}; + +} // namespace infini::rt + +#endif From 49c90af91fa6de78be4f32e015398fea349f0fab Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 04:37:53 +0000 Subject: [PATCH 18/22] feat: add CMake build system --- CMakeLists.txt | 196 +++++++++++++++++++++++++++++++++++++++++++++ src/CMakeLists.txt | 96 ++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 src/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..445834b --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,196 @@ +cmake_minimum_required(VERSION 3.18) +project(InfiniRT LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Options for backends. +option(WITH_CPU "Enable CPU backend" OFF) +option(WITH_NVIDIA "Enable CUDA backend" OFF) +option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) +option(WITH_METAX "Enable MetaX backend" OFF) +option(WITH_CAMBRICON "Enable Cambricon backend" OFF) +option(WITH_MOORE "Enable Moore backend" OFF) +option(WITH_ASCEND "Enable Ascend backend" OFF) + +option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) + +if(AUTO_DETECT_DEVICES) + message(STATUS "Auto-detecting available devices...") + + set(WITH_CPU ON) + + file(GLOB NVIDIA_DEV_FILES "/dev/nvidia*") + + if(NVIDIA_DEV_FILES) + set(WITH_NVIDIA ON) + message(STATUS "Auto-detected NVIDIA environment.") + endif() + + file(GLOB ILUVATAR_DEV_FILES "/dev/iluvatar*") + + if(ILUVATAR_DEV_FILES) + set(WITH_ILUVATAR ON) + message(STATUS "Auto-detected Iluvatar environment.") + endif() + + if(DEFINED ENV{MACA_PATH}) + set(WITH_METAX ON) + message(STATUS "Auto-detected MetaX environment from MACA_PATH") + else() + execute_process( + COMMAND sh -c "grep -h 9999 /sys/bus/pci/devices/*/vendor 2>/dev/null" + OUTPUT_VARIABLE _pci_vendor_output + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + string(FIND "${_pci_vendor_output}" "9999" _found_pos) + + if(_found_pos GREATER -1) + set(WITH_METAX ON) + message(STATUS "Detected MetaX GPU from PCI vendor ID 0x9999") + else() + set(WITH_METAX OFF) + message(STATUS "No MetaX GPU detected") + endif() + endif() + + if(DEFINED ENV{NEUWARE_HOME}) + set(WITH_CAMBRICON ON) + message(STATUS "Auto-detected Cambricon environment.") + endif() + + if(DEFINED ENV{MUSA_ROOT} OR DEFINED ENV{MUSA_HOME} OR DEFINED ENV{MUSA_PATH}) + set(WITH_MOORE ON) + set(WITH_MOORE ON CACHE BOOL "Enable Moore backend" FORCE) + message(STATUS "Auto-detected Moore environment.") + else() + set(WITH_MOORE OFF) + set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE) + endif() + + if(DEFINED ENV{ASCEND_HOME_PATH} OR EXISTS "/dev/davinci0") + set(WITH_ASCEND ON) + message(STATUS "Auto-detected Ascend environment.") + endif() +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) + +# Only one CUDA-like GPU backend can be enabled at a time. +set(_gpu_backend_count 0) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE WITH_ASCEND) + if(${_gpu_backend}) + math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") + endif() +endforeach() + +if(_gpu_backend_count GREATER 1) + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") +endif() + +if(WITH_NVIDIA) + add_compile_definitions(WITH_NVIDIA=1) + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) +endif() + +# Iluvatar: CUDA-compatible device, uses `clang++` with `-x ivcore` (not `nvcc`). +if(WITH_ILUVATAR) + add_compile_definitions(WITH_ILUVATAR=1) + set(ILUVATAR_ARCH "ivcore20" CACHE STRING "Iluvatar GPU architecture") + find_program(CLANGXX NAMES clang++) + if(CLANGXX) + set(CMAKE_CUDA_COMPILER "${CLANGXX}" CACHE STRING "Iluvatar CUDA compiler (clang++)") + else() + set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)") + endif() + set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") + set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar") + message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}") + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) +endif() + +if(WITH_METAX) + add_compile_definitions(WITH_METAX=1) + + # Normally can be found at: `/opt/maca/`. + set(MACA_PATH $ENV{MACA_PATH}) + set(CMAKE_C_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh) + set(CMAKE_CXX_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh) + + include_directories("${MACA_PATH}/include") + link_directories("${MACA_PATH}/lib") + + # Libraries: mcruntime / mcdnn / mcblas. + find_library(MACA_RUNTIME_LIB NAMES mcruntime HINTS "${MACA_PATH}/lib" REQUIRED) +endif() + +if(WITH_MOORE) + add_compile_definitions(WITH_MOORE=1) + + set(MUSA_ROOT "") + foreach(_musa_env MUSA_ROOT MUSA_HOME MUSA_PATH) + if(NOT MUSA_ROOT AND DEFINED ENV{${_musa_env}} AND NOT "$ENV{${_musa_env}}" STREQUAL "") + set(MUSA_ROOT "$ENV{${_musa_env}}") + endif() + endforeach() + + if(NOT MUSA_ROOT AND EXISTS "/usr/local/musa") + set(MUSA_ROOT "/usr/local/musa") + endif() + + if(NOT MUSA_ROOT) + message(FATAL_ERROR "`WITH_MOORE` is `ON` but `MUSA_ROOT`/`MUSA_HOME`/`MUSA_PATH` is not set and `/usr/local/musa` was not found.") + endif() + + if(NOT EXISTS "${MUSA_ROOT}/bin/mcc") + message(FATAL_ERROR "Could not find `mcc` under `${MUSA_ROOT}/bin`.") + endif() + + message(STATUS "Using Moore from `${MUSA_ROOT}`.") + + set(CMAKE_C_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mcc_wrapper.sh) + set(CMAKE_CXX_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mcc_wrapper.sh) + + include_directories("${MUSA_ROOT}/include") + link_directories("${MUSA_ROOT}/lib") + + find_library(MUSA_LIB NAMES musa HINTS "${MUSA_ROOT}/lib" REQUIRED) + find_library(MUSART_LIB NAMES musart HINTS "${MUSA_ROOT}/lib" REQUIRED) +endif() + +if(WITH_CAMBRICON) + add_compile_definitions(WITH_CAMBRICON=1) + set(NEUWARE_HOME $ENV{NEUWARE_HOME}) + + include_directories("${NEUWARE_HOME}/include") + link_directories("${NEUWARE_HOME}/lib") + link_directories("${NEUWARE_HOME}/lib64") + + # Libraries: `cnrt`. + find_library(CAMBRICON_RUNTIME_LIB NAMES cnrt HINTS "${NEUWARE_HOME}/lib64" REQUIRED) +endif() + +if(WITH_ASCEND) + add_compile_definitions(WITH_ASCEND=1) + if(NOT DEFINED ASCEND_HOME) + if(DEFINED ENV{ASCEND_HOME_PATH} AND NOT "$ENV{ASCEND_HOME_PATH}" STREQUAL "") + set(ASCEND_HOME "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Ascend toolkit root") + else() + set(ASCEND_HOME "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Ascend toolkit root") + endif() + endif() + if(NOT EXISTS "${ASCEND_HOME}") + message(FATAL_ERROR "`WITH_ASCEND` is ON but `${ASCEND_HOME}` was not found. Set ASCEND_HOME_PATH.") + endif() + message(STATUS "Using Ascend from `${ASCEND_HOME}`.") +endif() + +# If all other platforms are not enabled, CPU is enabled by default. +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) + add_compile_definitions(WITH_CPU=1) +endif() + +add_subdirectory(src) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..fc540ec --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,96 @@ +add_library(infinirt SHARED) + +file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") +target_sources(infinirt PRIVATE ${BASE_SRCS}) + +if(WITH_CPU) + target_compile_definitions(infinirt PUBLIC WITH_CPU=1) + + find_package(OpenMP REQUIRED) + target_link_libraries(infinirt PRIVATE OpenMP::OpenMP_CXX) +endif() + +if(WITH_NVIDIA) + enable_language(CUDA) + + target_compile_definitions(infinirt PUBLIC WITH_NVIDIA=1) + + find_package(CUDAToolkit REQUIRED) + target_link_libraries(infinirt PUBLIC CUDA::cudart) + + set_target_properties(infinirt PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) +endif() + +if(WITH_ILUVATAR) + enable_language(CUDA) + + target_compile_definitions(infinirt PUBLIC WITH_ILUVATAR=1) + + find_package(CUDAToolkit REQUIRED) + target_link_libraries(infinirt PUBLIC CUDA::cudart) + + set_target_properties(infinirt PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) +endif() + +if(WITH_METAX) + target_compile_definitions(infinirt PRIVATE WITH_METAX=1) + + target_include_directories(infinirt PUBLIC "${MACA_PATH}/include") + target_link_libraries(infinirt PUBLIC ${MACA_RUNTIME_LIB}) +endif() + +if(WITH_MOORE) + target_compile_definitions(infinirt PRIVATE WITH_MOORE=1) + + target_include_directories(infinirt PUBLIC "${MUSA_ROOT}/include") + target_link_libraries(infinirt PUBLIC ${MUSA_LIB} ${MUSART_LIB}) +endif() + +if(WITH_CAMBRICON) + target_compile_definitions(infinirt PRIVATE WITH_CAMBRICON=1) + + target_include_directories(infinirt PUBLIC "${NEUWARE_HOME}/include") + target_link_libraries(infinirt PUBLIC ${CAMBRICON_RUNTIME_LIB}) +endif() + +if(WITH_ASCEND) + # ASCEND_HOME is set by the top-level CMakeLists.txt. + target_compile_definitions(infinirt PUBLIC WITH_ASCEND=1) + + # Resolve the driver lib dir two levels above the toolkit root. + get_filename_component(ASCEND_ROOT "${ASCEND_HOME}/../.." ABSOLUTE) + + # Prefer the real driver HAL; fall back to the toolkit stub for build-only + # environments (e.g., Docker CI images without hardware drivers installed). + # CANN <= 8.0: stub at runtime/lib64/stub/; CANN >= 8.5: devlib/-linux/devlib/. + set(ASCEND_HAL_REAL "${ASCEND_ROOT}/driver/lib64/driver/libascend_hal.so") + set(ASCEND_HAL_STUB "${ASCEND_HOME}/runtime/lib64/stub/libascend_hal.so") + set(ASCEND_HAL_DEVLIB "${ASCEND_HOME}/${CMAKE_SYSTEM_PROCESSOR}-linux/devlib/libascend_hal.so") + if(EXISTS "${ASCEND_HAL_REAL}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_REAL}") + elseif(EXISTS "${ASCEND_HAL_STUB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_STUB}") + message(STATUS "ascend_hal: driver not found, using stub for linking") + elseif(EXISTS "${ASCEND_HAL_DEVLIB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_DEVLIB}") + message(STATUS "ascend_hal: driver not found, using devlib for linking") + else() + message(FATAL_ERROR "libascend_hal.so not found (tried ${ASCEND_HAL_REAL}, ${ASCEND_HAL_STUB}, and ${ASCEND_HAL_DEVLIB})") + endif() + + target_include_directories(infinirt PUBLIC + "${ASCEND_HOME}/include" + "${ASCEND_HOME}/include/aclnn" + "${ASCEND_HOME}/include/aclnnop") + target_link_libraries(infinirt PUBLIC + "${ASCEND_HOME}/lib64/libascendcl.so" + "${ASCEND_HAL_LIB}") +endif() + +target_include_directories(infinirt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) From 9a88bf6ab577a1b1308fba8bcd346b4b0e33f2a3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 05:00:52 +0000 Subject: [PATCH 19/22] chore: add `build/` to `.gitignore` --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d4fb281..37e0bb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Generated files +build/ + # Prerequisites *.d From 3ebd561411fd32ca598bcf6d482e3bccb99fc673 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 08:11:51 +0000 Subject: [PATCH 20/22] refactor: remove project-specific utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove `RuntimeUtils`, `CudaRuntimeUtils`, `DevicePropertyCache`, `DispatchFunc`, `Caster`, and `generic_utils` — these are operator-level utilities, not pure runtime APIs. Rewrite `Tensor::ToStringHelper()` to use a switch instead of `DispatchFunc`. --- src/caster.h | 13 -- src/common/generic_utils.h | 26 --- src/cpu/caster_.h | 74 ------- src/cuda/runtime_utils.h | 25 --- src/dispatcher.h | 341 --------------------------------- src/iluvatar/device_property.h | 42 ---- src/iluvatar/runtime_.h | 1 - src/iluvatar/runtime_utils.h | 15 -- src/metax/device_property.h | 11 -- src/metax/runtime_.h | 1 - src/metax/runtime_utils.h | 15 -- src/moore/device_property.h | 18 -- src/moore/runtime_.h | 1 - src/moore/runtime_utils.h | 15 -- src/nvidia/device_property.h | 42 ---- src/nvidia/runtime_.h | 1 - src/nvidia/runtime_utils.h | 15 -- src/tensor.cc | 38 +++- 18 files changed, 28 insertions(+), 666 deletions(-) delete mode 100644 src/caster.h delete mode 100644 src/common/generic_utils.h delete mode 100644 src/cpu/caster_.h delete mode 100644 src/cuda/runtime_utils.h delete mode 100644 src/dispatcher.h delete mode 100644 src/iluvatar/device_property.h delete mode 100644 src/iluvatar/runtime_utils.h delete mode 100644 src/metax/device_property.h delete mode 100644 src/metax/runtime_utils.h delete mode 100644 src/moore/device_property.h delete mode 100644 src/moore/runtime_utils.h delete mode 100644 src/nvidia/device_property.h delete mode 100644 src/nvidia/runtime_utils.h diff --git a/src/caster.h b/src/caster.h deleted file mode 100644 index 58612a9..0000000 --- a/src/caster.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef INFINI_RT_CASTER_H_ -#define INFINI_RT_CASTER_H_ - -#include "device.h" - -namespace infini::rt { - -template -struct Caster; - -} // namespace infini::rt - -#endif diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h deleted file mode 100644 index bda6cde..0000000 --- a/src/common/generic_utils.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef INFINI_RT_COMMON_GENERIC_UTILS_H_ -#define INFINI_RT_COMMON_GENERIC_UTILS_H_ - -#include - -namespace infini::rt::utils { - -std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, - const std::size_t* shape, - const std::ptrdiff_t* strides) { - std::size_t res = 0; - for (std::size_t i = ndim; i-- > 0;) { - res += (flat_index % shape[i]) * strides[i]; - flat_index /= shape[i]; - } - return res; -} - -template -constexpr auto CeilDiv(const X& x, const Y& y) { - return (x + y - 1) / y; -} - -} // namespace infini::rt::utils - -#endif diff --git a/src/cpu/caster_.h b/src/cpu/caster_.h deleted file mode 100644 index 15f081b..0000000 --- a/src/cpu/caster_.h +++ /dev/null @@ -1,74 +0,0 @@ -#ifndef INFINI_RT_CPU_CASTER__H_ -#define INFINI_RT_CPU_CASTER__H_ - -#include - -#include "caster.h" -#include "cpu/data_type_.h" - -namespace infini::rt { - -template <> -struct Caster { - template - static Dst Cast(Src&& x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureDst = std::remove_cv_t>; - using PureSrc = std::remove_cv_t>; - - if constexpr (std::is_same_v) { - return std::forward(x); - } - - constexpr bool src_is_custom = IsBFloat16 || - IsFP16; - constexpr bool dst_is_custom = IsBFloat16 || - IsFP16; - - if constexpr (!src_is_custom && !dst_is_custom) { - return static_cast(std::forward(x)); - } else { - return FromFloatHelper(ToFloatHelper(std::forward(x))); - } - } - - private: - template - struct HasToFloat : std::false_type {}; - - template - struct HasToFloat().ToFloat())>> - : std::true_type {}; - - template - struct HasFromFloat : std::false_type {}; - - template - struct HasFromFloat< - T, std::void_t()))>> - : std::true_type {}; - - template - static constexpr float ToFloatHelper(T&& x) { - if constexpr (HasToFloat::value) { - return std::forward(x).ToFloat(); - } else { - return static_cast(x); - } - } - - template - static constexpr PureDst FromFloatHelper(float f) { - if constexpr (HasFromFloat::value) { - return PureDst::FromFloat(f); - } else { - return static_cast(f); - } - } -}; - -} // namespace infini::rt - -#endif diff --git a/src/cuda/runtime_utils.h b/src/cuda/runtime_utils.h deleted file mode 100644 index f85eace..0000000 --- a/src/cuda/runtime_utils.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef INFINI_RT_CUDA_RUNTIME_UTILS_H_ -#define INFINI_RT_CUDA_RUNTIME_UTILS_H_ - -#include "device.h" - -namespace infini::rt { - -template -struct RuntimeUtils; - -template -struct CudaRuntimeUtils { - static int GetOptimalBlockSize() { - int max_threads = QueryMaxThreadsPerBlockFn(); - if (max_threads >= 2048) return 2048; - if (max_threads >= 1024) return 1024; - if (max_threads >= 512) return 512; - if (max_threads >= 256) return 256; - return 128; - } -}; - -} // namespace infini::rt - -#endif diff --git a/src/dispatcher.h b/src/dispatcher.h deleted file mode 100644 index e277e5d..0000000 --- a/src/dispatcher.h +++ /dev/null @@ -1,341 +0,0 @@ -#ifndef INFINI_RT_DISPATCHER_H_ -#define INFINI_RT_DISPATCHER_H_ - -#include -#include -#include -#include - -#include "common/traits.h" -#include "data_type.h" -#include "device.h" - -namespace infini::rt { - -// ----------------------------------------------------------------------------- -// Core Generic Runtime Dispatchers -// ----------------------------------------------------------------------------- - -namespace detail { - -// Implements the dispatch body over a resolved `List`. -template -auto DispatchFuncImpl(ValueType value, Functor&& func, - std::string_view context_str, List, - Args&&... args) { - using ReturnType = decltype(std::forward(func)( - ValueTag(head)>{}, std::forward(args)...)); - - // Path for void functions. - if constexpr (std::is_void_v) { - bool handled = ((value == static_cast(tail) - ? (std::forward(func)( - ValueTag{}, std::forward(args)...), - true) - : false) || - ... || - (value == static_cast(head) - ? (std::forward(func)( - ValueTag{}, std::forward(args)...), - true) - : false)); - - if (!handled) { - // TODO(lzm): change to logging. - std::cerr << "dispatch error (void): value " << static_cast(value) - << " not supported in the context: " << context_str << "\n"; - std::abort(); - } - } - // Path for non-void functions. - else { - std::optional result; - bool handled = ((value == static_cast(tail) - ? (result.emplace(std::forward(func)( - ValueTag{}, std::forward(args)...)), - true) - : false) || - ... || - (value == static_cast(head) - ? (result.emplace(std::forward(func)( - ValueTag{}, std::forward(args)...)), - true) - : false)); - - if (handled) { - return *result; - } - // TODO(lzm): change to logging. - std::cerr << "dispatch error (non-void): value " << static_cast(value) - << " not supported in the context: " << context_str << "\n"; - std::abort(); - return ReturnType{}; - } -} - -// Deduces `head`/`tail` from a `List` type via partial specialization, -// then forwards to `DispatchFuncImpl`. -template -struct DispatchFuncUnwrap; - -template -struct DispatchFuncUnwrap, - std::tuple> { - static auto call(ValueType value, Functor&& func, - std::string_view context_str, Args&&... args) { - return DispatchFuncImpl(value, std::forward(func), context_str, - List{}, std::forward(args)...); - } -}; - -// Empty-list specialization -template -struct DispatchFuncUnwrap, std::tuple> { - static auto call(ValueType value, Functor&&, std::string_view context_str, - Args&&...) { - // TODO(lzm): change to logging. - std::cerr << "dispatch error: no allowed values registered for value " - << static_cast(value) - << " in the context: " << context_str << "\n"; - std::abort(); - } -}; - -} // namespace detail - -// (Single Dispatch) Dispatches a runtime value to a compile-time functor. -template -auto DispatchFunc(ValueType value, Functor&& func, - std::string_view context_str = "", Args&&... args) { - using FilteredPack = typename Filter, List<>, - all_values...>::type; - - return detail::DispatchFuncUnwrap< - ValueType, Functor, FilteredPack, - std::tuple>::call(value, std::forward(func), - context_str, std::forward(args)...); -} - -// (Multi-Dispatch) Dispatches a vector of runtime values to a compile-time -// functor. -// Base Case: All Dimensions Resolved -template -auto DispatchFunc(const std::vector& values, size_t /*index*/, - Functor&& func, std::string_view /*context_str*/, - List, Args&&... args) { - return std::forward(func)(List{}, - std::forward(args)...); -} - -// Forward declaration of the recursive multi-dispatch overload. -template -auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, - Args&&... args); - -// Adapter used in the recursive multi-dispatch case: given a resolved value -// `val` recurse into the next dimension. -template -struct MultiDispatchRecurseAdapter; - -template -struct MultiDispatchRecurseAdapter, Functor, items...> { - const std::vector& values; - size_t next_index; - Functor& func; - std::string_view context_str; - - template - auto operator()(ValueTag, Args&&... args) const { - return DispatchFunc(values, next_index, func, context_str, - List{}, - std::forward(args)...); - } -}; - -template -auto MultiDispatchFirstDim(const std::vector& values, size_t index, - Functor& func, std::string_view context_str, - List, List, Args&&... args) { - static_assert(sizeof...(allowed) > 0, - "`DispatchFunc` dimension list is empty"); - using EnumType = std::common_type_t; - - MultiDispatchRecurseAdapter adapter{ - values, index + 1, func, context_str}; - - return DispatchFunc( - static_cast(values.at(index)), adapter, context_str, - std::forward(args)...); -} - -// (Multi-Dispatch) Recursive Case -template -auto DispatchFunc(const std::vector& values, size_t index, - Functor&& func, std::string_view context_str, List, - Args&&... args) { - return MultiDispatchFirstDim>( - values, index, func, context_str, List{}, FirstList{}, - std::forward(args)...); -} - -// ----------------------------------------------------------------------------- -// High-Level Specialized Dispatchers -// ----------------------------------------------------------------------------- -// These provide cleaner and more convenient APIs for common InfiniRT types. - -namespace detail { - -// Bridges the generic value dispatch layer to the `DataType`-specific type -// dispatch layer. -template -struct DataTypeAdapter { - Functor& func; - - template - auto operator()(ValueTag, Args&&... args) const { - using T = TypeMapType(dtype)>; - return func(TypeTag{}, std::forward(args)...); - } -}; - -template -struct DataTypeMultiAdapter { - Functor& func; - - template - auto operator()(List, Args&&... args) const { - return func(TypeTag(dtypes)>>{}..., - std::forward(args)...); - } -}; - -template -struct DeviceAdapter { - Functor& func; - - template - auto operator()(ValueTag, Args&&... args) const { - return func(ValueTag{}, std::forward(args)...); - } -}; - -template -struct DeviceMultiAdapter { - Functor& func; - - template - auto operator()(List, Args&&... args) const { - return func(ValueTag{}..., std::forward(args)...); - } -}; - -} // namespace detail - -// `DataType` Dispatch -template -auto DispatchFunc(DataType dtype, Functor&& func, - std::string_view context_str = "", Args&&... args) { - detail::DataTypeAdapter> adapter{func}; - return DispatchFunc(dtype, adapter, context_str, - std::forward(args)...); -} - -// `DataType` Multi-Dispatch -template -auto DispatchFunc(std::initializer_list dtypes, Functor&& func, - std::string_view context_str = "", Args&&... args) { - std::vector v; - for (auto d : dtypes) v.push_back(static_cast(d)); - - detail::DataTypeMultiAdapter> adapter{ - func}; - return DispatchFunc(v, 0, adapter, context_str, List<>{}, - std::forward(args)...); -} - -// `Device` Dispatch -template -auto DispatchFunc(Device::Type device, Functor&& func, - std::string_view context_str = "", Args&&... args) { - detail::DeviceAdapter> adapter{func}; - return DispatchFunc(allowed_devices)...>( - device, adapter, context_str, std::forward(args)...); -} - -// `Device` Multi-Dispatch -template -auto DispatchFunc(std::initializer_list devices, Functor&& func, - std::string_view context_str = "", Args&&... args) { - std::vector v; - for (auto d : devices) v.push_back(static_cast(d)); - - detail::DeviceMultiAdapter> adapter{func}; - return DispatchFunc(v, 0, adapter, context_str, List<>{}, - std::forward(args)...); -} - -template -auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, - std::string_view context_str, List, - Args&&... args) { - return DispatchFunc>(items)...>( - value, std::forward(func), context_str, - std::forward(args)...); -} - -template -auto DispatchFuncListAliasImpl(ValueType value, Functor&& func, - std::string_view context_str, List, - Args&&... args) { - return DispatchFunc>(items)...>( - value, std::forward(func), context_str, - std::forward(args)...); -} - -// Interface for Generic `List` Aliases (for non-DataType dispatch, e.g. Device) -template ::value>> -auto DispatchFunc(ValueType value, Functor&& func, - std::string_view context_str = "", Args&&... args) { - return DispatchFuncListAliasImpl(value, std::forward(func), - context_str, ListType{}, - std::forward(args)...); -} - -// Interface for Generic `List` Aliases (for DataType dispatch with device type) -template ::value>> -auto DispatchFunc(ValueType value, Functor&& func, - std::string_view context_str = "", Args&&... args) { - return DispatchFuncListAliasImpl(value, std::forward(func), - context_str, ListType{}, - std::forward(args)...); -} - -// Interface for Any `int64_t`-Convertible Types -template -auto DispatchFunc(std::initializer_list keys, Functor&& func, - std::string_view context_str = "", Args&&... args) { - std::vector v_keys(keys); - return DispatchFunc(v_keys, 0, std::forward(func), - context_str, List<>{}, - std::forward(args)...); -} - -} // namespace infini::rt - -#endif diff --git a/src/iluvatar/device_property.h b/src/iluvatar/device_property.h deleted file mode 100644 index 1605603..0000000 --- a/src/iluvatar/device_property.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef INFINI_RT_ILUVATAR_DEVICE_PROPERTY_H_ -#define INFINI_RT_ILUVATAR_DEVICE_PROPERTY_H_ - -#include - -#include -#include - -namespace infini::rt { - -class DevicePropertyCache { - public: - static const cudaDeviceProp& GetCurrentDeviceProps() { - int device_id = 0; - cudaGetDevice(&device_id); - return GetDeviceProps(device_id); - } - - static const cudaDeviceProp& GetDeviceProps(int device_id) { - static std::vector cache = []() { - int count = 0; - cudaGetDeviceCount(&count); - if (count == 0) return std::vector{}; - std::vector props(count); - for (int i = 0; i < count; ++i) { - cudaGetDeviceProperties(&props[i], i); - } - return props; - }(); - - assert(device_id >= 0 && device_id < static_cast(cache.size())); - return cache[device_id]; - } -}; - -inline int QueryMaxThreadsPerBlock() { - return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; -} - -} // namespace infini::rt - -#endif diff --git a/src/iluvatar/runtime_.h b/src/iluvatar/runtime_.h index db2eeaf..c442753 100644 --- a/src/iluvatar/runtime_.h +++ b/src/iluvatar/runtime_.h @@ -9,7 +9,6 @@ #include "cuda/runtime.h" #include "iluvatar/device_.h" -#include "iluvatar/runtime_utils.h" namespace infini::rt { diff --git a/src/iluvatar/runtime_utils.h b/src/iluvatar/runtime_utils.h deleted file mode 100644 index fcda918..0000000 --- a/src/iluvatar/runtime_utils.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef INFINI_RT_ILUVATAR_RUNTIME_UTILS_H_ -#define INFINI_RT_ILUVATAR_RUNTIME_UTILS_H_ - -#include "cuda/runtime_utils.h" -#include "iluvatar/device_property.h" - -namespace infini::rt { - -template <> -struct RuntimeUtils - : CudaRuntimeUtils {}; - -} // namespace infini::rt - -#endif diff --git a/src/metax/device_property.h b/src/metax/device_property.h deleted file mode 100644 index 5ceaed6..0000000 --- a/src/metax/device_property.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef INFINI_RT_METAX_DEVICE_PROPERTY_H_ -#define INFINI_RT_METAX_DEVICE_PROPERTY_H_ - -namespace infini::rt { - -// TODO: Add MCR device properties query for Metax. -inline int QueryMaxThreadsPerBlock() { return 256; } - -} // namespace infini::rt - -#endif diff --git a/src/metax/runtime_.h b/src/metax/runtime_.h index 885d8f1..2d33123 100644 --- a/src/metax/runtime_.h +++ b/src/metax/runtime_.h @@ -5,7 +5,6 @@ #include "cuda/runtime.h" #include "metax/device_.h" -#include "metax/runtime_utils.h" namespace infini::rt { diff --git a/src/metax/runtime_utils.h b/src/metax/runtime_utils.h deleted file mode 100644 index 2527124..0000000 --- a/src/metax/runtime_utils.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef INFINI_RT_METAX_RUNTIME_UTILS_H_ -#define INFINI_RT_METAX_RUNTIME_UTILS_H_ - -#include "cuda/runtime_utils.h" -#include "metax/device_property.h" - -namespace infini::rt { - -template <> -struct RuntimeUtils - : CudaRuntimeUtils {}; - -} // namespace infini::rt - -#endif diff --git a/src/moore/device_property.h b/src/moore/device_property.h deleted file mode 100644 index c9eac81..0000000 --- a/src/moore/device_property.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef INFINI_RT_MOORE_DEVICE_PROPERTY_H_ -#define INFINI_RT_MOORE_DEVICE_PROPERTY_H_ - -#include - -namespace infini::rt { - -inline int QueryMaxThreadsPerBlock() { - int device = 0; - musaGetDevice(&device); - musaDeviceProp prop; - musaGetDeviceProperties(&prop, device); - return prop.maxThreadsPerBlock; -} - -} // namespace infini::rt - -#endif diff --git a/src/moore/runtime_.h b/src/moore/runtime_.h index 076f436..c268a8a 100644 --- a/src/moore/runtime_.h +++ b/src/moore/runtime_.h @@ -7,7 +7,6 @@ #include "cuda/runtime.h" #include "moore/device_.h" -#include "moore/runtime_utils.h" namespace infini::rt { diff --git a/src/moore/runtime_utils.h b/src/moore/runtime_utils.h deleted file mode 100644 index 053146b..0000000 --- a/src/moore/runtime_utils.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef INFINI_RT_MOORE_RUNTIME_UTILS_H_ -#define INFINI_RT_MOORE_RUNTIME_UTILS_H_ - -#include "cuda/runtime_utils.h" -#include "moore/device_property.h" - -namespace infini::rt { - -template <> -struct RuntimeUtils - : CudaRuntimeUtils {}; - -} // namespace infini::rt - -#endif diff --git a/src/nvidia/device_property.h b/src/nvidia/device_property.h deleted file mode 100644 index 2557cb3..0000000 --- a/src/nvidia/device_property.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef INFINI_RT_NVIDIA_DEVICE_PROPERTY_H_ -#define INFINI_RT_NVIDIA_DEVICE_PROPERTY_H_ - -#include - -#include -#include - -namespace infini::rt { - -class DevicePropertyCache { - public: - static const cudaDeviceProp& GetCurrentDeviceProps() { - int device_id = 0; - cudaGetDevice(&device_id); - return GetDeviceProps(device_id); - } - - static const cudaDeviceProp& GetDeviceProps(int device_id) { - static std::vector cache = []() { - int count = 0; - cudaGetDeviceCount(&count); - if (count == 0) return std::vector{}; - std::vector props(count); - for (int i = 0; i < count; ++i) { - cudaGetDeviceProperties(&props[i], i); - } - return props; - }(); - - assert(device_id >= 0 && device_id < static_cast(cache.size())); - return cache[device_id]; - } -}; - -inline int QueryMaxThreadsPerBlock() { - return DevicePropertyCache::GetCurrentDeviceProps().maxThreadsPerBlock; -} - -} // namespace infini::rt - -#endif diff --git a/src/nvidia/runtime_.h b/src/nvidia/runtime_.h index f10cc6d..f3d815f 100644 --- a/src/nvidia/runtime_.h +++ b/src/nvidia/runtime_.h @@ -9,7 +9,6 @@ #include "cuda/runtime.h" #include "nvidia/device_.h" -#include "nvidia/runtime_utils.h" namespace infini::rt { diff --git a/src/nvidia/runtime_utils.h b/src/nvidia/runtime_utils.h deleted file mode 100644 index 783f71a..0000000 --- a/src/nvidia/runtime_utils.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef INFINI_RT_NVIDIA_RUNTIME_UTILS_H_ -#define INFINI_RT_NVIDIA_RUNTIME_UTILS_H_ - -#include "cuda/runtime_utils.h" -#include "nvidia/device_property.h" - -namespace infini::rt { - -template <> -struct RuntimeUtils - : CudaRuntimeUtils {}; - -} // namespace infini::rt - -#endif diff --git a/src/tensor.cc b/src/tensor.cc index 13f5521..68ac0e3 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -4,8 +4,6 @@ #include #include -#include "dispatcher.h" - namespace infini::rt { static Tensor::Index GetEffectiveIndex(Tensor::Index index, Tensor::Size size) { @@ -112,14 +110,34 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { std::string Tensor::ToStringHelper() const { if (ndim() == 0) { - return DispatchFunc>( - dtype_, - [&](auto tag) { - using T = typename decltype(tag)::type; - return std::to_string(*static_cast(data_)); - }, - "Tensor::ToStringHelper()"); + switch (dtype_) { + case DataType::kFloat16: + return std::to_string(static_cast(data_)->ToFloat()); + case DataType::kBFloat16: + return std::to_string(static_cast(data_)->ToFloat()); + case DataType::kFloat32: + return std::to_string(*static_cast(data_)); + case DataType::kFloat64: + return std::to_string(*static_cast(data_)); + case DataType::kInt8: + return std::to_string(*static_cast(data_)); + case DataType::kInt16: + return std::to_string(*static_cast(data_)); + case DataType::kInt32: + return std::to_string(*static_cast(data_)); + case DataType::kInt64: + return std::to_string(*static_cast(data_)); + case DataType::kUInt8: + return std::to_string(*static_cast(data_)); + case DataType::kUInt16: + return std::to_string(*static_cast(data_)); + case DataType::kUInt32: + return std::to_string(*static_cast(data_)); + case DataType::kUInt64: + return std::to_string(*static_cast(data_)); + default: + return "?"; + } } std::string result{"["}; From ef2febacd1dce93fc655418de770289c29988368 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 08:14:16 +0000 Subject: [PATCH 21/22] refactor: rename `Tensor` to `TensorView` The class holds a non-owning `void*` pointer with shape/strides metadata, which is a view rather than an owning tensor. --- src/{tensor.cc => tensor_view.cc} | 68 ++++++++++++++++--------------- src/{tensor.h => tensor_view.h} | 39 +++++++++--------- 2 files changed, 56 insertions(+), 51 deletions(-) rename src/{tensor.cc => tensor_view.cc} (60%) rename src/{tensor.h => tensor_view.h} (70%) diff --git a/src/tensor.cc b/src/tensor_view.cc similarity index 60% rename from src/tensor.cc rename to src/tensor_view.cc index 68ac0e3..ffc8e04 100644 --- a/src/tensor.cc +++ b/src/tensor_view.cc @@ -1,4 +1,4 @@ -#include "tensor.h" +#include "tensor_view.h" #include #include @@ -6,17 +6,18 @@ namespace infini::rt { -static Tensor::Index GetEffectiveIndex(Tensor::Index index, Tensor::Size size) { +static TensorView::Index GetEffectiveIndex(TensorView::Index index, + TensorView::Size size) { return index < 0 ? index + size : index; } -Tensor::Tensor(void* data, std::initializer_list shape, - const DataType& dtype, const Device& device, - std::initializer_list strides) - : Tensor{data, decltype(shape_){shape}, dtype, device, - decltype(strides_){strides}} {} +TensorView::TensorView(void* data, std::initializer_list shape, + const DataType& dtype, const Device& device, + std::initializer_list strides) + : TensorView{data, decltype(shape_){shape}, dtype, device, + decltype(strides_){strides}} {} -Tensor Tensor::operator[](const Index& index) const { +TensorView TensorView::operator[](const Index& index) const { return { reinterpret_cast( reinterpret_cast(data_) + @@ -25,37 +26,39 @@ Tensor Tensor::operator[](const Index& index) const { Strides{strides_.cbegin() + 1, strides_.cend()}}; } -void*& Tensor::data() { return data_; } +void*& TensorView::data() { return data_; } -const void* Tensor::data() const { return data_; } +const void* TensorView::data() const { return data_; } -const Tensor::Shape& Tensor::shape() const { return shape_; } +const TensorView::Shape& TensorView::shape() const { return shape_; } -const DataType& Tensor::dtype() const { return dtype_; } +const DataType& TensorView::dtype() const { return dtype_; } -const Device& Tensor::device() const { return device_; } +const Device& TensorView::device() const { return device_; } -const Tensor::Strides& Tensor::strides() const { return strides_; } +const TensorView::Strides& TensorView::strides() const { return strides_; } -Tensor::Size Tensor::size(const Index& index) const { +TensorView::Size TensorView::size(const Index& index) const { return shape_[GetEffectiveIndex(index, shape_.size())]; } -Tensor::Stride Tensor::stride(const Index& index) const { +TensorView::Stride TensorView::stride(const Index& index) const { return strides_[GetEffectiveIndex(index, strides_.size())]; } -Tensor::Size Tensor::ndim() const { return shape_.size(); } +TensorView::Size TensorView::ndim() const { return shape_.size(); } -Tensor::Size Tensor::element_size() const { return kDataTypeToSize.at(dtype_); } +TensorView::Size TensorView::element_size() const { + return kDataTypeToSize.at(dtype_); +} -Tensor::Size Tensor::numel() const { - return std::accumulate(shape_.begin(), shape_.end(), - static_cast(1), - [](Tensor::Size a, Tensor::Size b) { return a * b; }); +TensorView::Size TensorView::numel() const { + return std::accumulate( + shape_.begin(), shape_.end(), static_cast(1), + [](TensorView::Size a, TensorView::Size b) { return a * b; }); } -Tensor Tensor::T() const { +TensorView TensorView::T() const { return {data_, {shape_[1], shape_[0]}, dtype_, @@ -63,20 +66,20 @@ Tensor Tensor::T() const { {strides_[1], strides_[0]}}; } -std::string Tensor::ToString() const { +std::string TensorView::ToString() const { return "tensor(" + ToStringHelper() + ", dtype=" + std::string(kDataTypeToDesc.at(dtype_)) + ", device='" + device_.ToString() + "')"; } -bool Tensor::HasBroadcastDim() const { +bool TensorView::HasBroadcastDim() const { return std::any_of(shape_.begin(), shape_.end(), [&, i = 0](const auto&) mutable { return shape_[i] != 1 && strides_[i++] == 0; }); } -bool Tensor::IsContiguous() const { +bool TensorView::IsContiguous() const { if (ndim() == 0) { return true; } @@ -88,11 +91,11 @@ bool Tensor::IsContiguous() const { return stride(ndim() - 1) == 1; } -const DataType Tensor::DefaultDataType() { return DataType::kFloat32; } +const DataType TensorView::DefaultDataType() { return DataType::kFloat32; } -Device Tensor::DefaultDevice() { return Device{Device::Type::kCpu}; } +Device TensorView::DefaultDevice() { return Device{Device::Type::kCpu}; } -Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { +TensorView::Strides TensorView::DefaultStrides(const Shape& shape) { if (shape.empty()) { return {}; } @@ -108,7 +111,7 @@ Tensor::Strides Tensor::DefaultStrides(const Shape& shape) { return strides; } -std::string Tensor::ToStringHelper() const { +std::string TensorView::ToStringHelper() const { if (ndim() == 0) { switch (dtype_) { case DataType::kFloat16: @@ -152,12 +155,13 @@ std::string Tensor::ToStringHelper() const { return result; } -bool Tensor::IsMergeable(Tensor::Size dim_start, Tensor::Size dim_end) const { +bool TensorView::IsMergeable(TensorView::Size dim_start, + TensorView::Size dim_end) const { if (dim_start == dim_end) { return true; } - for (Tensor::Size i = dim_start; i < dim_end; ++i) { + for (TensorView::Size i = dim_start; i < dim_end; ++i) { if (size(i) == 1 && stride(i) == 0) { return false; } diff --git a/src/tensor.h b/src/tensor_view.h similarity index 70% rename from src/tensor.h rename to src/tensor_view.h index c8a51c4..e747d19 100644 --- a/src/tensor.h +++ b/src/tensor_view.h @@ -1,5 +1,5 @@ -#ifndef INFINI_RT_TENSOR_H_ -#define INFINI_RT_TENSOR_H_ +#ifndef INFINI_RT_TENSOR_VIEW_H_ +#define INFINI_RT_TENSOR_VIEW_H_ #include #include @@ -11,7 +11,7 @@ namespace infini::rt { -class Tensor { +class TensorView { public: using Size = std::size_t; @@ -24,7 +24,7 @@ class Tensor { using Strides = std::vector; template - Tensor(void* data, const Shape& shape) + TensorView(void* data, const Shape& shape) : data_{data}, shape_{shape}, dtype_{DefaultDataType()}, @@ -32,7 +32,7 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const DataType& dtype) + TensorView(void* data, const Shape& shape, const DataType& dtype) : data_{data}, shape_{shape}, dtype_{dtype}, @@ -40,7 +40,7 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const Device& device) + TensorView(void* data, const Shape& shape, const Device& device) : data_{data}, shape_{shape}, dtype_{DefaultDataType()}, @@ -48,8 +48,8 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const DataType& dtype, - const Device& device) + TensorView(void* data, const Shape& shape, const DataType& dtype, + const Device& device) : data_{data}, shape_{shape}, dtype_{dtype}, @@ -57,18 +57,19 @@ class Tensor { strides_{DefaultStrides(shape)} {} template - Tensor(void* data, const Shape& shape, const DataType& dtype, - const Device& device, const Strides& strides) + TensorView(void* data, const Shape& shape, const DataType& dtype, + const Device& device, const Strides& strides) : data_{data}, shape_{shape}, dtype_{dtype}, device_{device}, strides_{strides} {} - Tensor(void* data, std::initializer_list shape, const DataType& dtype, - const Device& device, std::initializer_list strides); + TensorView(void* data, std::initializer_list shape, + const DataType& dtype, const Device& device, + std::initializer_list strides); - Tensor operator[](const Index& index) const; + TensorView operator[](const Index& index) const; void*& data(); @@ -92,7 +93,7 @@ class Tensor { Size numel() const; - Tensor T() const; + TensorView T() const; std::string ToString() const; @@ -125,8 +126,8 @@ class Tensor { } // namespace infini::rt template <> -struct std::hash { - std::size_t operator()(const infini::rt::Tensor& tensor) const { +struct std::hash { + std::size_t operator()(const infini::rt::TensorView& tensor) const { std::size_t seed{0}; for (const auto& size : tensor.shape()) { @@ -146,9 +147,9 @@ struct std::hash { }; template <> -struct std::equal_to { - bool operator()(const infini::rt::Tensor& a, - const infini::rt::Tensor& b) const { +struct std::equal_to { + bool operator()(const infini::rt::TensorView& a, + const infini::rt::TensorView& b) const { return a.dtype() == b.dtype() && a.device() == b.device() && a.shape() == b.shape() && a.strides() == b.strides(); } From 11c6388e58ccf1d65c420229f7ec4421d508edd1 Mon Sep 17 00:00:00 2001 From: zhushuang Date: Wed, 3 Jun 2026 11:58:26 +0000 Subject: [PATCH 22/22] refactor InfiniOps cpu runtime through InfiniRT --- CMakeLists.txt | 1 + src/CMakeLists.txt | 13 +++++++++++++ src/device.h | 3 +++ src/infini_rt/cpu/data_type_.h | 6 ++++++ src/infini_rt/cpu/device_.h | 6 ++++++ src/infini_rt/cpu/runtime_.h | 6 ++++++ src/infini_rt/data_type.h | 6 ++++++ src/infini_rt/device.h | 6 ++++++ src/infini_rt/runtime.h | 6 ++++++ src/infini_rt/tensor_view.h | 6 ++++++ 10 files changed, 59 insertions(+) create mode 100644 src/infini_rt/cpu/data_type_.h create mode 100644 src/infini_rt/cpu/device_.h create mode 100644 src/infini_rt/cpu/runtime_.h create mode 100644 src/infini_rt/data_type.h create mode 100644 src/infini_rt/device.h create mode 100644 src/infini_rt/runtime.h create mode 100644 src/infini_rt/tensor_view.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 445834b..e92b202 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -190,6 +190,7 @@ endif() # If all other platforms are not enabled, CPU is enabled by default. if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) + set(WITH_CPU ON CACHE BOOL "Enable CPU backend" FORCE) add_compile_definitions(WITH_CPU=1) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fc540ec..6ca3924 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,7 @@ add_library(infinirt SHARED) +include(GNUInstallDirs) + file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") target_sources(infinirt PRIVATE ${BASE_SRCS}) @@ -94,3 +96,14 @@ if(WITH_ASCEND) endif() target_include_directories(infinirt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS infinirt + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/infinirt + FILES_MATCHING PATTERN "*.h" +) diff --git a/src/device.h b/src/device.h index d4d9fcb..7da2255 100644 --- a/src/device.h +++ b/src/device.h @@ -1,6 +1,9 @@ #ifndef INFINI_RT_DEVICE_H_ #define INFINI_RT_DEVICE_H_ +#include +#include + #include "common/constexpr_map.h" #include "common/traits.h" #include "hash.h" diff --git a/src/infini_rt/cpu/data_type_.h b/src/infini_rt/cpu/data_type_.h new file mode 100644 index 0000000..31d69df --- /dev/null +++ b/src/infini_rt/cpu/data_type_.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_CPU_DATA_TYPE__H_ +#define INFINI_RT_PUBLIC_CPU_DATA_TYPE__H_ + +#include "../../cpu/data_type_.h" + +#endif diff --git a/src/infini_rt/cpu/device_.h b/src/infini_rt/cpu/device_.h new file mode 100644 index 0000000..022d838 --- /dev/null +++ b/src/infini_rt/cpu/device_.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_CPU_DEVICE__H_ +#define INFINI_RT_PUBLIC_CPU_DEVICE__H_ + +#include "../../cpu/device_.h" + +#endif diff --git a/src/infini_rt/cpu/runtime_.h b/src/infini_rt/cpu/runtime_.h new file mode 100644 index 0000000..aab20b9 --- /dev/null +++ b/src/infini_rt/cpu/runtime_.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_CPU_RUNTIME__H_ +#define INFINI_RT_PUBLIC_CPU_RUNTIME__H_ + +#include "../../cpu/runtime_.h" + +#endif diff --git a/src/infini_rt/data_type.h b/src/infini_rt/data_type.h new file mode 100644 index 0000000..8fe5786 --- /dev/null +++ b/src/infini_rt/data_type.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_DATA_TYPE_H_ +#define INFINI_RT_PUBLIC_DATA_TYPE_H_ + +#include "../data_type.h" + +#endif diff --git a/src/infini_rt/device.h b/src/infini_rt/device.h new file mode 100644 index 0000000..9e9eb8b --- /dev/null +++ b/src/infini_rt/device.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_DEVICE_H_ +#define INFINI_RT_PUBLIC_DEVICE_H_ + +#include "../device.h" + +#endif diff --git a/src/infini_rt/runtime.h b/src/infini_rt/runtime.h new file mode 100644 index 0000000..ab6d557 --- /dev/null +++ b/src/infini_rt/runtime.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_RUNTIME_H_ +#define INFINI_RT_PUBLIC_RUNTIME_H_ + +#include "../runtime.h" + +#endif diff --git a/src/infini_rt/tensor_view.h b/src/infini_rt/tensor_view.h new file mode 100644 index 0000000..0193b2d --- /dev/null +++ b/src/infini_rt/tensor_view.h @@ -0,0 +1,6 @@ +#ifndef INFINI_RT_PUBLIC_TENSOR_VIEW_H_ +#define INFINI_RT_PUBLIC_TENSOR_VIEW_H_ + +#include "../tensor_view.h" + +#endif