Skip to content

Commit a26bcc4

Browse files
committed
MLP const update
1 parent ea4beae commit a26bcc4

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

include/clients/nrt/MLPRegressorClient.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class MLPRegressorClient : public FluidBaseClient,
6161
public:
6262
using string = std::string;
6363
using BufferPtr = std::shared_ptr<BufferAdaptor>;
64+
using InputBufferPtr = std::shared_ptr<const BufferAdaptor>;
6465
using IndexVector = FluidTensor<index, 1>;
6566
using StringVector = FluidTensor<string, 1>;
6667
using DataSet = FluidDataSet<string, double, 1>;
@@ -93,7 +94,8 @@ class MLPRegressorClient : public FluidBaseClient,
9394
return {};
9495
}
9596

96-
MessageResult<double> fit(DataSetClientRef source, DataSetClientRef target)
97+
MessageResult<double> fit(InputDataSetClientRef source,
98+
InputDataSetClientRef target)
9799
{
98100
auto sourceClientPtr = source.get().lock();
99101
if (!sourceClientPtr) return Error<double>(NoDataSet);
@@ -130,8 +132,8 @@ class MLPRegressorClient : public FluidBaseClient,
130132
return error;
131133
}
132134

133-
MessageResult<void> predict(DataSetClientRef srcClient,
134-
DataSetClientRef destClient)
135+
MessageResult<void> predict(InputDataSetClientRef srcClient,
136+
DataSetClientRef destClient)
135137
{
136138
index inputTap = get<kInputTap>();
137139
index outputTap = get<kOutputTap>();
@@ -160,7 +162,7 @@ class MLPRegressorClient : public FluidBaseClient,
160162
return OK();
161163
}
162164

163-
MessageResult<void> predictPoint(BufferPtr in, BufferPtr out)
165+
MessageResult<void> predictPoint(InputBufferPtr in, BufferPtr out)
164166
{
165167
index inputTap = get<kInputTap>();
166168
index outputTap = get<kOutputTap>();
@@ -174,7 +176,7 @@ class MLPRegressorClient : public FluidBaseClient,
174176
index outputSize = mAlgorithm.outputSize(outputTap);
175177

176178
if (!in || !out) return Error(NoBuffer);
177-
BufferAdaptor::Access inBuf(in.get());
179+
BufferAdaptor::ReadAccess inBuf(in.get());
178180
BufferAdaptor::Access outBuf(out.get());
179181
if (!inBuf.exists()) return Error(InvalidBuffer);
180182
if (!outBuf.exists()) return Error(InvalidBuffer);
@@ -245,13 +247,13 @@ class MLPRegressorClient : public FluidBaseClient,
245247
}
246248
};
247249

248-
using MLPRegressorRef = SharedClientRef<MLPRegressorClient>;
250+
using MLPRegressorRef = SharedClientRef<const MLPRegressorClient>;
249251

250252
constexpr auto MLPRegressorQueryParams =
251253
defineParameters(MLPRegressorRef::makeParam("model", "Source Model"),
252254
LongParam("tapIn", "Input Tap Index", 0, Min(0)),
253255
LongParam("tapOut", "Output Tap Index", -1, Min(-1)),
254-
BufferParam("inputPointBuffer", "Input Point Buffer"),
256+
InputBufferParam("inputPointBuffer", "Input Point Buffer"),
255257
BufferParam("predictionBuffer", "Prediction Buffer"));
256258

257259
class MLPRegressorQuery : public FluidBaseClient, ControlIn, ControlOut
@@ -297,7 +299,7 @@ class MLPRegressorQuery : public FluidBaseClient, ControlIn, ControlOut
297299
return;
298300
}
299301

300-
algorithm::MLP& algorithm = MLPRef->algorithm();
302+
algorithm::MLP const& algorithm = MLPRef->algorithm();
301303

302304
if (!algorithm.trained()) return;
303305
index inputTap = get<kInputTap>();

0 commit comments

Comments
 (0)