Skip to content
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,19 @@ add_executable(gpt2
example/gpt2/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/gpt2/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/gpt2/checkpoint_loader.cc
)
link_infini_train_exe(gpt2)

add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/llama3/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
example/llama3/checkpoint_loader.cc
)
link_infini_train_exe(llama3)

Expand Down
119 changes: 119 additions & 0 deletions example/common/checkpoint_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "example/common/checkpoint_loader.h"
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件太重了,有一千多行,checkpoint相关的基建和 llama / gpt 的 save / load 都混在一起了。要不要拆分一个example/common/checkpoint_utils.h/.cc,然后保留 gpt2 和 llama3 各自的特化调用?这个可以再讨论一下

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是按模型拆分吧,通用的公共函数放这里,gpt2/llama3 的特化部分放 example 下模型各自文件夹里。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include <cmath>
#include <cstdlib>
#include <filesystem>
#include <memory>
#include <string>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/nn/modules/transformer/transformer_config.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/tensor.h"

using namespace infini_train;
namespace nn = infini_train::nn;

// TODO(jym): ckpt is a new checkpoint format; bin is the legacy format. Keeping both as an interim solution; plan to
// consolidate into one later.
ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) {
ResumeFromCheckpointResult result;
if (args.resume_root.empty()) {
LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch.";
return result;
}

int ddp_world_size = nn::parallel::global::GetDataParallelSize();
int tp_world_size = nn::parallel::global::GetTensorParallelSize();
int sp_world_size = nn::parallel::global::GetSequenceParallelEnabled() ? tp_world_size : 1;
int pp_world_size = nn::parallel::global::GetPipelineParallelSize();

std::filesystem::path resume_dir = args.resume_root;
if (args.rank.IsParallel()) {
const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank());
if (std::filesystem::exists(rank_dir)) {
resume_dir = rank_dir;
}
}

Checkpoint::Load(resume_dir, *args.model, args.optimizer.get(), args.state);

result.global_step = static_cast<int>(args.state.global_step);

CHECK_EQ(args.state.n_layer, args.model_config.n_layer)
<< "n_layer mismatch: ckpt=" << args.state.n_layer << ", config=" << args.model_config.n_layer;
CHECK_EQ(args.state.n_head, args.model_config.n_head)
<< "n_head mismatch: ckpt=" << args.state.n_head << ", config=" << args.model_config.n_head;
CHECK_EQ(args.state.n_kv_head, args.model_config.n_kv_head)
<< "n_kv_head mismatch: ckpt=" << args.state.n_kv_head << ", config=" << args.model_config.n_kv_head;
CHECK_EQ(args.state.n_embd, args.model_config.n_embd)
<< "n_embd mismatch: ckpt=" << args.state.n_embd << ", config=" << args.model_config.n_embd;
CHECK_EQ(args.state.vocab_size, args.model_config.vocab_size)
<< "vocab_size mismatch: ckpt=" << args.state.vocab_size << ", config=" << args.model_config.vocab_size;

CHECK_EQ(args.state.ddp_size, ddp_world_size) << "DDP size mismatch: checkpoint has DDP=" << args.state.ddp_size
<< ", but current run has DDP=" << ddp_world_size;
CHECK_EQ(args.state.tp_size, tp_world_size)
<< "TP size mismatch: checkpoint has TP=" << args.state.tp_size << ", but current run has TP=" << tp_world_size;
CHECK_EQ(args.state.sp_size, sp_world_size)
<< "SP size mismatch: checkpoint has SP=" << args.state.sp_size << ", but current run has SP=" << sp_world_size;
CHECK_EQ(args.state.pp_size, pp_world_size)
<< "PP size mismatch: checkpoint has PP=" << args.state.pp_size << ", but current run has PP=" << pp_world_size;

result.consumed_batches = static_cast<size_t>(std::max<int64_t>(args.state.consumed_batches, 0));
if (args.rank.IsMainRank()) {
LOG(INFO) << std::format("Resume training from step {}, last_lr {:.3e}, consumed_batches {}",
args.state.global_step, args.state.last_lr, args.state.consumed_batches);
}

return result;
}

void SaveCheckpoint(const SaveCheckpointArgs &args) {
const auto ckpt_start = std::chrono::high_resolution_clock::now();

TrainerState state;
state.global_step = args.global_step;
state.consumed_batches = static_cast<int64_t>(args.consumed_batches);
state.last_lr = args.last_lr;
state.n_layer = args.n_layer;
state.n_head = args.n_head;
state.n_kv_head = args.n_kv_head;
state.n_embd = args.n_embd;
state.vocab_size = args.vocab_size;
state.ddp_size = args.ddp_size;
state.tp_size = args.tp_size;
state.sp_size = args.sp_size;
state.pp_size = args.pp_size;

Checkpoint::Save(args.save_dir, args.model, &args.optimizer, state);

const auto ckpt_end = std::chrono::high_resolution_clock::now();
const double ckpt_ms = std::chrono::duration<double, std::milli>(ckpt_end - ckpt_start).count();

if (!args.rank.IsMainRank()) {
return;
}

LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);

