Skip to content

Commit

Permalink
ConstantType::StrictNoConst introduced
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 21, 2023
1 parent cd22e4a commit d835ea7
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
12 changes: 7 additions & 5 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,16 +988,18 @@ bool Node::isConstant() {
}

void Node::updateConstantType() {
bool isConst = true;
for (const auto& parentEdge : getParentEdges()) {
isConst &= parentEdge.lock()->getParent()->isConstant();
if (constant != ConstantType::StrictNoConst) {
bool isConst = true;
for (const auto& parentEdge : getParentEdges()) {
isConst &= parentEdge.lock()->getParent()->isConstant();
}
constant = isConst ? ConstantType::Const : ConstantType::NoConst;
}
constant = isConst ? ConstantType::Const : ConstantType::NoConst;

for (const auto& childEdge : getChildEdges()) {
const auto childNode = childEdge.lock()->getChild();
const auto childConstType = childNode->getConstantType();
if (childConstType != ConstantType::Unknown && childConstType != constant) {
if (!one_of(childConstType, ConstantType::Unknown, ConstantType::StrictNoConst, constant)) {
childNode->updateConstantType();
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class Node {
enum class ConstantType {
Unknown,
Const,
NoConst
NoConst,
StrictNoConst,
};
ConstantType getConstantType() const;
void updateConstantType();
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/multinomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Multinomial::Multinomial(const std::shared_ptr<ov::Node>& op, const GraphContext
m_num_samples_precision = ov::element::i32;
m_output_precision = multinomial_op->get_convert_type();

constant = ConstantType::NoConst;
constant = ConstantType::StrictNoConst;

m_const_batch = op->get_input_partial_shape(PROBS_PORT)[0].is_static();
m_const_inputs[PROBS_PORT] = is_type<op::v0::Constant>(op->get_input_node_ptr(PROBS_PORT));
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/random_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ RandomUniform::RandomUniform(const std::shared_ptr<ov::Node>& op, const GraphCon

// RandomUniform should generate new sequence each run even if all inputs are constants. So that method Node::IsConstant()
// doesn't return 'True' for RandomUniform with all constant inputs and the node generates new values for each inference,
// we set 'NoConst' value for 'ConstantType' in ctor.
constant = ConstantType::NoConst;
// we set 'StrictNoConst' value for 'ConstantType' in ctor.
constant = ConstantType::StrictNoConst;

auto rnd_op = as_type_ptr<op::v8::RandomUniform>(op);
m_global_seed = rnd_op->get_global_seed();
Expand Down

0 comments on commit d835ea7

Please sign in to comment.