Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d7e2139

Browse files
junrushaozheng-da
authored andcommitted
[MXNET-1417][Performance] Caching Dynamic Shape Checking Result (#15262)
* fix * address comments
1 parent cab1dfa commit d7e2139

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

src/imperative/cached_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ CachedOp::CachedOp(
100100
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
101101
static const auto _copy_op = Op::Get("_copy");
102102
config_.Init(flags);
103+
this->dynamic_shape_checked_ = false;
103104

104105
if (config_.static_shape) {
105106
CHECK(config_.static_alloc) << "static_alloc must be True when static_shape is True";
@@ -272,6 +273,11 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
272273
bool erase_result) {
273274
using namespace nnvm;
274275
using namespace imperative;
276+
if (this->dynamic_shape_checked_) {
277+
return config_.is_dynamic;
278+
} else {
279+
this->dynamic_shape_checked_ = true;
280+
}
275281
CHECK_EQ(inputs.size(), num_inputs());
276282

277283
auto state_ptr = GetCachedOpState(default_ctx);

src/imperative/cached_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class CachedOp {
196196
nnvm::Graph grad_graph_;
197197
nnvm::Graph full_graph_;
198198
bool inlining_;
199+
bool dynamic_shape_checked_;
199200
std::vector<nnvm::NodeEntry> ograd_entries_;
200201
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
201202
std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;

0 commit comments

Comments
 (0)