diff --git a/include/zmq.h b/include/zmq.h index 4858d693ee..ff80f13318 100644 --- a/include/zmq.h +++ b/include/zmq.h @@ -190,6 +190,16 @@ extern "C" { #define ETERM (ZMQ_HAUSNUMERO + 53) #define EMTHREAD (ZMQ_HAUSNUMERO + 54) +// Curve error codes +#define ECURVEKEY (ZMQ_HAUSNUMERO + 55) // possibly wrong server key +#define ECURVEHANDSHAKE (ZMQ_HAUSNUMERO + 56) // invalid handshake command +#define ECURVECLIENT (ZMQ_HAUSNUMERO + 57) // invalid curve client +#define ECURVENONCE (ZMQ_HAUSNUMERO + 58) // wrong nonce +#define ECURVEHELLOVER (ZMQ_HAUSNUMERO + 59) // wrong hello version +#define ECURVEHELLOSIZE (ZMQ_HAUSNUMERO + 60) // wrong hello size +#define ECURVEHELLOCMD (ZMQ_HAUSNUMERO + 61) // wrong hello command + + /* This function retrieves the errno as it is known to 0MQ library. The goal */ /* of this function is to make the code 100% portable, including where 0MQ */ /* compiled with certain CRT library (on Windows) is linked to an */ @@ -202,6 +212,10 @@ ZMQ_EXPORT const char *zmq_strerror (int errnum); /* Run-time API version detection */ ZMQ_EXPORT void zmq_version (int *major, int *minor, int *patch); +/* Error handler callback */ +typedef void(*zmq_error_fn) (int err, const char* host, void* data); +ZMQ_EXPORT int zmq_error_handler(void* context, zmq_error_fn ffn, void* data); + /******************************************************************************/ /* 0MQ infrastructure (a.k.a. context) initialisation & termination. */ /******************************************************************************/ diff --git a/src/ctx.cpp b/src/ctx.cpp index 9dcdeca25e..6d7942a4fb 100644 --- a/src/ctx.cpp +++ b/src/ctx.cpp @@ -70,6 +70,8 @@ int clipped_maxsocket (int max_requested) zmq::ctx_t::ctx_t () : tag (ZMQ_CTX_TAG_VALUE_GOOD), + error_fn(0), + error_data(0), starting (true), terminating (false), reaper (NULL), @@ -107,6 +109,25 @@ bool zmq::ctx_t::check_tag () return tag == ZMQ_CTX_TAG_VALUE_GOOD; } +void zmq::ctx_t::set_error_handler(zmq_error_fn ffn, void* data) +{ + if(error_data) + { + free(error_data); + } + + error_fn = ffn; + error_data = data; +} + +void zmq::ctx_t::handle_error(int errno_, const char* host) +{ + if(error_fn) + { + error_fn(errno_, host, error_data); + } +} + zmq::ctx_t::~ctx_t () { // Check that there are no remaining sockets. @@ -137,6 +158,11 @@ zmq::ctx_t::~ctx_t () randombytes_close (); #endif + if(error_data) + { + free(error_data); + } + // Remove the tag, so that the object is considered dead. tag = ZMQ_CTX_TAG_VALUE_BAD; } diff --git a/src/ctx.hpp b/src/ctx.hpp index 953b7941ef..957198d913 100644 --- a/src/ctx.hpp +++ b/src/ctx.hpp @@ -75,6 +75,12 @@ namespace zmq // Returns false if object is not a context. bool check_tag (); + // set error handler callback + void set_error_handler(zmq_error_fn ffn, void* data); + + // redirects error to error handler callback if it is set + void handle_error(int errno_, const char* host); + // This function is called when user invokes zmq_ctx_term. If there are // no more sockets open it'll cause all the infrastructure to be shut // down. If there are open sockets still, the deallocation happens @@ -145,6 +151,10 @@ namespace zmq // Used to check whether the object is a context. uint32_t tag; + // Error handler callback + zmq_error_fn error_fn; + void* error_data; + // Sockets belonging to this context. We need the list so that // we can notify the sockets when zmq_ctx_term() is called. // The sockets will return ETERM then. diff --git a/src/curve_server.cpp b/src/curve_server.cpp index eff4290b33..8b50d3e459 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -103,7 +103,7 @@ int zmq::curve_server_t::process_handshake_command (msg_t *msg_) default: // Temporary support for security debugging puts ("CURVE I: invalid handshake command"); - errno = EPROTO; + errno = ECURVEHANDSHAKE; rc = -1; break; } @@ -175,7 +175,7 @@ int zmq::curve_server_t::decode (msg_t *msg_) if (msg_->size () < 33) { // Temporary support for security debugging puts ("CURVE I: invalid CURVE client, sent malformed command"); - errno = EPROTO; + errno = ECURVECLIENT; return -1; } @@ -183,7 +183,7 @@ int zmq::curve_server_t::decode (msg_t *msg_) if (memcmp (message, "\x07MESSAGE", 8)) { // Temporary support for security debugging puts ("CURVE I: invalid CURVE client, did not send MESSAGE"); - errno = EPROTO; + errno = ECURVECLIENT; return -1; } @@ -192,7 +192,7 @@ int zmq::curve_server_t::decode (msg_t *msg_) memcpy (message_nonce + 16, message + 8, 8); uint64_t nonce = get_uint64(message + 8); if (nonce <= cn_peer_nonce) { - errno = EPROTO; + errno = ECURVENONCE; return -1; } cn_peer_nonce = nonce; @@ -231,7 +231,7 @@ int zmq::curve_server_t::decode (msg_t *msg_) else { // Temporary support for security debugging puts ("CURVE I: connection key used for MESSAGE is wrong"); - errno = EPROTO; + errno = ECURVEKEY; } free (message_plaintext); free (message_box); @@ -269,7 +269,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) if (msg_->size () != 200) { // Temporary support for security debugging puts ("CURVE I: client HELLO is not correct size"); - errno = EPROTO; + errno = ECURVEHELLOSIZE; return -1; } @@ -277,7 +277,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) if (memcmp (hello, "\x05HELLO", 6)) { // Temporary support for security debugging puts ("CURVE I: client HELLO has invalid command name"); - errno = EPROTO; + errno = ECURVEHELLOCMD; return -1; } @@ -287,7 +287,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) if (major != 1 || minor != 0) { // Temporary support for security debugging puts ("CURVE I: client HELLO has unknown version number"); - errno = EPROTO; + errno = ECURVEHELLOVER; return -1; } @@ -312,7 +312,7 @@ int zmq::curve_server_t::process_hello (msg_t *msg_) if (rc != 0) { // Temporary support for security debugging puts ("CURVE I: cannot open client HELLO -- wrong server key?"); - errno = EPROTO; + errno = ECURVEKEY; return -1; } diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 1f38983ad5..e8abcca23b 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -965,6 +965,8 @@ int zmq::stream_engine_t::push_one_then_decode_and_push (msg_t *msg_) void zmq::stream_engine_t::error (error_reason_t reason) { + socket->get_ctx()->handle_error(errno, peer_address.c_str()); + if (options.raw_socket && options.raw_notify) { // For raw sockets, send a final 0-length message to the application // so that it knows the peer has been disconnected. diff --git a/src/zmq.cpp b/src/zmq.cpp index 5058f835be..9a93f0b77a 100644 --- a/src/zmq.cpp +++ b/src/zmq.cpp @@ -116,6 +116,18 @@ int zmq_errno (void) } +int zmq_error_handler(void* ctx_, zmq_error_fn ffn, void* data) +{ + if (!ctx_ || !((zmq::ctx_t *) ctx_)->check_tag()) { + errno = EFAULT; + return -1; + } + + ((zmq::ctx_t *) ctx_)->set_error_handler(ffn, data); + + return 0; +} + // New context API void *zmq_ctx_new (void)