diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9f2bc44ed683..ab442743df08 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -241,6 +241,7 @@ List of Contributors * [Zak Jost](https://github.com/zjost) * [Shoubhik Bhattacharya](https://github.com/shoubhik) * [Rohit Srivastava](https://github.com/access2rohit) +* [Caner Turkmen](https://github.com/canerturkmen) Label Bot --------- diff --git a/src/operator/contrib/hawkes_ll-inl.h b/src/operator/contrib/hawkes_ll-inl.h new file mode 100644 index 000000000000..d5e90ad6545d --- /dev/null +++ b/src/operator/contrib/hawkes_ll-inl.h @@ -0,0 +1,506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file hawkes_ll-inl.h + * \brief Log likelihood of a marked self-exciting Hawkes process + * \author Caner Turkmen + */ +#ifndef MXNET_OPERATOR_CONTRIB_HAWKES_LL_INL_H_ +#define MXNET_OPERATOR_CONTRIB_HAWKES_LL_INL_H_ + +#include +#include + +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +namespace hawkesll { + enum HawkesLLOpInputs {kMu, kAlpha, kBeta, kState, kIATimes, kMarks, + kValidLength, kMaxTime}; + enum HawkesLLGradInputs {kOutGradLL, kOutGradStates, kGradMu, kGradAlpha, + kGradBeta, kGradState, kGradIATimes, kGradMarks, + kGradValidLength, kGradMaxTime}; + enum HawkesLLOpOutputs {kOutLL, kOutStates}; + enum HawkesLLOpResource {kTempSpace}; +} // namespace hawkesll + +inline bool HawkesLLOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + // check dimensions of the type vectors + CHECK_EQ(in_attrs->size(), 8U); + CHECK_EQ(out_attrs->size(), 2U); + + TYPE_ASSIGN_CHECK(*out_attrs, hawkesll::kOutLL, in_attrs->at(0)) + TYPE_ASSIGN_CHECK(*out_attrs, hawkesll::kOutStates, in_attrs->at(0)) + + for (index_t j = 0; j < 8; ++j) { + if (j != hawkesll::kMarks) { + TYPE_ASSIGN_CHECK(*in_attrs, j, out_attrs->at(0)) + } + } + TYPE_ASSIGN_CHECK(*in_attrs, hawkesll::kMarks, 4) // int32 + + return out_attrs->at(hawkesll::kOutLL) != -1; +} + +inline bool HawkesLLOpShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + using namespace mshadow; + int N, T, K; + + CHECK_EQ(in_attrs->size(), 8U); + CHECK_EQ(out_attrs->size(), 2U); + + // check ndims + CHECK_EQ(in_attrs->at(hawkesll::kMu).ndim(), 2); // mu (N, K) + CHECK_EQ(in_attrs->at(hawkesll::kAlpha).ndim(), 1); // branching ratio (K,) + CHECK_EQ(in_attrs->at(hawkesll::kBeta).ndim(), 1); // decay exponent (K,) + CHECK_EQ(in_attrs->at(hawkesll::kState).ndim(), 2); // Hawkes states (N, K) + CHECK_EQ(in_attrs->at(hawkesll::kIATimes).ndim(), 2); // i.a. times (N, T) + CHECK_EQ(in_attrs->at(hawkesll::kMarks).ndim(), 2); // marks (N, T) + CHECK_EQ(in_attrs->at(hawkesll::kValidLength).ndim(), 1); // valid len (N,) + CHECK_EQ(in_attrs->at(hawkesll::kMaxTime).ndim(), 1); // max_time (N,) + + N = in_attrs->at(hawkesll::kIATimes)[0]; // number of samples in batch + T = in_attrs->at(hawkesll::kIATimes)[1]; // time length + K = in_attrs->at(hawkesll::kMu)[1]; // number of marks + + // check inputs consistent + CHECK_EQ(in_attrs->at(hawkesll::kMu)[0], N); + CHECK_EQ(in_attrs->at(hawkesll::kMu)[1], K); + CHECK_EQ(in_attrs->at(hawkesll::kAlpha)[0], K); + CHECK_EQ(in_attrs->at(hawkesll::kBeta)[0], K); + CHECK_EQ(in_attrs->at(hawkesll::kState)[0], N); + CHECK_EQ(in_attrs->at(hawkesll::kState)[1], K); + CHECK_EQ(in_attrs->at(hawkesll::kMarks)[0], N); + CHECK_EQ(in_attrs->at(hawkesll::kMarks)[1], T); + CHECK_EQ(in_attrs->at(hawkesll::kValidLength)[0], N); + CHECK_EQ(in_attrs->at(hawkesll::kMaxTime)[0], N); + + // infer output type + SHAPE_ASSIGN_CHECK(*out_attrs, hawkesll::kOutLL, Shape1(N)) + SHAPE_ASSIGN_CHECK(*out_attrs, hawkesll::kOutStates, Shape2(N, K)) + + return out_attrs->at(hawkesll::kOutLL).ndim() != 0U && + out_attrs->at(hawkesll::kOutStates).Size() != 0U; +} + +template +struct hawkesll_forward { + template + MSHADOW_XINLINE static void Map(int i, + DType* out_loglike, + DType* out_state, + const DType* mu, + const DType* alpha, + const DType* beta, + DType* state, + const DType* lags, + const int32_t* marks, + DType* valid_length, + DType* max_time, + int K, + int T, + DType* temp_register + ) { + int32_t ci; // current mark + DType ll = 0; // log likelihood + DType t = 0; // current time + DType d, ed, lda, comp; + DType *last_ = &temp_register[i * K]; + + const DType *mu_ = &mu[i * K]; + const DType *lag_ = &lags[i * T]; + const int32_t *mark_ = &marks[i * T]; + DType *state_ = &out_state[i * K]; + + // iterate over points in sequence + for (index_t j = 0; j < valid_length[i]; ++j) { + ci = mark_[j]; + t += lag_[j]; + d = t - last_[ci]; + ed = expf(-beta[ci] * d); + + lda = mu_[ci] + alpha[ci] * beta[ci] * state_[ci] * ed; + comp = mu_[ci] * d + alpha[ci] * state_[ci] * (1 - ed); + + ll += logf(lda) - comp; + + KERNEL_ASSIGN(state_[ci], req, 1 + (state_[ci] * ed)) + + last_[ci] = t; + } + + KERNEL_ASSIGN(out_loglike[i], req, ll) + } +}; + +template +struct hawkesll_forward_compensator { + template + MSHADOW_XINLINE static void Map(int i, + DType* rem_comp, + DType* out_state, + const DType* mu, + const DType* alpha, + const DType* beta, + const DType* max_time, + const int K, + const DType* last_buffer + ) { + DType d, ed; + int m = i % K; // mark + int j = i / K; // particle + + // take care of the remaining compensators and state update + d = max_time[j] - last_buffer[i]; + ed = expf(-beta[m] * d); + + // return the remaining compensator + KERNEL_ASSIGN(rem_comp[i], req, + mu[i] * d + alpha[m] * out_state[i] * (1 - ed)) + + // update the state + KERNEL_ASSIGN(out_state[i], req, ed * out_state[i]) + } +}; + +template +void HawkesLLForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + + Stream *s = ctx.get_stream(); + + CHECK_EQ(inputs.size(), 8U); + CHECK_EQ(outputs.size(), 2U); + + const TBlob& out_loglike = outputs[hawkesll::kOutLL]; + const TBlob& out_state = outputs[hawkesll::kOutStates]; + + int K = inputs[hawkesll::kMu].shape_[1]; + int N = inputs[hawkesll::kIATimes].shape_[0]; + int T = inputs[hawkesll::kIATimes].shape_[1]; + + MSHADOW_TYPE_SWITCH(out_loglike.type_flag_, DType, { + Tensor temp_space = ctx.requested[hawkesll::kTempSpace] + .get_space_typed( + Shape2(2*N, K), + s); + + Tensor last_buffer = + Tensor(&temp_space.dptr_[0], Shape2(N, K), s); + Tensor rem_comp = + Tensor(&temp_space.dptr_[N*K], Shape2(N, K), s); + + Tensor out_loglike_ts = + out_loglike.get_with_shape(Shape1(N), s); + + last_buffer = DType(0.0); + rem_comp = DType(0.0); + + Tensor out_state_ts = + out_state.get_with_shape(Shape2(N, K), s); + Tensor in_state_ts = + inputs[hawkesll::kState].get_with_shape(Shape2(N, K), s); + + mshadow::Copy(out_state_ts, in_state_ts, s); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, N, + out_loglike.dptr(), + out_state.dptr(), + inputs[hawkesll::kMu].dptr(), // mu + inputs[hawkesll::kAlpha].dptr(), // alpha + inputs[hawkesll::kBeta].dptr(), // beta + inputs[hawkesll::kState].dptr(), // states + inputs[hawkesll::kIATimes].dptr(), // interarrival times + inputs[hawkesll::kMarks].dptr(), // marks + inputs[hawkesll::kValidLength].dptr(), // valid_length + inputs[hawkesll::kMaxTime].dptr(), // max_time + K, + T, + last_buffer.dptr_); + }); + + // in parallel, we take care of the remaining compensators + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, N * K, + rem_comp.dptr_, + out_state.dptr(), + inputs[hawkesll::kMu].dptr(), // mu + inputs[hawkesll::kAlpha].dptr(), // alpha + inputs[hawkesll::kBeta].dptr(), // beta + inputs[hawkesll::kMaxTime].dptr(), // max_time + K, + last_buffer.dptr_); + }); + out_loglike_ts -= mshadow::expr::sumall_except_dim<0>(rem_comp); + }) +} + +template +struct hawkesll_backward { + template + MSHADOW_XINLINE static void Map(int i, // indexes the sample (particle) + DType* mu_gbfr, + DType* alpha_gbfr, + DType* beta_gbfr, // (N, K) + const DType* mu, // (N, K) + const DType* alpha, // (K,) + const DType* beta, // (K,) + const DType* lags, // (N, T) + const int32_t* marks, // (N, T) + const DType* valid_length, // (N,) + const DType* max_time, // (N,) + const int K, + const int T, + DType* last_buffer, + DType* phi_buffer, + DType* phig_buffer + ) { + int32_t ci; + int32_t part_ix_K = i*K, part_ix_T = i*T; + + DType t = 0, d, lda, ed; + DType* last_ = &last_buffer[part_ix_K]; + DType* state_ = &phi_buffer[part_ix_K]; + DType* dstate_ = &phig_buffer[part_ix_K]; + + DType* mug_ = &mu_gbfr[part_ix_K]; + DType* alphag_ = &alpha_gbfr[part_ix_K]; + DType* betag_ = &beta_gbfr[part_ix_K]; + + const DType* lag_ = &lags[part_ix_T]; + const int32_t* mark_ = &marks[part_ix_T]; + + // iterate over points + for (index_t j = 0; j < valid_length[i]; ++j){ + ci = mark_[j]; + t += lag_[j]; + d = t - last_[ci]; + ed = expf(-beta[ci] * d); + + lda = mu[part_ix_K + ci] + alpha[ci] * beta[ci] * state_[ci] * ed; + + KERNEL_ASSIGN(mug_[ci], req, mug_[ci] + (1 / lda) - d) + KERNEL_ASSIGN(alphag_[ci], req, + ( + alphag_[ci] + + (beta[ci] * state_[ci] * ed) / lda + - state_[ci] * (1 - ed) + ) + ) + KERNEL_ASSIGN(betag_[ci], req, + betag_[ci] + + alpha[ci] * ed + * (state_[ci] * (1 - beta[ci] * d) + beta[ci] * dstate_[ci]) + / lda + - alpha[ci] + * (dstate_[ci] * (1 - ed) + state_[ci] * d * ed) + ) + + KERNEL_ASSIGN(dstate_[ci], req, ed * (-d * state_[ci] + dstate_[ci])) + KERNEL_ASSIGN(state_[ci], req, 1 + (state_[ci] * ed)) + + last_[ci] = t; + } + } +}; + + +template +struct hawkesll_backward_compensator { + template + MSHADOW_XINLINE static void Map(int i, + DType* mu_gbfr, + DType* alpha_gbfr, + DType* beta_gbfr, // (N, K) + DType* out_grad, // read this (N,) + const DType* mu, // (N, K) + const DType* alpha, // (K,) + const DType* beta, // (K,) + const DType* max_time, // (N,) + const int K, + DType* last_buffer, + DType* phi_buffer, + DType* phig_buffer + ) { + DType d, ed; + int m = i % K; // mark + int j = i / K; // particle + int32_t part_ix_K = j*K; + DType* mug_ = &mu_gbfr[part_ix_K]; + DType* alphag_ = &alpha_gbfr[part_ix_K]; + DType* betag_ = &beta_gbfr[part_ix_K]; + + // take care of the remaining compensators and state update + d = max_time[j] - last_buffer[i]; + ed = expf(-beta[m] * d); + + // take care of the gradients of the remaining compensator + KERNEL_ASSIGN(mug_[m], req, mug_[m] - d) + KERNEL_ASSIGN(alphag_[m], req, + alphag_[m] - phi_buffer[i] * (1 - ed) + ) + KERNEL_ASSIGN(betag_[m], req, + betag_[m] - alpha[m] * ( + phig_buffer[i] * (1 - ed) + + phi_buffer[i] * d * ed + ) + ) + + // // correct the gradients with respect to output gradients + KERNEL_ASSIGN(mug_[m], req, out_grad[j] * mug_[m]) + KERNEL_ASSIGN(alphag_[m], req, out_grad[j] * alphag_[m]) + KERNEL_ASSIGN(betag_[m], req, out_grad[j] * betag_[m]) + } +}; + +template +void HawkesLLBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 10U); + CHECK_EQ(outputs.size(), 8U); + CHECK_EQ(req.size(), 8U); + + mshadow::Stream *s = ctx.get_stream(); + + int K = inputs[hawkesll::kGradMu].shape_[1]; // mu data + int N = inputs[hawkesll::kGradIATimes].shape_[0]; + int T = inputs[hawkesll::kGradIATimes].shape_[1]; + + CHECK_EQ(inputs[hawkesll::kOutGradLL].shape_[0], N); // grad of out 0 (LL) + CHECK_EQ(inputs[hawkesll::kOutGradStates].shape_[0], N); // grad out 1-states + CHECK_EQ(inputs[hawkesll::kOutGradStates].shape_[1], K); + + // sufficient statistics are not differentiated w.r.t. + CHECK_EQ(req[hawkesll::kIATimes], OpReqType::kNullOp); + CHECK_EQ(req[hawkesll::kMarks], OpReqType::kNullOp); + CHECK_EQ(req[hawkesll::kValidLength], OpReqType::kNullOp); + CHECK_EQ(req[hawkesll::kMaxTime], OpReqType::kNullOp); + + const TBlob& out_grad = inputs[hawkesll::kOutGradLL]; + + using namespace mshadow; + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { + // allocate gradient buffers + Tensor bfr = + ctx.requested[hawkesll::kTempSpace] + .get_space_typed(Shape2(6*N, K), s); + + Tensor alpha_gbfr = + Tensor(&bfr.dptr_[N*K], Shape2(N, K), s); + Tensor beta_gbfr = + Tensor(&bfr.dptr_[2*N*K], Shape2(N, K), s); + Tensor last_buffer = + Tensor(&bfr.dptr_[3*N*K], Shape2(N, K), s); + Tensor phig_buffer = + Tensor(&bfr.dptr_[4*N*K], Shape2(N, K), s); + Tensor phi_buffer = + Tensor(&bfr.dptr_[5*N*K], Shape2(N, K), s); + + alpha_gbfr = DType(0.0); + beta_gbfr = DType(0.0); + last_buffer = DType(0.0); + phig_buffer = DType(0.0); + + mshadow::Copy(phi_buffer, + inputs[hawkesll::kGradState] + .get_with_shape(Shape2(N, K), s), + s); + + // get the gradient to be output + Tensor in_grad_mu = + outputs[hawkesll::kMu].get_with_shape(Shape2(N, K), s); + Tensor in_grad_alpha = + outputs[hawkesll::kAlpha].get_with_shape(Shape1(K), s); + Tensor in_grad_beta = + outputs[hawkesll::kBeta].get_with_shape(Shape1(K), s); + + in_grad_mu = DType(0.0); + + MXNET_ASSIGN_REQ_SWITCH(req[hawkesll::kMu], req_type, { + Kernel, xpu>::Launch( + s, + N, + in_grad_mu.dptr_, alpha_gbfr.dptr_, beta_gbfr.dptr_, // gradients + inputs[hawkesll::kGradMu].dptr(), // mu data + inputs[hawkesll::kGradAlpha].dptr(), // alpha data + inputs[hawkesll::kGradBeta].dptr(), // beta data + inputs[hawkesll::kGradIATimes].dptr(), // lags data + inputs[hawkesll::kGradMarks].dptr(), // marks data + inputs[hawkesll::kGradValidLength].dptr(), // valid_length data + inputs[hawkesll::kGradMaxTime].dptr(), // max_time data + K, + T, + last_buffer.dptr_, // buffer to keep timestamp of last item + phi_buffer.dptr_, // "states" + phig_buffer.dptr_); // derivatives of "states" + }); + + MXNET_ASSIGN_REQ_SWITCH(req[hawkesll::kMu], req_type, { + Kernel, xpu>::Launch( + s, + N * K, + in_grad_mu.dptr_, alpha_gbfr.dptr_, beta_gbfr.dptr_, // gradients + out_grad.dptr(), + inputs[hawkesll::kGradMu].dptr(), // mu data + inputs[hawkesll::kGradAlpha].dptr(), // alpha data + inputs[hawkesll::kGradBeta].dptr(), // beta data + inputs[hawkesll::kGradMaxTime].dptr(), // max_time data + K, + last_buffer.dptr_, // buffer to keep timestamp of last item + phi_buffer.dptr_, // "states" + phig_buffer.dptr_); // derivatives of "states" + }); + + // reduce the gradients + Assign(in_grad_alpha, req[hawkesll::kAlpha], + mshadow::expr::sumall_except_dim<1>(alpha_gbfr) + ) + + Assign(in_grad_beta, req[hawkesll::kBeta], + mshadow::expr::sumall_except_dim<1>(beta_gbfr) + ) + }) +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_HAWKES_LL_INL_H_ diff --git a/src/operator/contrib/hawkes_ll.cc b/src/operator/contrib/hawkes_ll.cc new file mode 100644 index 000000000000..758ab2012580 --- /dev/null +++ b/src/operator/contrib/hawkes_ll.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file hawkes_ll.cc + * \brief Log likelihood of a marked self-exciting Hawkes process + * \author Caner Turkmen + */ + +#include "./hawkes_ll-inl.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_hawkesll) + .describe(R"code(Computes the log likelihood of a univariate Hawkes process. + +The log likelihood is calculated on point process observations represented +as *ragged* matrices for *lags* (interarrival times w.r.t. the previous point), +and *marks* (identifiers for the process ID). Note that each mark is considered independent, +i.e., computes the joint likelihood of a set of Hawkes processes determined by the conditional intensity: + +.. math:: + + \lambda_k^*(t) = \lambda_k + \alpha_k \sum_{\{t_i < t, y_i = k\}} \beta_k \exp(-\beta_k (t - t_i)) + +where :math:`\lambda_k` specifies the background intensity ``lda``, :math:`\alpha_k` specifies the *branching ratio* or ``alpha``, and :math:`\beta_k` the delay density parameter ``beta``. + +``lags`` and ``marks`` are two NDArrays of shape (N, T) and correspond to the representation of the point process observation, the first dimension corresponds to the batch index, and the second to the sequence. These are "left-aligned" *ragged* matrices (the first index of the second dimension is the beginning of every sequence. The length of each sequence is given by ``valid_length``, of shape (N,) where ``valid_length[i]`` corresponds to the number of valid points in ``lags[i, :]`` and ``marks[i, :]``. + +``max_time`` is the length of the observation period of the point process. That is, specifying ``max_time[i] = 5`` computes the likelihood of the i-th sample as observed on the time interval :math:`(0, 5]`. Naturally, the sum of all valid ``lags[i, :valid_length[i]]`` must be less than or equal to 5. + +The input ``state`` specifies the *memory* of the Hawkes process. Invoking the memoryless property of exponential decays, we compute the *memory* as + +.. math:: + + s_k(t) = \sum_{t_i < t} \exp(-\beta_k (t - t_i)). + +The ``state`` to be provided is :math:`s_k(0)` and carries the added intensity due to past events before the current batch. :math:`s_k(T)` is returned from the function where :math:`T` is ``max_time[T]``. + +Example:: + + # define the Hawkes process parameters + lda = nd.array([1.5, 2.0, 3.0]).tile((N, 1)) + alpha = nd.array([0.2, 0.3, 0.4]) # branching ratios should be < 1 + beta = nd.array([1.0, 2.0, 3.0]) + + # the "data", or observations + ia_times = nd.array([[6, 7, 8, 9], [1, 2, 3, 4], [3, 4, 5, 6], [8, 9, 10, 11]]) + marks = nd.zeros((N, T)).astype(np.int32) + + # starting "state" of the process + states = nd.zeros((N, K)) + + valid_length = nd.array([1, 2, 3, 4]) # number of valid points in each sequence + max_time = nd.ones((N,)) * 100.0 # length of the observation period + + A = nd.contrib.hawkesll( + lda, alpha, beta, states, ia_times, marks, valid_length, max_time + ) + +References: + +- Bacry, E., Mastromatteo, I., & Muzy, J. F. (2015). + Hawkes processes in finance. Market Microstructure and Liquidity + , 1(01), 1550005. +)code" ADD_FILELINE) + .set_num_inputs(8) + .set_num_outputs(2) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{ + "lda", "alpha", "beta", "state", "lags", + "marks", "valid_length", "max_time" + }; + }) + .set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "out_state"}; + }) + .set_attr("FInferShape", HawkesLLOpShape) + .set_attr("FInferType", HawkesLLOpType) + .set_attr("FCompute", HawkesLLForward) + .set_attr( + "FGradient", ElemwiseGradUseIn{"_contrib_backward_hawkesll"} + ) + .set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::Type::kTempSpace}; + }) + .add_argument( + "lda", "NDArray-or-Symbol", + "Shape (N, K) The intensity for each of the K processes, for each sample" + ) + .add_argument( + "alpha", "NDArray-or-Symbol", + "Shape (K,) The infectivity factor (branching ratio) for each process" + ) + .add_argument( + "beta", "NDArray-or-Symbol", + "Shape (K,) The decay parameter for each process" + ) + .add_argument( + "state", "NDArray-or-Symbol", + "Shape (N, K) the Hawkes state for each process" + ) + .add_argument( + "lags", "NDArray-or-Symbol", + "Shape (N, T) the interarrival times" + ) + .add_argument( + "marks", "NDArray-or-Symbol", + "Shape (N, T) the marks (process ids)" + ) + .add_argument( + "valid_length", "NDArray-or-Symbol", + "The number of valid points in the process" + ) + .add_argument( + "max_time", "NDArray-or-Symbol", + "the length of the interval where the processes were sampled"); + +NNVM_REGISTER_OP(_contrib_backward_hawkesll) + .set_num_inputs(10) + .set_num_outputs(8) + .set_attr("TIsBackward", true) + .set_attr("FCompute", HawkesLLBackward) + .set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::Type::kTempSpace}; + }); +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/hawkes_ll.cu b/src/operator/contrib/hawkes_ll.cu new file mode 100755 index 000000000000..d35d7d0b0c08 --- /dev/null +++ b/src/operator/contrib/hawkes_ll.cu @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file hawkes_ll.cu + * \brief Log likelihood of a marked self-exciting Hawkes process + * \author Caner Turkmen + */ + +#include "./hawkes_ll-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_hawkesll) +.set_attr("FCompute", HawkesLLForward); + +NNVM_REGISTER_OP(_contrib_backward_hawkesll) +.set_attr("FCompute", HawkesLLBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_contrib_hawkesll.py b/tests/python/unittest/test_contrib_hawkesll.py new file mode 100644 index 000000000000..a4b1d9de605f --- /dev/null +++ b/tests/python/unittest/test_contrib_hawkesll.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +from mxnet import nd + + +def test_hawkesll_output_ok(): + T, N, K = 4, 4, 3 + + mu = nd.array([1.5, 2.0, 3.0]).tile((N, 1)) + alpha = nd.array([0.2, 0.3, 0.4]) + beta = nd.array([1.0, 2.0, 3.0]) + + lags = nd.array([[6, 7, 8, 9], [1, 2, 3, 4], [3, 4, 5, 6], [8, 9, 10, 11]]) + marks = nd.zeros((N, T)).astype(np.int32) + states = nd.zeros((N, K)) + + valid_length = nd.array([1, 2, 3, 4]) + max_time = nd.ones((N,)) * 100.0 + + A = nd.contrib.hawkesll( + mu, alpha, beta, states, lags, marks, valid_length, max_time + ) + + assert np.allclose( + np.array([-649.79453489, -649.57118596, -649.38025115, -649.17811484]), + A[0].asnumpy(), + ) + + +def test_hawkesll_output_multivariate_ok(): + T, N, K = 9, 2, 3 + + mu = nd.array([1.5, 2.0, 3.0]) + alpha = nd.array([0.2, 0.3, 0.4]) + beta = nd.array([2.0, 2.0, 2.0]) + + lags = nd.array([[6, 7, 8, 9, 3, 2, 5, 1, 7], [1, 2, 3, 4, 2, 1, 2, 1, 4]]) + marks = nd.array([[0, 1, 2, 1, 0, 2, 1, 0, 2], [1, 2, 0, 0, 0, 2, 2, 1, 0]]).astype( + np.int32 + ) + + states = nd.zeros((N, K)) + + valid_length = nd.array([7, 9]) + max_time = nd.ones((N,)) * 100.0 + + A = nd.contrib.hawkesll( + mu.tile((N, 1)), alpha, beta, states, lags, marks, valid_length, max_time + ) + + assert np.allclose(np.array([-647.01240372, -646.28617272]), A[0].asnumpy()) + + +def test_hawkesll_backward_correct(): + ctx = mx.cpu() + + mu = nd.array([1.5, 2.0, 3.0]) + alpha = nd.array([0.2, 0.3, 0.4]) + beta = nd.array([2.0, 2.0, 2.0]) + + T, N, K = 9, 2, 3 + lags = nd.array([[6, 7, 8, 9, 3, 2, 5, 1, 7], [1, 2, 3, 4, 2, 1, 2, 1, 4]]) + marks = nd.array([[0, 0, 0, 1, 0, 0, 1, 2, 0], [1, 2, 0, 0, 0, 2, 2, 1, 0]]).astype( + np.int32 + ) + valid_length = nd.array([9, 9]) + states = nd.zeros((N, K)) + + max_time = nd.ones((N,)) * 100.0 + + mu.attach_grad() + alpha.attach_grad() + beta.attach_grad() + + with mx.autograd.record(): + A, _ = nd.contrib.hawkesll( + mu.tile((N, 1)), alpha, beta, states, lags, marks, valid_length, max_time + ) + A.backward() + + dmu, dalpha, dbeta = ( + np.array([-193.33987481, -198.0, -198.66828681]), + np.array([-9.95093892, -4.0, -3.98784892]), + np.array([-1.49052169e-02, -5.87469511e-09, -7.29065224e-03]), + ) + assert np.allclose(dmu, mu.grad.asnumpy()) + assert np.allclose(dalpha, alpha.grad.asnumpy()) + assert np.allclose(dbeta, beta.grad.asnumpy()) + + +def test_hawkesll_forward_single_mark(): + _dtype = np.float32 + + mu = nd.array([1.5]).astype(_dtype) + alpha = nd.array([0.2]).astype(_dtype) + beta = nd.array([1.0]).astype(_dtype) + + T, N, K = 7, 1, 1 + lags = nd.array([[6, 7, 8, 3, 2, 1, 7]]).astype(_dtype) + marks = nd.array([[0, 0, 0, 0, 0, 0, 0]]).astype(np.int32) + valid_length = nd.array([7]).astype(_dtype) + + states = nd.zeros((N, K)).astype(_dtype) + max_time = nd.ones((N,)).astype(_dtype) * 100 + + A, _ = nd.contrib.hawkesll( + mu.tile((N, 1)), alpha, beta, states, lags, marks, valid_length, max_time + ) + + assert np.allclose(A[0].asscalar(), -148.4815) + + +def test_hawkesll_backward_single_mark(): + _dtype = np.float32 + + mu = nd.array([1.5]).astype(_dtype) + alpha = nd.array([0.2]).astype(_dtype) + beta = nd.array([1.0]).astype(_dtype) + + T, N, K = 7, 1, 1 + lags = nd.array([[6, 7, 8, 3, 2, 1, 7]]).astype(_dtype) + marks = nd.array([[0, 0, 0, 0, 0, 0, 0]]).astype(np.int32) + valid_length = nd.array([7]).astype(_dtype) + + states = nd.zeros((N, K)).astype(_dtype) + max_time = nd.ones((N,)).astype(_dtype) * 40 + + mu.attach_grad() + beta.attach_grad() + + with mx.autograd.record(): + A, _ = nd.contrib.hawkesll( + mu.tile((N, 1)), alpha, beta, states, lags, marks, valid_length, max_time + ) + + A.backward() + + assert np.allclose(beta.grad.asnumpy().sum(), -0.05371582) + + +if __name__ == "__main__": + import nose + + nose.runmodule()