Skip to content

Commit

Permalink
Refactor forward declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
asadchev committed Apr 30, 2021
1 parent 626eb79 commit 7aa5114
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 232 deletions.
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
set(TILEDARRAY_HEADER_FILES
tiledarray.h
tiledarray_fwd.h
TiledArray/fwd.h
TiledArray/config.h
TiledArray/array_impl.h
TiledArray/bitset.h
Expand Down Expand Up @@ -98,6 +99,7 @@ TiledArray/expressions/contraction_helpers.h
TiledArray/expressions/expr.h
TiledArray/expressions/expr_engine.h
TiledArray/expressions/expr_trace.h
TiledArray/expressions/fwd.h
TiledArray/expressions/leaf_engine.h
TiledArray/expressions/mult_engine.h
TiledArray/expressions/mult_expr.h
Expand Down Expand Up @@ -177,7 +179,6 @@ TiledArray/util/random.h
TiledArray/util/singleton.h
TiledArray/util/time.h
TiledArray/util/vector.h

)

if(CUDA_FOUND)
Expand Down
70 changes: 32 additions & 38 deletions src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
#ifndef TILEDARRAY_ARRAY_H__INCLUDED
#define TILEDARRAY_ARRAY_H__INCLUDED

#include <cstdlib>

#include <madness/world/parallel_archive.h>
#include "TiledArray/expressions/fwd.h"

#include "TiledArray/array_impl.h"
#include "TiledArray/conversions/clone.h"
Expand All @@ -35,15 +33,14 @@
#include "TiledArray/util/initializer_list.h"
#include "TiledArray/util/random.h"

#include <cstdlib>
#include <madness/world/parallel_archive.h>

