-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
Copy pathgblinear_model.h
143 lines (129 loc) · 4.11 KB
/
gblinear_model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*!
* Copyright 2018-2019 by Contributors
*/
#pragma once
#include <dmlc/io.h>
#include <dmlc/parameter.h>
#include <xgboost/learner.h>
#include <vector>
#include <string>
#include <cstring>
#include "xgboost/base.h"
#include "xgboost/feature_map.h"
#include "xgboost/model.h"
#include "xgboost/json.h"
#include "xgboost/parameter.h"
namespace xgboost {
class Json;
namespace gbm {
// Deprecated in 1.0.0. model parameter. Only staying here for compatible binary model IO.
struct DeprecatedGBLinearModelParam : public dmlc::Parameter<DeprecatedGBLinearModelParam> {
// number of feature dimension
uint32_t deprecated_num_feature;
// deprecated. use learner_model_param_->num_output_group.
int32_t deprecated_num_output_group;
// reserved field
int32_t reserved[32];
// constructor
DeprecatedGBLinearModelParam() {
static_assert(sizeof(*this) == sizeof(int32_t) * 34,
"Model parameter size can not be changed.");
std::memset(this, 0, sizeof(DeprecatedGBLinearModelParam));
}
DMLC_DECLARE_PARAMETER(DeprecatedGBLinearModelParam) {}
};
// model for linear booster
class GBLinearModel : public Model {
private:
// Deprecated in 1.0.0
DeprecatedGBLinearModelParam param_;
public:
LearnerModelParam const* learner_model_param;
public:
explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
learner_model_param {learner_model_param} {}
void Configure(Args const &cfg) { }
// weight for each of feature, bias is the last one
std::vector<bst_float> weight;
// initialize the model parameter
inline void LazyInitModel() {
if (!weight.empty()) {
return;
}
// bias is the last weight
weight.resize((learner_model_param->num_feature + 1) *
learner_model_param->num_output_group);
std::fill(weight.begin(), weight.end(), 0.0f);
}
void SaveModel(Json *p_out) const override;
void LoadModel(Json const &in) override;
// save the model to file
void Save(dmlc::Stream *fo) const {
fo->Write(¶m_, sizeof(param_));
fo->Write(weight);
}
// load model from file
void Load(dmlc::Stream *fi) {
CHECK_EQ(fi->Read(¶m_, sizeof(param_)), sizeof(param_));
fi->Read(&weight);
}
// model bias
inline bst_float *Bias() {
return &weight[learner_model_param->num_feature *
learner_model_param->num_output_group];
}
inline const bst_float *Bias() const {
return &weight[learner_model_param->num_feature *
learner_model_param->num_output_group];
}
// get i-th weight
inline bst_float *operator[](size_t i) {
return &weight[i * learner_model_param->num_output_group];
}
inline const bst_float *operator[](size_t i) const {
return &weight[i * learner_model_param->num_output_group];
}
std::vector<std::string> DumpModel(const FeatureMap &fmap, bool with_stats,
std::string format) const {
const int ngroup = learner_model_param->num_output_group;
const unsigned nfeature = learner_model_param->num_feature;
std::stringstream fo("");
if (format == "json") {
fo << " { \"bias\": [" << std::endl;
for (int gid = 0; gid < ngroup; ++gid) {
if (gid != 0) {
fo << "," << std::endl;
}
fo << " " << this->Bias()[gid];
}
fo << std::endl
<< " ]," << std::endl
<< " \"weight\": [" << std::endl;
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
if (i != 0 || gid != 0) {
fo << "," << std::endl;
}
fo << " " << (*this)[i][gid];
}
}
fo << std::endl << " ]" << std::endl << " }";
} else {
fo << "bias:\n";
for (int gid = 0; gid < ngroup; ++gid) {
fo << this->Bias()[gid] << std::endl;
}
fo << "weight:\n";
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
fo << (*this)[i][gid] << std::endl;
}
}
}
std::vector<std::string> v;
v.push_back(fo.str());
return v;
}
};
} // namespace gbm
} // namespace xgboost