Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
[submodule "third_party/googletest"]
path = third_party/googletest
url = git@github.com:google/googletest.git
[submodule "third_party/InfiniOps"]
path = third_party/InfiniOps
url = git@github.com:InfiniTensor/InfiniOps.git
34 changes: 34 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF)
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
option(USE_NCCL "Build project for distributed running" ON)
option(USE_INFINIOPS "Use InfiniOps as an optional kernel provider" OFF)
option(BUILD_TEST "Build InfiniTrain tests" OFF)

project(infini_train VERSION 0.5.0 LANGUAGES CXX)
Expand Down Expand Up @@ -51,6 +52,32 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)

include_directories(${PROJECT_SOURCE_DIR})

if(USE_INFINIOPS)
add_compile_definitions(USE_INFINIOPS=1)

set(INFINIOPS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/third_party/InfiniOps")
if(NOT EXISTS "${INFINIOPS_SOURCE_DIR}/CMakeLists.txt")
message(FATAL_ERROR
"USE_INFINIOPS=ON requires InfiniOps under third_party/InfiniOps. "
"Run: git submodule update --init third_party/InfiniOps")
endif()

set(INFINIOPS_WITH_CPU OFF)
if(NOT USE_CUDA)
set(INFINIOPS_WITH_CPU ON)
endif()

set(WITH_CPU ${INFINIOPS_WITH_CPU} CACHE BOOL "Enable InfiniOps CPU backend" FORCE)
set(WITH_NVIDIA ${USE_CUDA} CACHE BOOL "Enable InfiniOps NVIDIA backend" FORCE)
add_subdirectory(${INFINIOPS_SOURCE_DIR} ${CMAKE_BINARY_DIR}/third_party/InfiniOps EXCLUDE_FROM_ALL)
if(NOT TARGET infiniops)
message(FATAL_ERROR "InfiniOps third-party project did not define target `infiniops`")
endif()
if(NOT TARGET InfiniOps::infiniops)
add_library(InfiniOps::infiniops ALIAS infiniops)
endif()
endif()

if(PROFILE_MODE)
add_compile_definitions(PROFILE_MODE=1)
endif()
Expand All @@ -62,9 +89,13 @@ endif()
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
if(NOT USE_INFINIOPS)
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/.*\.cc$")
endif()
if(NOT USE_CUDA)
list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*")
list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*")
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/cuda/.*")
endif()
if(NOT USE_NCCL)
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
Expand Down Expand Up @@ -126,6 +157,9 @@ endif()
# ------------------------------------------------------------------------------

add_library(infini_train STATIC ${SRC})
if(USE_INFINIOPS)
target_link_libraries(infini_train PUBLIC InfiniOps::infiniops)
endif()
target_link_libraries(infini_train
PUBLIC
glog
Expand Down
45 changes: 45 additions & 0 deletions infini_train/include/core/kernel_provider/infiniops/adapter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include <cstdint>
#include <memory>
#include <mutex>
#include <vector>

#include <handle.h>

#include "data_type.h"
#include "tensor.h"

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train {
class Tensor;
} // namespace infini_train

namespace infini_train::core {
class Stream;
} // namespace infini_train::core

namespace infini_train::kernel_provider::infiniops {

infini::ops::DataType ToOpsDataType(DataType dtype);

infini::ops::Device ToOpsDevice(const Device &device);

std::mutex &InfiniOpsCallMutex();

using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream);

void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory);

infini::ops::Handle GetHandle(const Device &device);

infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor);

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device);

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
const std::vector<int64_t> &strides);

} // namespace infini_train::kernel_provider::infiniops
50 changes: 50 additions & 0 deletions infini_train/include/core/kernel_provider/infiniops_registry.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是出于什么原因要单独写一套 registry,而不能直接复用 InfiniTrain 原有的注册表呢?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

见下回复

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include <map>
#include <string>
#include <utility>

#include "glog/logging.h"

#include "infini_train/include/device.h"
#include "infini_train/include/dispatcher.h"

