Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/infinicore/nn/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ class Module {

virtual ~Module() = default;

const std::unordered_map<std::string, Parameter> &state_dict() const;
std::unordered_map<std::string, Parameter> state_dict() const;

void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);

void load_parameter(const std::string &name, const Tensor &param);

void load_parameter_no_sync(const std::string &name, const Tensor &param);

void load_parameter_(const std::string &name, const Tensor &param);

void load_parameter_from_blob(const std::string &name, const void *data);
Expand Down
2 changes: 2 additions & 0 deletions include/infinicore/nn/parameter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Parameter : public Tensor {

void load(const Tensor &tensor);

void load_no_sync(const Tensor &tensor);

protected:
// Tensor parallel configs
Size tp_dim_; // dimension partitioned
Expand Down
25 changes: 20 additions & 5 deletions src/infinicore/nn/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
#include <stdexcept>

namespace infinicore::nn {
const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
static std::unordered_map<std::string, Parameter> result;
result.clear();

std::unordered_map<std::string, Parameter> Module::state_dict() const {
std::unordered_map<std::string, Parameter> result;
collect_all_parameters(result, "");

return result;
}

Expand Down Expand Up @@ -36,6 +33,24 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}

void Module::load_parameter_no_sync(const std::string &name, const Tensor &param) {
auto all_params = state_dict();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里每load一个权重,就要调用一次state_dict()去遍历所有的模块。 这里能优化么

auto it = all_params.find(name);
if (it != all_params.end()) {
auto existing_param = it->second;
try {
existing_param.load_no_sync(param);
} catch (const std::exception &e) {
throw std::runtime_error("Error loading parameter '" + name + "'. \n" + e.what());
}
return;
}

spdlog::debug("load_parameter_no_sync: Parameter '{}' not found. Available: {} params",
name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}

void Module::load_parameter_(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal)
auto it = parameters_.find(name);
Expand Down
7 changes: 5 additions & 2 deletions src/infinicore/nn/parameter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ void Parameter::load_blob(const void *data) {
}

void Parameter::load(const Tensor &tensor) {
load_no_sync(tensor);
infinicore::context::syncStream();
}

void Parameter::load_no_sync(const Tensor &tensor) {
if (impl_->dtype() != tensor->dtype()) {
throw std::runtime_error("Dtype mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
}
Expand Down Expand Up @@ -97,7 +102,5 @@ void Parameter::load(const Tensor &tensor) {

impl_->copy_from(tensor->narrow({{tp_dim_, offset, shard_size}}));
}

infinicore::context::syncStream();
}
} // namespace infinicore::nn
Loading