if (!args.prune_step_checkpoints) {
return;
}

std::vector<std::filesystem::path> ckpts;
if (std::filesystem::exists(args.checkpoint_root_dir)) {
for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) {
if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) {
ckpts.push_back(entry.path());
}
}
std::sort(ckpts.begin(), ckpts.end());
while (ckpts.size() > args.max_checkpoint_keep) {
std::filesystem::remove_all(ckpts.front());
ckpts.erase(ckpts.begin());
}
}
}
60 changes: 60 additions & 0 deletions example/common/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <cstdint>
#include <cstring>
#include <filesystem>

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

using namespace infini_train;
namespace nn = infini_train::nn;

namespace infini_train::nn {
class TransformerConfig;
}

struct ResumeFromCheckpointArgs {
std::filesystem::path resume_root;
const nn::parallel::Rank &rank;
std::shared_ptr<nn::Module> model;
std::shared_ptr<Optimizer> optimizer;
DistributedDataLoader &train_loader;
const nn::TransformerConfig &model_config;
TrainerState &state;
};

struct ResumeFromCheckpointResult {
int global_step = 0;
size_t consumed_batches = 0;
};

struct SaveCheckpointArgs {
std::filesystem::path save_dir;
int64_t global_step = 0;
size_t consumed_batches = 0;
double last_lr = 0.0;
int64_t n_layer = 0;
int64_t n_head = 0;
int64_t n_kv_head = 0;
int64_t n_embd = 0;
int64_t vocab_size = 0;
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
bool no_save_optim = false;
bool prune_step_checkpoints = false;
std::filesystem::path checkpoint_root_dir;
size_t max_checkpoint_keep = 0;
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

void SaveCheckpoint(const SaveCheckpointArgs &args);
41 changes: 16 additions & 25 deletions example/gpt2/checkpoint_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
#include <filesystem>
#include <fstream>
#include <memory>
#include <random>
#include <string>
#include <tuple>
#include <vector>

#include "glog/logging.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"
Expand All @@ -24,39 +21,34 @@
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

#include "example/common/utils.h"
#include "example/gpt2/config.h"

using namespace infini_train;
namespace nn = infini_train::nn;

namespace {
constexpr int kRandomSeed = 42;
constexpr int32_t kGPT2Magic = 20240326;
constexpr int32_t kGPT2FP32Version = 3;
constexpr int32_t kGPT2BF16Version = 5;

// TODO(dcj): make this rng generator compatible with torch later
static std::mt19937 gen{kRandomSeed};
} // namespace

namespace {
constexpr int32_t kHeaderMagic = 20240326;
constexpr int32_t kHeaderFP32Version = 3;
constexpr int32_t kHeaderBF16Version = 5;

std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header,
size_t offset) {
std::tuple<int32_t, DataType> DetermineAndCheckVersion(const std::vector<uint8_t> &header, size_t offset) {
const auto version = BytesToType<uint32_t>(header, offset);
switch (version) {
case kHeaderBF16Version:
return {version, infini_train::DataType::kBFLOAT16};
case kHeaderFP32Version:
return {version, infini_train::DataType::kFLOAT32};
case kGPT2BF16Version:
return {version, DataType::kBFLOAT16};
case kGPT2FP32Version:
return {version, DataType::kFLOAT32};
default:
LOG(FATAL) << "Unsupported version: " << version << " at " << __FILE__ << ":" << __LINE__;
return {}; // Unreachable, but keeps compiler happy
}
}
} // namespace

namespace gpt2 {

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
// TODO(jym): ckpt is a new checkpoint format; bin is the legacy format. Keeping both as an interim solution; plan to
// consolidate into one later.
std::shared_ptr<nn::TransformerModel> gpt2::LoadFromLLMC(const std::string &filepath) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand All @@ -65,9 +57,9 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs);

const auto magic = BytesToType<uint32_t>(header, 0);
CHECK_EQ(magic, kHeaderMagic);
CHECK_EQ(magic, kGPT2Magic);
auto [version, dtype] = DetermineAndCheckVersion(header, 4);
CHECK_EQ(version, kHeaderFP32Version);
CHECK_EQ(version, kGPT2FP32Version);

auto tp_size = nn::parallel::global::GetTensorParallelSize();

Expand Down Expand Up @@ -428,4 +420,3 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)

return local_gpt2;
}
} // namespace gpt2
3 changes: 2 additions & 1 deletion example/gpt2/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#pragma once

#include <cstring>
#include <memory>
#include <string>

namespace infini_train::nn {
class TransformerModel;
} // namespace infini_train::nn
}

namespace gpt2 {
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
Expand Down
1 change: 0 additions & 1 deletion example/gpt2/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ inline void SanitizeGPT2Config(const nn::TransformerConfig &c) {
CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
}

} // namespace gpt2
Loading
Loading