namespace infini_train::kernel_provider {

using KeyT = std::pair<Device::DeviceType, std::string>;

class InfiniOpsRegistry {
public:
static InfiniOpsRegistry &Instance() {
static InfiniOpsRegistry instance;
return instance;
}

const KernelFunction *Lookup(const std::string &kernel_name) const {
auto it = name_to_kernel_map_.find(kernel_name);
return it == name_to_kernel_map_.end() ? nullptr : &it->second;
}

template <typename FuncT> void Register(const std::string &kernel_name, FuncT &&kernel) {
CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name;
name_to_kernel_map_.emplace(kernel_name, kernel);
}

private:
std::map<std::string, KernelFunction> name_to_kernel_map_;
};

// Bridge functions used by Dispatcher::GetKernel. Implemented in
// infiniops_registry.cc; declared here for users that already include
// the full registry header (e.g. unit tests).
bool InfiniOpsEnabled();
bool InfiniOpsEnabled(const KeyT &key);
const KernelFunction *LookupInfiniOpsKernel(const KeyT &key);

} // namespace infini_train::kernel_provider

#define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \
static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \
infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \
return true; \
}();
17 changes: 17 additions & 0 deletions infini_train/include/dispatcher.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不应该给 infinops 开额外分支,之前接沐曦 kernel 这块是不需要动的。

Copy link
Copy Markdown
Contributor Author

@chen2021673 chen2021673 Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里沐曦和 infinops 的区别在于 infinops 需要解耦 device。

沐曦/MACA 是 InfiniTrain 的一个 device backend。它和 CUDA/CPU 一样,kernel 是按 device + op 注册的,所以可以继续用现有 REGISTER_KERNEL(device, kernel_name, kernel_func)

但 infinops 里面封装了不同 device 的执行,不是某一个具体 device,而是 kernel provider / backend provider。InfiniOps 自己内部再根据 handle/tensor device 去适配 NVIDIA、CPU、MUSA、Moore 等后端。所以框架这边调的时候也要在 Dispatcher::GetKernel 之前或里面做一层 provider policy。 infiniops_registry.h 里重写一套注册也是这个原因,其实就是无法复用REGISTER_KERNEL 接口,要把 device 参数换成 backend 参数。

这里其实可以不特化写一个 InfiniOpsRegistry,而是给Dispatcher添加一个通用的REGISTER_KERNEL_BACKEND(backend, kernel_name, kernel_func) ,InfiniOps调用REGISTER_KERNEL_BACKEND("InfiniOps", kernel_name, kernel_func)。这样 dispatcher.h 可能会整洁一点。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解可以复用 REGISTER_KERNEL(device, kernel_name, kernel_func) 来注册不同 device 的 infiniops kernel。在框架层,不需要关心底层算子库的实现,只按需注册对应 device 的 kernel 即可,类似:
https://github.com/InfiniTensor/InfiniTensor/blob/fcd1fb0299e181f841918c4db4e5f13a18a2ae60/src/kernels/infiniop/element_wise.cc#L36

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <map>
#include <string>
#include <type_traits>
#include <utility>

Expand Down Expand Up @@ -47,6 +48,11 @@ class KernelFunction {
void *func_ptr_ = nullptr;
};

namespace kernel_provider {
bool InfiniOpsEnabled(const std::pair<Device::DeviceType, std::string> &key);
const KernelFunction *LookupInfiniOpsKernel(const std::pair<Device::DeviceType, std::string> &key);
} // namespace kernel_provider