namespace TiledArray {

// Forward declarations
template <typename, typename>
class Tensor;
namespace expressions {
template <typename, bool>
class TsrExpr;
} // namespace expressions

/// A (multidimensional) tiled array

Expand All @@ -56,7 +53,6 @@ template <typename Tile = Tensor<double, Eigen::aligned_allocator<double>>,
typename Policy = DensePolicy>
class DistArray : public madness::archive::ParallelSerializableObject {
public:
typedef DistArray<Tile, Policy> DistArray_; ///< This object's type
typedef TiledArray::detail::ArrayImpl<Tile, Policy>
impl_type; ///< The type of the PIMPL
typedef typename impl_type::policy_type policy_type; ///< Policy type
Expand Down Expand Up @@ -107,7 +103,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// used elsewhere too.
///
template <typename OtherTile>
using is_my_type = std::is_same<DistArray_, DistArray<OtherTile, Policy>>;
using is_my_type = std::is_same<DistArray, DistArray<OtherTile, Policy>>;

template <typename OtherTile>
using enable_if_not_my_type =
Expand Down Expand Up @@ -149,7 +145,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
try {
world.gop.lazy_sync(id, [pimpl]() {
delete pimpl;
DistArray_::cleanup_counter_--;
DistArray::cleanup_counter_--;
});
} catch (madness::MadnessException& e) {
fprintf(stderr,
Expand Down Expand Up @@ -245,7 +241,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

/// This is a shallow copy, that is no data is copied.
/// \param other The array to be copied
DistArray(const DistArray_& other) : pimpl_(other.pimpl_) {}
DistArray(const DistArray& other) : pimpl_(other.pimpl_) {}

/// Dense array constructor

Expand Down Expand Up @@ -301,27 +297,27 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// raised \p world and \p il are unchanged.
template <typename T>
DistArray(World& world, detail::vector_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::matrix_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor3_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor4_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor5_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor6_il<T> il)
: DistArray(array_from_il<DistArray_>(world, il)) {}
: DistArray(array_from_il<DistArray>(world, il)) {}
///@}

/// \name Tiling initializer list constructors
Expand Down Expand Up @@ -352,27 +348,27 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// raised \p world and \p il are unchanged.
template <typename T>
DistArray(World& world, const trange_type& trange, detail::vector_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::matrix_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor3_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor4_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor5_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor6_il<T> il)
: DistArray(array_from_il<DistArray_>(world, trange, il)) {}
: DistArray(array_from_il<DistArray>(world, trange, il)) {}
/// @}

/// converting copy constructor
Expand Down Expand Up @@ -410,7 +406,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// Create a deep copy of this array

/// \return An array that is equal to this array
DistArray_ clone() const { return TiledArray::clone(*this); }
DistArray clone() const { return TiledArray::clone(*this); }

/// Accessor for the (shared_ptr to) implementation object

Expand Down Expand Up @@ -462,7 +458,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

/// This is a shallow copy, that is no data is copied.
/// \param other The array to be copied
DistArray_& operator=(const DistArray_& other) {
DistArray& operator=(const DistArray& other) {
pimpl_ = other.pimpl_;

return *this;
Expand Down Expand Up @@ -977,10 +973,9 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// \return A const tensor expression object
/// \note size and contents of \p vars are validated using
/// DistArray::check_str_index()
TiledArray::expressions::TsrExpr<const DistArray_, true> operator()(
const std::string& vars) const {
auto operator()(const std::string& vars) const {
check_str_index(vars);
return TiledArray::expressions::TsrExpr<const DistArray_, true>(*this,
return TiledArray::expressions::TsrExpr<const DistArray>(*this,
vars);
}

Expand All @@ -990,10 +985,9 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// \return A non-const tensor expression object
/// \note size and contents of \p vars are validated using
/// DistArray::check_str_index()
TiledArray::expressions::TsrExpr<DistArray_, true> operator()(
const std::string& vars) {
auto operator()(const std::string& vars) {
check_str_index(vars);
return TiledArray::expressions::TsrExpr<DistArray_, true>(*this, vars);
return TiledArray::expressions::TsrExpr<DistArray>(*this, vars);
}

/// \deprecated use DistArray::world()
Expand Down Expand Up @@ -1161,7 +1155,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

/// \param other The array to be swapped with this array.
/// \throw None no throw guarantee.
void swap(DistArray_& other) { std::swap(pimpl_, other.pimpl_); }
void swap(DistArray& other) { std::swap(pimpl_, other.pimpl_); }

/// Convert a distributed array into a replicated array
/// \throw TiledArray::Exception if the PIMPL is not initialized. Strong throw
Expand All @@ -1170,19 +1164,19 @@ class DistArray : public madness::archive::ParallelSerializableObject {
if ((!impl_ref().pmap()->is_replicated()) && (world().size() > 1)) {
// Construct a replicated array
auto pmap = std::make_shared<detail::ReplicatedPmap>(world(), size());
DistArray_ result = DistArray_(world(), trange(), shape(), pmap);
DistArray result = DistArray(world(), trange(), shape(), pmap);

// Create the replicator object that will do an all-to-all broadcast of
// the local tile data.
auto replicator =
std::make_shared<detail::Replicator<DistArray_>>(*this, result);
std::make_shared<detail::Replicator<DistArray>>(*this, result);

// Put the replicator pointer in the deferred cleanup object so it will
// be deleted at the end of the next fence.
TA_ASSERT(replicator.unique()); // Required for deferred_cleanup
madness::detail::deferred_cleanup(world(), replicator);

DistArray_::operator=(result);
DistArray::operator=(result);
}
}

Expand Down Expand Up @@ -1264,7 +1258,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

// use default pmap, ensure it's the same pmap used to serialize
auto volume = trange.tiles_range().volume();
auto pmap = detail::policy_t<DistArray_>::default_pmap(world, volume);
auto pmap = detail::policy_t<DistArray>::default_pmap(world, volume);
size_t pmap_hash_code = 0;
ar& pmap_hash_code;
if (pmap_hash_code != typeid(pmap.get()).hash_code())
Expand Down Expand Up @@ -1337,7 +1331,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

// use default pmap
auto volume = trange.tiles_range().volume();
auto pmap = detail::policy_t<DistArray_>::default_pmap(world, volume);
auto pmap = detail::policy_t<DistArray>::default_pmap(world, volume);
pimpl_.reset(
new impl_type(world, std::move(trange), std::move(shape), pmap));

Expand Down Expand Up @@ -1371,7 +1365,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {

// use default pmap
auto volume = trange.tiles_range().volume();
auto pmap = detail::policy_t<DistArray_>::default_pmap(world, volume);
auto pmap = detail::policy_t<DistArray>::default_pmap(world, volume);
pimpl_.reset(
new impl_type(world, std::move(trange), std::move(shape), pmap));
}
Expand Down
11 changes: 1 addition & 10 deletions src/TiledArray/expressions/blk_tsr_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef TILEDARRAY_EXPRESSIONS_BLK_TSR_ENGINE_H__INCLUDED
#define TILEDARRAY_EXPRESSIONS_BLK_TSR_ENGINE_H__INCLUDED

#include <TiledArray/expressions/fwd.h>
#include <TiledArray/expressions/leaf_engine.h>
#include <TiledArray/tile_op/shift.h>

Expand All @@ -37,16 +38,6 @@ class DistArray;

namespace expressions {

// Forward declaration
template <typename, bool>
class BlkTsrExpr;
template <typename, typename>
class ScalBlkTsrExpr;
template <typename, typename, bool>
class BlkTsrEngine;
template <typename, typename, typename>
class ScalBlkTsrEngine;

template <typename Tile, typename Policy, typename Result, bool Alias>
struct EngineTrait<BlkTsrEngine<DistArray<Tile, Policy>, Result, Alias>> {
// Argument typedefs
Expand Down
8 changes: 0 additions & 8 deletions src/TiledArray/expressions/blk_tsr_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@
namespace TiledArray {
namespace expressions {

// Forward declaration
template <typename, bool>
class TsrExpr;
template <typename, bool>
class BlkTsrExpr;
template <typename, typename>
class ScalBlkTsrExpr;

template <typename Array>
using ConjBlkTsrExpr =
ScalBlkTsrExpr<Array, TiledArray::detail::ComplexConjugate<void>>;
Expand Down
19 changes: 5 additions & 14 deletions src/TiledArray/expressions/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#ifndef TILEDARRAY_EXPRESSIONS_EXPR_H__INCLUDED
#define TILEDARRAY_EXPRESSIONS_EXPR_H__INCLUDED

#include "TiledArray/expressions/fwd.h"


#include "../reduce_task.h"
#include "../tile_interface/cast.h"
#include "../tile_interface/scale.h"
Expand All @@ -45,18 +48,7 @@

#include <TiledArray/tensor/type_traits.h>

namespace TiledArray {
namespace expressions {

// Forward declaration
template <typename>
struct ExprTrait;
template <typename, bool>
class TsrExpr;
template <typename, bool>
class BlkTsrExpr;
template <typename>
struct is_aliased;
namespace TiledArray::expressions {

template <typename Engine>
struct EngineParamOverride {
Expand Down Expand Up @@ -881,7 +873,6 @@ class Expr {

}; // class Expr

} // namespace expressions
} // namespace TiledArray
}

#endif // TILEDARRAY_EXPRESSIONS_EXPR_H__INCLUDED
6 changes: 1 addition & 5 deletions src/TiledArray/expressions/expr_trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@
#ifndef TILEDARRAY_EXPR_TRACE_H__INCLUDED
#define TILEDARRAY_EXPR_TRACE_H__INCLUDED

#include <TiledArray/expressions/fwd.h>
#include <TiledArray/expressions/index_list.h>
#include <iostream>

namespace TiledArray {
namespace expressions {

template <typename>
class Expr;
template <typename, bool>
class TsrExpr;

/// Expression output stream
class ExprOStream {
std::ostream& os_; ///< output stream
Expand Down
Loading

0 comments on commit 7aa5114

Please sign in to comment.