Skip to content

Commit

Permalink
Bugfix for single_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
josephjaspers committed Nov 17, 2019
1 parent 42a9491 commit 7d28487
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 63 deletions.
1 change: 1 addition & 0 deletions include/neural_networks/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ struct NeuralNetwork {

void move_training_data_to_single_predict(int batch_index) {
m_layer_chain.for_each([&](auto& layer) {
layer.zero_time_index();
layer.move_training_data_to_single_predict(batch_index);
});
}
Expand Down
89 changes: 26 additions & 63 deletions include/shape/Dim.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,59 +149,16 @@ struct Dim {
public:


BCINLINE auto operator + (const Dim& other) const {
return op_impl(BC::oper::Add(), other);
#define BC_DIM_OP(op, functor)\
Dim operator op(const Dim& other) const { \
return this->op_impl(BC::oper::functor(), other); \
}

BCINLINE auto operator - (const Dim& other) const {
return op_impl(BC::oper::Sub(), other);
#define BC_DIM_INPLACE_OP(op, functor)\
Dim operator op##=(const Dim& other) { \
return this->inplace_op_impl(BC::oper::functor(), other); \
}

BCINLINE auto operator / (const Dim& other) const {
return op_impl(BC::oper::Div(), other);
}

BCINLINE auto operator * (const Dim& other) const {
return op_impl(BC::oper::Mul(), other);
}

BCINLINE auto operator > (const Dim& other) const {
return op_impl(BC::oper::Greater(), other);
}

BCINLINE auto operator < (const Dim& other) const {
return op_impl(BC::oper::Lesser(), other);
}

BCINLINE auto operator >= (const Dim& other) const {
return op_impl(BC::oper::Greater_Equal(), other);
}

BCINLINE auto operator <= (const Dim& other) const {
return op_impl(BC::oper::Lesser_Equal(), other);
}

BCINLINE auto equal(const Dim& other) const {
return op_impl(BC::oper::Equal(), other);
}

BCINLINE Dim& operator += (const Dim& other) {
return inplace_op_impl(BC::oper::Add(), other);
}

BCINLINE Dim& operator -= (const Dim& other) {
return inplace_opother_impl(BC::oper::Sub(), other);
}

BCINLINE Dim& operator /= (const Dim& other) {
return inplace_op_impl(BC::oper::Div(), other);
}

BCINLINE Dim& operator *= (const Dim& other) {
return inplace_op_impl(BC::oper::Mul(), other);
}


#define BC_DIM_INPLACE_SCALAR_OP(op, functor) \
friend Dim operator op##=(Dim &dim, const value_type& scalar) { \
return dim.inplace_scalar_op_impl(BC::oper::functor(), scalar); \
Expand All @@ -216,24 +173,30 @@ struct Dim {
return dim.scalar_op_impl(BC::oper::functor(), scalar); \
}

#define BC_DIM_SCALAR_OP_BOTH(op, functor) \
BC_DIM_INPLACE_SCALAR_OP(op, functor) \
BC_DIM_SCALAR_OP(op, functor)

#define BC_DIM_OP_FACTORY(op, functor) \
BC_DIM_OP(op, functor) \
BC_DIM_SCALAR_OP(op, functor)

BC_DIM_SCALAR_OP_BOTH(+, Add)
BC_DIM_SCALAR_OP_BOTH(-, Sub)
BC_DIM_SCALAR_OP_BOTH(/, Div)
BC_DIM_SCALAR_OP_BOTH(*, Mul)
BC_DIM_SCALAR_OP(<, Lesser)
BC_DIM_SCALAR_OP(<=, Lesser_Equal)
BC_DIM_SCALAR_OP(>, Greater)
BC_DIM_SCALAR_OP(>=, Greater_Equal)
#define BC_DIM_OP_BOTH(op, functor) \
BC_DIM_OP_FACTORY(op, functor) \
BC_DIM_INPLACE_OP(op, functor) \
BC_DIM_INPLACE_SCALAR_OP(op, functor)

BC_DIM_OP_BOTH(+, Add)
BC_DIM_OP_BOTH(-, Sub)
BC_DIM_OP_BOTH(/, Div)
BC_DIM_OP_BOTH(*, Mul)
BC_DIM_OP_FACTORY(<, Lesser)
BC_DIM_OP_FACTORY(<=, Lesser_Equal)
BC_DIM_OP_FACTORY(>, Greater)
BC_DIM_OP_FACTORY(>=, Greater_Equal)

#undef BC_DIM_SCALAR_OP
#undef BC_DIM_OP
#undef BC_DIM_INPLACE_OP
#undef BC_DIM_INPLACE_SCALAR_OP

#undef BC_DIM_SCALAR_OP
#undef BC_DIM_OP_FACTORY
#undef BC_DIM_OP_BOTH

BCINLINE
bool all(size_t start, size_t end) const {
Expand Down

0 comments on commit 7d28487

Please sign in to comment.