forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_class_import.cpp
158 lines (136 loc) · 4.78 KB
/
test_class_import.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <gtest/gtest.h>
#include <ATen/core/qualified_name.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
static constexpr c10::string_view classSrcs1 = R"JIT(
class FooNestedTest:
def __init__(self, y):
self.y = y
class FooNestedTest2:
def __init__(self, y):
self.y = y
self.nested = __torch__.FooNestedTest(y)
class FooTest:
def __init__(self, x):
self.class_attr = __torch__.FooNestedTest(x)
self.class_attr2 = __torch__.FooNestedTest2(x)
self.x = self.class_attr.y + self.class_attr2.y
)JIT";
static constexpr c10::string_view classSrcs2 = R"JIT(
class FooTest:
def __init__(self, x):
self.dx = x
)JIT";
static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
/*version=*/2);
si.loadType(QualifiedName(class_name));
}
TEST(ClassImportTest, Basic) {
auto cu1 = std::make_shared<CompilationUnit>();
auto cu2 = std::make_shared<CompilationUnit>();
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
"__torch__.FooTest",
std::make_shared<Source>(classSrcs1),
constantTable);
import_libs(
cu2,
"__torch__.FooTest",
std::make_shared<Source>(classSrcs2),
constantTable);
// We should get the correct version of `FooTest` for whichever namespace we
// are referencing
c10::QualifiedName base("__torch__");
auto classType1 = cu1->get_class(c10::QualifiedName(base, "FooTest"));
ASSERT_TRUE(classType1->hasAttribute("x"));
ASSERT_FALSE(classType1->hasAttribute("dx"));
auto classType2 = cu2->get_class(c10::QualifiedName(base, "FooTest"));
ASSERT_TRUE(classType2->hasAttribute("dx"));
ASSERT_FALSE(classType2->hasAttribute("x"));
// We should only see FooNestedTest in the first namespace
auto c = cu1->get_class(c10::QualifiedName(base, "FooNestedTest"));
ASSERT_TRUE(c);
c = cu2->get_class(c10::QualifiedName(base, "FooNestedTest"));
ASSERT_FALSE(c);
}
TEST(ClassImportTest, ScriptObject) {
Module m1("m1");
Module m2("m2");
std::vector<at::IValue> constantTable;
import_libs(
m1._ivalue()->compilation_unit(),
"__torch__.FooTest",
std::make_shared<Source>(classSrcs1),
constantTable);
import_libs(
m2._ivalue()->compilation_unit(),
"__torch__.FooTest",
std::make_shared<Source>(classSrcs2),
constantTable);
// Incorrect arguments for constructor should throw
c10::QualifiedName base("__torch__");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest"), {1}));
auto x = torch::ones({2, 3});
auto obj = m2.create_class(c10::QualifiedName(base, "FooTest"), x).toObject();
auto dx = obj->getAttr("dx");
ASSERT_TRUE(almostEqual(x, dx.toTensor()));
auto new_x = torch::rand({2, 3});
obj->setAttr("dx", new_x);
auto new_dx = obj->getAttr("dx");
ASSERT_TRUE(almostEqual(new_x, new_dx.toTensor()));
}
static const auto methodSrc = R"JIT(
def __init__(self, x):
return x
)JIT";
TEST(ClassImportTest, ClassDerive) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu);
const auto self = SimpleSelf(cls);
auto methods = cu->define("foo.bar", methodSrc, nativeResolver(), &self);
auto method = methods[0];
cls->addAttribute("attr", TensorType::get());
ASSERT_TRUE(cls->findMethod(method->name()));
// Refining a new class should retain attributes and methods
auto newCls = cls->refine({TensorType::get()});
ASSERT_TRUE(newCls->hasAttribute("attr"));
ASSERT_TRUE(newCls->findMethod(method->name()));
auto newCls2 = cls->withContained({TensorType::get()})->expect<ClassType>();
ASSERT_TRUE(newCls2->hasAttribute("attr"));
ASSERT_TRUE(newCls2->findMethod(method->name()));
}
static constexpr c10::string_view torchbindSrc = R"JIT(
class FooBar1234(Module):
__parameters__ = []
f : __torch__.torch.classes._TorchScriptTesting._StackString
training : bool
def forward(self: __torch__.FooBar1234) -> str:
return (self.f).top()
)JIT";
TEST(ClassImportTest, CustomClass) {
auto cu1 = std::make_shared<CompilationUnit>();
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
"__torch__.FooBar1234",
std::make_shared<Source>(torchbindSrc),
constantTable);
}
} // namespace jit
} // namespace torch