Skip to content

Commit

Permalink
TypeGraph: Introduce NodeTracker for efficient cycle detection
Browse files Browse the repository at this point in the history
Added to Flattener and TypeIdentifier passes for now as a
proof-of-concept. Other passes can come later.
  • Loading branch information
ajor committed Jul 11, 2023
1 parent f676112 commit bc5ba69
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 44 deletions.
17 changes: 5 additions & 12 deletions oi/type_graph/Flattener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,19 @@ namespace type_graph {

Pass Flattener::createPass() {
auto fn = [](TypeGraph& typeGraph) {
Flattener flattener;
flattener.flatten(typeGraph.rootTypes());
// TODO should flatten just operate on a single type and we do the looping
// here?
Flattener flattener{typeGraph.resetTracker()};
for (auto& type : typeGraph.rootTypes()) {
flattener.accept(type);
}
};

return Pass("Flattener", fn);
}

void Flattener::flatten(std::vector<std::reference_wrapper<Type>>& types) {
for (auto& type : types) {
accept(type);
}
}

void Flattener::accept(Type& type) {
if (visited_.count(&type) != 0)
if (tracker_.visit(type))
return;

visited_.insert(&type);
type.accept(*this);
}

Expand Down
11 changes: 6 additions & 5 deletions oi/type_graph/Flattener.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
#pragma once

#include <string>
#include <unordered_set>
#include <vector>

#include "NodeTracker.h"
#include "PassManager.h"
#include "Types.h"
#include "Visitor.h"
Expand All @@ -28,14 +28,15 @@ namespace type_graph {
/*
* Flattener
*
* Flattens classes by removing parents and adding their members directly into
* derived classes.
* Flattens classes by removing parents and adding their attributes directly
* into derived classes.
*/
class Flattener : public RecursiveVisitor {
public:
static Pass createPass();

void flatten(std::vector<std::reference_wrapper<Type>>& types);
Flattener(NodeTracker& tracker) : tracker_(tracker) {
}

using RecursiveVisitor::accept;

Expand All @@ -46,7 +47,7 @@ class Flattener : public RecursiveVisitor {
static const inline std::string ParentPrefix = "__oi_parent";

private:
std::unordered_set<Type*> visited_;
NodeTracker& tracker_;
std::vector<Member> flattened_members_;
std::vector<uint64_t> offset_stack_;
};
Expand Down
70 changes: 70 additions & 0 deletions oi/type_graph/NodeTracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <vector>

#include "Types.h"

namespace type_graph {

/*
* NodeTracker
*
* Helper class for visitors. Efficiently tracks whether or not a graph node has
* been seen before, to avoid infinite looping on cycles.
*/
class NodeTracker {
public:
NodeTracker() = default;
NodeTracker(size_t size) : visited_(size) {
}

/*
* visit
*
* Marks a given node as visited.
* Returns true if this node has already been visited, false otherwise.
*/
bool visit(const Type& type) {
auto id = type.id();
if (id < 0)
return false;
if (visited_.size() <= static_cast<size_t>(id))
visited_.resize(id + 1);
bool result = visited_[id];
visited_[id] = true;
return result;
}

/*
* reset
*
* Clears the contents of this NodeTracker and marks every node as unvisited.
*/
void reset() {
std::fill(visited_.begin(), visited_.end(), false);
}

void resize(size_t size) {
visited_.resize(size);
}

private:
std::vector<bool> visited_;
};

} // namespace type_graph
6 changes: 6 additions & 0 deletions oi/type_graph/TypeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

namespace type_graph {

NodeTracker& TypeGraph::resetTracker() noexcept {
tracker_.reset();
tracker_.resize(size());
return tracker_;
}

template <>
Primitive& TypeGraph::makeType<Primitive>(Primitive::Kind kind) {
switch (kind) {
Expand Down
4 changes: 4 additions & 0 deletions oi/type_graph/TypeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <vector>

#include "NodeTracker.h"
#include "Types.h"

namespace type_graph {
Expand Down Expand Up @@ -47,6 +48,8 @@ class TypeGraph {
rootTypes_.push_back(type);
}

NodeTracker& resetTracker() noexcept;

// Override of the generic makeType function that returns singleton Primitive
// objects
template <typename T>
Expand Down Expand Up @@ -83,6 +86,7 @@ class TypeGraph {
std::vector<std::reference_wrapper<Type>> rootTypes_;
// Store all type objects in vectors for ownership. Order is not significant.
std::vector<std::unique_ptr<Type>> types_;
NodeTracker tracker_;
NodeId next_id_ = 0;
};

Expand Down
6 changes: 3 additions & 3 deletions oi/type_graph/TypeIdentifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace type_graph {
Pass TypeIdentifier::createPass(
const std::vector<ContainerInfo>& passThroughTypes) {
auto fn = [&passThroughTypes](TypeGraph& typeGraph) {
TypeIdentifier typeId{typeGraph, passThroughTypes};
TypeIdentifier typeId{typeGraph.resetTracker(), typeGraph,
passThroughTypes};
for (auto& type : typeGraph.rootTypes()) {
typeId.accept(type);
}
Expand All @@ -48,10 +49,9 @@ bool TypeIdentifier::isAllocator(Type& t) {
}

void TypeIdentifier::accept(Type& type) {
if (visited_.count(&type) != 0)
if (tracker_.visit(type))
return;

visited_.insert(&type);
type.accept(*this);
}

Expand Down
12 changes: 7 additions & 5 deletions oi/type_graph/TypeIdentifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
*/
#pragma once

#include <array>
#include <unordered_set>
#include <vector>

#include "NodeTracker.h"
#include "PassManager.h"
#include "Types.h"
#include "Visitor.h"
Expand All @@ -38,9 +37,12 @@ class TypeIdentifier : public RecursiveVisitor {
static Pass createPass(const std::vector<ContainerInfo>& passThroughTypes);
static bool isAllocator(Type& t);

TypeIdentifier(TypeGraph& typeGraph,
TypeIdentifier(NodeTracker& tracker,
TypeGraph& typeGraph,
const std::vector<ContainerInfo>& passThroughTypes)
: typeGraph_(typeGraph), passThroughTypes_(passThroughTypes) {
: tracker_(tracker),
typeGraph_(typeGraph),
passThroughTypes_(passThroughTypes) {
}

using RecursiveVisitor::accept;
Expand All @@ -49,7 +51,7 @@ class TypeIdentifier : public RecursiveVisitor {
void visit(Container& c) override;

private:
std::unordered_set<Type*> visited_;
NodeTracker& tracker_;
TypeGraph& typeGraph_;
const std::vector<ContainerInfo>& passThroughTypes_;
};
Expand Down
Loading

0 comments on commit bc5ba69

Please sign in to comment.