diff --git a/include/infinicore/nn/module.hpp b/include/infinicore/nn/module.hpp index 32a788484..8f30998fb 100644 --- a/include/infinicore/nn/module.hpp +++ b/include/infinicore/nn/module.hpp @@ -15,12 +15,14 @@ class Module { virtual ~Module() = default; - const std::unordered_map &state_dict() const; + std::unordered_map state_dict() const; void load_state_dict(const std::unordered_map &_state_dict); void load_parameter(const std::string &name, const Tensor ¶m); + void load_parameter_no_sync(const std::string &name, const Tensor ¶m); + void load_parameter_(const std::string &name, const Tensor ¶m); void load_parameter_from_blob(const std::string &name, const void *data); diff --git a/include/infinicore/nn/parameter.hpp b/include/infinicore/nn/parameter.hpp index 6b956f574..1602f58d6 100644 --- a/include/infinicore/nn/parameter.hpp +++ b/include/infinicore/nn/parameter.hpp @@ -27,6 +27,8 @@ class Parameter : public Tensor { void load(const Tensor &tensor); + void load_no_sync(const Tensor &tensor); + protected: // Tensor parallel configs Size tp_dim_; // dimension partitioned diff --git a/src/infinicore/nn/module.cc b/src/infinicore/nn/module.cc index f055f0958..dae21ed92 100644 --- a/src/infinicore/nn/module.cc +++ b/src/infinicore/nn/module.cc @@ -3,12 +3,9 @@ #include namespace infinicore::nn { -const std::unordered_map &Module::state_dict() const { - static std::unordered_map result; - result.clear(); - +std::unordered_map Module::state_dict() const { + std::unordered_map result; collect_all_parameters(result, ""); - return result; } @@ -36,6 +33,24 @@ void Module::load_parameter(const std::string &name, const Tensor ¶m) { throw std::runtime_error("Parameter '" + name + "' not found in module."); } +void Module::load_parameter_no_sync(const std::string &name, const Tensor ¶m) { + auto all_params = state_dict(); + auto it = all_params.find(name); + if (it != all_params.end()) { + auto existing_param = it->second; + try { + existing_param.load_no_sync(param); + } catch (const std::exception &e) { + throw std::runtime_error("Error loading parameter '" + name + "'. \n" + e.what()); + } + return; + } + + spdlog::debug("load_parameter_no_sync: Parameter '{}' not found. Available: {} params", + name, parameters_.size()); + throw std::runtime_error("Parameter '" + name + "' not found in module."); +} + void Module::load_parameter_(const std::string &name, const Tensor ¶m) { // This function only handles direct parameters (no hierarchical traversal) auto it = parameters_.find(name); diff --git a/src/infinicore/nn/parameter.cc b/src/infinicore/nn/parameter.cc index 8b42c7810..1b8a82762 100644 --- a/src/infinicore/nn/parameter.cc +++ b/src/infinicore/nn/parameter.cc @@ -58,6 +58,11 @@ void Parameter::load_blob(const void *data) { } void Parameter::load(const Tensor &tensor) { + load_no_sync(tensor); + infinicore::context::syncStream(); +} + +void Parameter::load_no_sync(const Tensor &tensor) { if (impl_->dtype() != tensor->dtype()) { throw std::runtime_error("Dtype mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + "."); } @@ -97,7 +102,5 @@ void Parameter::load(const Tensor &tensor) { impl_->copy_from(tensor->narrow({{tp_dim_, offset, shard_size}})); } - - infinicore::context::syncStream(); } } // namespace infinicore::nn