17
17
#include " ./base.h"
18
18
#include " ../../src/common/span.h"
19
19
20
+ #include " ../../src/common/host_device_vector.h"
21
+
20
22
namespace xgboost {
21
23
// forward declare learner.
22
24
class LearnerImpl ;
@@ -41,7 +43,7 @@ class MetaInfo {
41
43
/* ! \brief number of nonzero entries in the data */
42
44
uint64_t num_nonzero_{0 };
43
45
/* ! \brief label of each instance */
44
- std::vector <bst_float> labels_;
46
+ HostDeviceVector <bst_float> labels_;
45
47
/* !
46
48
* \brief specified root index of each instance,
47
49
* can be used for multi task setting
@@ -53,15 +55,15 @@ class MetaInfo {
53
55
*/
54
56
std::vector<bst_uint> group_ptr_;
55
57
/* ! \brief weights of each instance, optional */
56
- std::vector <bst_float> weights_;
58
+ HostDeviceVector <bst_float> weights_;
57
59
/* ! \brief session-id of each instance, optional */
58
60
std::vector<uint64_t > qids_;
59
61
/* !
60
62
* \brief initialized margins,
61
63
* if specified, xgboost will start from this init margin
62
64
* can be used to specify initial prediction to boost from.
63
65
*/
64
- std::vector <bst_float> base_margin_;
66
+ HostDeviceVector <bst_float> base_margin_;
65
67
/* ! \brief version flag, used to check version of this info */
66
68
static const int kVersion = 2 ;
67
69
/* ! \brief version that introduced qid field */
@@ -74,7 +76,7 @@ class MetaInfo {
74
76
* \return The weight.
75
77
*/
76
78
inline bst_float GetWeight (size_t i) const {
77
- return weights_.size () != 0 ? weights_[i] : 1 .0f ;
79
+ return weights_.Size () != 0 ? weights_. HostVector () [i] : 1 .0f ;
78
80
}
79
81
/* !
80
82
* \brief Get the root index of i-th instance.
@@ -86,12 +88,12 @@ class MetaInfo {
86
88
}
87
89
/* ! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
88
90
inline const std::vector<size_t >& LabelAbsSort () const {
89
- if (label_order_cache_.size () == labels_.size ()) {
91
+ if (label_order_cache_.size () == labels_.Size ()) {
90
92
return label_order_cache_;
91
93
}
92
- label_order_cache_.resize (labels_.size ());
94
+ label_order_cache_.resize (labels_.Size ());
93
95
std::iota (label_order_cache_.begin (), label_order_cache_.end (), 0 );
94
- const auto l = labels_;
96
+ const auto & l = labels_. HostVector () ;
95
97
XGBOOST_PARALLEL_SORT (label_order_cache_.begin (), label_order_cache_.end (),
96
98
[&l](size_t i1, size_t i2) {return std::abs (l[i1]) < std::abs (l[i2]);});
97
99
@@ -151,9 +153,9 @@ struct Entry {
151
153
*/
152
154
class SparsePage {
153
155
public:
154
- std::vector <size_t > offset;
156
+ HostDeviceVector <size_t > offset;
155
157
/* ! \brief the data of the segments */
156
- std::vector <Entry> data;
158
+ HostDeviceVector <Entry> data;
157
159
158
160
size_t base_rowid;
159
161
@@ -162,8 +164,10 @@ class SparsePage {
162
164
163
165
/* ! \brief get i-th row from the batch */
164
166
inline Inst operator [](size_t i) const {
165
- return {data.data () + offset[i],
166
- static_cast <Inst::index_type>(offset[i + 1 ] - offset[i])};
167
+ const auto & data_vec = data.HostVector ();
168
+ const auto & offset_vec = offset.HostVector ();
169
+ return {data_vec.data () + offset_vec[i],
170
+ static_cast <Inst::index_type>(offset_vec[i + 1 ] - offset_vec[i])};
167
171
}
168
172
169
173
/* ! \brief constructor */
@@ -172,73 +176,81 @@ class SparsePage {
172
176
}
173
177
/* ! \return number of instance in the page */
174
178
inline size_t Size () const {
175
- return offset.size () - 1 ;
179
+ return offset.Size () - 1 ;
176
180
}
177
181
/* ! \return estimation of memory cost of this page */
178
182
inline size_t MemCostBytes () const {
179
- return offset.size () * sizeof (size_t ) + data.size () * sizeof (Entry);
183
+ return offset.Size () * sizeof (size_t ) + data.Size () * sizeof (Entry);
180
184
}
181
185
/* ! \brief clear the page */
182
186
inline void Clear () {
183
187
base_rowid = 0 ;
184
- offset.clear ();
185
- offset.push_back (0 );
186
- data.clear ();
188
+ auto & offset_vec = offset.HostVector ();
189
+ offset_vec.clear ();
190
+ offset_vec.push_back (0 );
191
+ data.HostVector ().clear ();
187
192
}
188
193
189
194
/* !
190
195
* \brief Push row block into the page.
191
196
* \param batch the row batch.
192
197
*/
193
198
inline void Push (const dmlc::RowBlock<uint32_t >& batch) {
194
- data.reserve (data.size () + batch.offset [batch.size ] - batch.offset [0 ]);
195
- offset.reserve (offset.size () + batch.size );
199
+ auto & data_vec = data.HostVector ();
200
+ auto & offset_vec = offset.HostVector ();
201
+ data_vec.reserve (data.Size () + batch.offset [batch.size ] - batch.offset [0 ]);
202
+ offset_vec.reserve (offset.Size () + batch.size );
196
203
CHECK (batch.index != nullptr );
197
204
for (size_t i = 0 ; i < batch.size ; ++i) {
198
- offset .push_back (offset .back () + batch.offset [i + 1 ] - batch.offset [i]);
205
+ offset_vec .push_back (offset_vec .back () + batch.offset [i + 1 ] - batch.offset [i]);
199
206
}
200
207
for (size_t i = batch.offset [0 ]; i < batch.offset [batch.size ]; ++i) {
201
208
uint32_t index = batch.index [i];
202
209
bst_float fvalue = batch.value == nullptr ? 1 .0f : batch.value [i];
203
- data .emplace_back (index , fvalue);
210
+ data_vec .emplace_back (index , fvalue);
204
211
}
205
- CHECK_EQ (offset .back (), data.size ());
212
+ CHECK_EQ (offset_vec .back (), data.Size ());
206
213
}
207
214
/* !
208
215
* \brief Push a sparse page
209
216
* \param batch the row page
210
217
*/
211
218
inline void Push (const SparsePage &batch) {
212
- size_t top = offset.back ();
213
- data.resize (top + batch.data .size ());
214
- std::memcpy (dmlc::BeginPtr (data) + top,
215
- dmlc::BeginPtr (batch.data ),
216
- sizeof (Entry) * batch.data .size ());
217
- size_t begin = offset.size ();
218
- offset.resize (begin + batch.Size ());
219
+ auto & data_vec = data.HostVector ();
220
+ auto & offset_vec = offset.HostVector ();
221
+ const auto & batch_offset_vec = batch.offset .HostVector ();
222
+ const auto & batch_data_vec = batch.data .HostVector ();
223
+ size_t top = offset_vec.back ();
224
+ data_vec.resize (top + batch.data .Size ());
225
+ std::memcpy (dmlc::BeginPtr (data_vec) + top,
226
+ dmlc::BeginPtr (batch_data_vec),
227
+ sizeof (Entry) * batch.data .Size ());
228
+ size_t begin = offset.Size ();
229
+ offset_vec.resize (begin + batch.Size ());
219
230
for (size_t i = 0 ; i < batch.Size (); ++i) {
220
- offset [i + begin] = top + batch. offset [i + 1 ];
231
+ offset_vec [i + begin] = top + batch_offset_vec [i + 1 ];
221
232
}
222
233
}
223
234
/* !
224
235
* \brief Push one instance into page
225
236
* \param inst an instance row
226
237
*/
227
238
inline void Push (const Inst &inst) {
228
- offset.push_back (offset.back () + inst.size ());
229
- size_t begin = data.size ();
230
- data.resize (begin + inst.size ());
239
+ auto & data_vec = data.HostVector ();
240
+ auto & offset_vec = offset.HostVector ();
241
+ offset_vec.push_back (offset_vec.back () + inst.size ());
242
+
243
+ size_t begin = data_vec.size ();
244
+ data_vec.resize (begin + inst.size ());
231
245
if (inst.size () != 0 ) {
232
- std::memcpy (dmlc::BeginPtr (data ) + begin, inst.data (),
246
+ std::memcpy (dmlc::BeginPtr (data_vec ) + begin, inst.data (),
233
247
sizeof (Entry) * inst.size ());
234
248
}
235
249
}
236
250
237
- size_t Size () { return offset.size () - 1 ; }
251
+ size_t Size () { return offset.Size () - 1 ; }
238
252
};
239
253
240
-
241
-
242
254
/* !
243
255
* \brief This is data structure that user can pass to DMatrix::Create
244
256
* to create a DMatrix for training, user can create this data structure
0 commit comments