class Dispatcher {
public:
using KeyT = std::pair<Device::DeviceType, std::string>;
Expand All @@ -57,6 +63,17 @@ class Dispatcher {
}

const KernelFunction &GetKernel(KeyT key) const {
if (kernel_provider::InfiniOpsEnabled(key)) {
if (const auto *kernel = kernel_provider::LookupInfiniOpsKernel(key)) {
#ifdef PROFILE_MODE
SetProfileContext(key.second, key.first);
#endif
return *kernel;
}
LOG(WARNING) << "InfiniOps kernel enabled but not registered: " << key.second
<< " on device: " << static_cast<int>(key.first) << "; falling back to default kernel";
}

CHECK(key_to_kernel_map_.contains(key))
<< "Kernel not found: " << key.second << " on device: " << static_cast<int>(key.first);
#ifdef PROFILE_MODE
Expand Down
48 changes: 48 additions & 0 deletions infini_train/include/kernels/common/gemm.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件内容没什么问题,但不适合放到 include 里作为公共头文件暴露,先放 infini_train/src/kernels/common 里吧

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::kernels {

enum class GemmTranspose : int {
kNoTranspose = 0,
kTranspose = 1,
};

/**
* Parameter bundle for a single GEMM call:
* C = alpha * op(A) * op(B) + beta * C
*
* batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a
* strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must
* be left at 0.
*/
struct GemmParams {
GemmTranspose trans_a = GemmTranspose::kNoTranspose;
GemmTranspose trans_b = GemmTranspose::kNoTranspose;

int m = 0; // rows of op(A) and C
int n = 0; // cols of op(B) and C
int k = 0; // cols of op(A) == rows of op(B)

const void *A = nullptr;
int lda = 0;
const void *B = nullptr;
int ldb = 0;
void *C = nullptr;
int ldc = 0;

float alpha = 1.0f;
float beta = 0.0f;

int batch_count = 1;
long long stride_a = 0;
long long stride_b = 0;
long long stride_c = 0;

DataType input_dtype;
DataType output_dtype;
};

} // namespace infini_train::kernels
2 changes: 1 addition & 1 deletion infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
std::shared_ptr<Tensor> View(const std::vector<int64_t> &dims);
std::shared_ptr<Tensor> Contiguous();
// FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor
// class before this can be implemented correctly. The guard in elementwise.cu ensures
// class before this can be implemented correctly. The elementwise broadcast guard ensures
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用改吧

Copy link
Copy Markdown
Contributor Author

@chen2021673 chen2021673 Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该限制在.cu?

// non-contiguous tensors fall back to the broadcast path until this is resolved.
bool IsContiguous() const;
std::shared_ptr<Tensor> Flatten(int64_t start = 0, int64_t end = -1);
Expand Down
120 changes: 120 additions & 0 deletions infini_train/src/core/kernel_provider/infiniops/adapter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "infini_train/include/core/kernel_provider/infiniops/adapter.h"

#include <map>
#include <unordered_map>

#include "glog/logging.h"

#include "infini_train/include/core/runtime/device_guard.h"

namespace infini_train::kernel_provider::infiniops {

namespace {

inline const std::unordered_map<DataType, infini::ops::DataType> kOpsDataTypeMap = {
{DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16},
{DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64},
{DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16},
{DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64},
{DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16},
{DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64},
};

inline const std::unordered_map<Device::DeviceType, infini::ops::Device::Type> kOpsDeviceTypeMap = {
{Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia},
{Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu},
};

std::map<Device::DeviceType, HandleFactory> &HandleFactories() {
static std::map<Device::DeviceType, HandleFactory> factories;
return factories;
}

} // namespace

void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) {
CHECK(factory != nullptr);
auto &factories = HandleFactories();
CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type "
<< static_cast<int>(type);
factories.emplace(type, factory);
}

infini::ops::Handle GetHandle(const Device &device) {
auto &factories = HandleFactories();
auto it = factories.find(device.type());
CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type "
<< static_cast<int>(device.type());

auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device);
return it->second(device, stream);
}

infini::ops::DataType ToOpsDataType(DataType dtype) {
auto it = kOpsDataTypeMap.find(dtype);
if (it == kOpsDataTypeMap.end()) {
LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast<int>(dtype);
__builtin_unreachable();
}
return it->second;
}

infini::ops::Device ToOpsDevice(const Device &device) {
auto it = kOpsDeviceTypeMap.find(device.type());
if (it == kOpsDeviceTypeMap.end()) {
LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast<int>(device.type());
__builtin_unreachable();
}
return {it->second, device.index()};
}

std::mutex &InfiniOpsCallMutex() {
static std::mutex mutex;
return mutex;
}

namespace {
infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector<int64_t> &dims) {
infini::ops::Tensor::Strides strides(dims.size());
if (dims.empty()) {
return strides;
}
strides.back() = 1;
for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * static_cast<infini::ops::Tensor::Stride>(dims[i + 1]);
}
return strides;
}

infini::ops::Tensor::Shape ToShape(const std::vector<int64_t> &dims) {
infini::ops::Tensor::Shape shape(dims.size());
for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast<infini::ops::Tensor::Size>(dims[i]); }
return shape;
}

infini::ops::Tensor::Strides ToStrides(const std::vector<int64_t> &strides) {
infini::ops::Tensor::Strides ops_strides(strides.size());
for (size_t i = 0; i < strides.size(); ++i) {
ops_strides[i] = static_cast<infini::ops::Tensor::Stride>(strides[i]);
}
return ops_strides;
}
} // namespace

infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor) {
const auto &dims = tensor->Dims();
return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()),
ComputeContiguousStrides(dims)};
}

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device) {
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)};
}

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
const std::vector<int64_t> &strides) {
CHECK_EQ(dims.size(), strides.size());
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)};
}

} // namespace infini_train::kernel_provider::infiniops
Loading
Loading