diff --git a/augmentation.c b/augmentation.c index d678ac8a..aab35c44 100644 --- a/augmentation.c +++ b/augmentation.c @@ -72,6 +72,10 @@ static void housekeep_augment_cache_locked(void) static augmentation_response *augmentation_response_init(void) { augmentation_response *response = kzalloc(sizeof(augmentation_response), GFP_KERNEL); + if (!response) + { + return ERR_PTR(-ENOMEM); + } kref_init(&response->kref); @@ -97,12 +101,10 @@ static void augmentation_response_release(struct kref *kref) augmentation_response_free(response); } -void augmentation_response_get(augmentation_response *response) +static void augmentation_response_get(augmentation_response *response) { if (!response) - { return; - } kref_get(&response->kref); } @@ -110,9 +112,7 @@ void augmentation_response_get(augmentation_response *response) void augmentation_response_put(augmentation_response *response) { if (!response) - { return; - } kref_put(&response->kref, augmentation_response_release); } @@ -121,17 +121,25 @@ static char *get_task_context_key(task_context *ctx) { int keylen = snprintf(NULL, 0, "%s %d %d %u %s", ctx->cgroup_path, ctx->uid.val, ctx->gid.val, ctx->namespace_ids.mnt, ctx->command_path); char *key = kzalloc(keylen + 1, GFP_KERNEL); + if (!key) + { + return ERR_PTR(-ENOMEM); + } snprintf(key, keylen + 1, "%s %d %d %u %s", ctx->cgroup_path, ctx->uid.val, ctx->gid.val, ctx->namespace_ids.mnt, ctx->command_path); return key; } -static void augmentation_response_cache_set_locked(char *key, augmentation_response *response) +static int augmentation_response_cache_set_locked(char *key, augmentation_response *response) { if (!key) { - pr_err("invalid key"); - return; + return -EINVAL; + } + + if (!response) + { + return -EINVAL; } housekeep_augment_cache_locked(); @@ -139,15 +147,21 @@ static void augmentation_response_cache_set_locked(char *key, augmentation_respo augmentation_response_cache_entry *new_entry = kzalloc(sizeof(augmentation_response_cache_entry), GFP_KERNEL); if (!new_entry) { - pr_err("memory allocation error"); - return; + return -ENOMEM; } - new_entry->key = strdup(key); + new_entry->key = kstrdup(key, GFP_KERNEL); + if (!new_entry->key) + { + kfree(new_entry); + return -ENOMEM; + } new_entry->response = response; pr_debug("add entry # key[%s]", new_entry->key); list_add(&new_entry->list, &augmentation_response_cache); + + return 0; } static augmentation_response *augmentation_response_cache_get_locked(char *key) @@ -174,34 +188,78 @@ static augmentation_response *augmentation_response_cache_get_locked(char *key) augmentation_response *augment_workload() { augmentation_response *response; + void *error; - augmentation_response_cache_lock(); + task_context *ctx = get_task_context(); + if (IS_ERR(ctx)) + return (void *)ctx; + + char *key = get_task_context_key(ctx); + if (IS_ERR(key)) + return (void *)key; - char *key = get_task_context_key(get_task_context()); + augmentation_response_cache_lock(); response = augmentation_response_cache_get_locked(key); + if (response) { augmentation_response_get(response); + augmentation_response_cache_unlock(); goto ret; } + augmentation_response_cache_unlock(); + response = augmentation_response_init(); + if (IS_ERR(response)) + { + goto ret; + } + command_answer *answer = send_augment_command(); + if (IS_ERR(answer)) + { + error = answer; + goto error; + } if (answer->error) - response->error = strdup(answer->error); + { + response->error = kstrdup(answer->error, GFP_KERNEL); + if (!response->error) + { + error = ERR_PTR(-ENOMEM); + goto error; + } + } else { - response->response = strdup(answer->answer); - augmentation_response_cache_set_locked(key, response); + response->response = kstrdup(answer->answer, GFP_KERNEL); + if (!response->response) + { + error = ERR_PTR(-ENOMEM); + goto error; + } + + augmentation_response_cache_lock(); + int ret = augmentation_response_cache_set_locked(key, response); + if (ret < 0) + { + error = ERR_PTR(ret); + augmentation_response_cache_unlock(); + goto error; + } augmentation_response_get(response); + augmentation_response_cache_unlock(); } free_command_answer(answer); ret: kfree(key); - - augmentation_response_cache_unlock(); - return response; + +error: + kfree(key); + augmentation_response_free(response); + return error; } diff --git a/augmentation.h b/augmentation.h index 63db47a3..f059b385 100644 --- a/augmentation.h +++ b/augmentation.h @@ -27,8 +27,12 @@ typedef struct struct list_head list; } augmentation_response_cache_entry; -void augmentation_response_get(augmentation_response *response); void augmentation_response_put(augmentation_response *response); +/* + * augment_workload + * + * returns an augmentation_response struct pointer or ERR_PTR() on error + */ augmentation_response *augment_workload(void); #endif diff --git a/buffer.c b/buffer.c index 45aa2b29..b5f8c536 100644 --- a/buffer.c +++ b/buffer.c @@ -14,7 +14,14 @@ buffer_t *buffer_new(int capacity) { buffer_t *buffer = kmalloc(sizeof(buffer_t), GFP_KERNEL); + if (!buffer) + return ERR_PTR(-ENOMEM); buffer->data = kmalloc(capacity, GFP_KERNEL); + if (!buffer->data) + { + buffer_free(buffer); + return ERR_PTR(-ENOMEM); + } buffer->size = 0; buffer->capacity = capacity; return buffer; @@ -22,11 +29,11 @@ buffer_t *buffer_new(int capacity) void buffer_free(buffer_t *buffer) { - if (buffer) - { - kfree(buffer->data); - kfree(buffer); - } + if (IS_ERR_OR_NULL(buffer)) + return; + + kfree(buffer->data); + kfree(buffer); } char *buffer_grow(buffer_t *buffer, int len) @@ -43,6 +50,10 @@ char *buffer_grow(buffer_t *buffer, int len) } buffer->data = krealloc(buffer->data, new_capacity, GFP_KERNEL); + if (!buffer->data) + { + return NULL; + } buffer->capacity = new_capacity; } diff --git a/buffer.h b/buffer.h index ed14095e..73f08ff7 100644 --- a/buffer.h +++ b/buffer.h @@ -11,16 +11,26 @@ #ifndef buffer_h #define buffer_h -typedef struct buffer_t { +typedef struct buffer_t +{ char *data; int size; int capacity; } buffer_t; - +/* + * buffer_new + * + * returns a buffer_t struct pointer or ERR_PTR() on error + */ buffer_t *buffer_new(int capacity); void buffer_free(buffer_t *buffer); -// returns a pointer to the buffer where the caller can write len long data (possibly resizing the buffer) +/* + * buffer_grow + * + * returns a buffer_t struct pointer where the caller can write len long data + * (possibly resizing the buffer) or NULL on realloc error + */ char *buffer_grow(buffer_t *buffer, int len); #endif diff --git a/cert_tools.c b/cert_tools.c index f15a65ea..4358c57f 100644 --- a/cert_tools.c +++ b/cert_tools.c @@ -16,6 +16,7 @@ #include "cert_tools.h" #include "string.h" #include "rsa_tools.h" +#include "task_context.h" // Define the maximum number of elements inside the cache #define MAX_CACHE_LENGTH 64 @@ -51,26 +52,32 @@ size_t linkedlist_length(struct list_head *head) // add_cert_to_cache adds a certificate chain with a given trust anchor to a linked list. The key will identify this entry. // the function is thread safe. -void add_cert_to_cache(char *key, x509_certificate *cert) +int add_cert_to_cache(char *key, x509_certificate *cert) { if (!key) { pr_err("invalid key"); - return; + return -EINVAL; } cert_with_key *new_entry = kzalloc(sizeof(cert_with_key), GFP_KERNEL); if (!new_entry) { - pr_err("memory allocation error"); - return; + return -ENOMEM; + } + new_entry->key = kstrdup(key, GFP_KERNEL); + if (!new_entry->key) + { + kfree(new_entry); + return -ENOMEM; } - new_entry->key = strdup(key); new_entry->cert = cert; cert_cache_lock(); list_add(&new_entry->list, &cert_cache); cert_cache_unlock(); + + return 0; } // remove_unused_expired_certs_from_cache iterates over the whole cache and tries to clean up the unused/expired certificates. @@ -199,6 +206,8 @@ bool validate_cert(x509_certificate_validity cert_validity) x509_certificate *x509_certificate_init(void) { x509_certificate *cert = kzalloc(sizeof(x509_certificate), GFP_KERNEL); + if (!cert) + return ERR_PTR(-ENOMEM); kref_init(&cert->kref); @@ -207,10 +216,8 @@ x509_certificate *x509_certificate_init(void) static void x509_certificate_free(x509_certificate *cert) { - if (!cert) - { + if (IS_ERR_OR_NULL(cert)) return; - } free_br_x509_certificate(cert->chain, cert->chain_len); free_br_x509_trust_anchors(cert->trust_anchors, cert->trust_anchors_len); @@ -227,18 +234,16 @@ static void x509_certificate_release(struct kref *kref) void x509_certificate_get(x509_certificate *cert) { - if (!cert) - { + if (IS_ERR_OR_NULL(cert)) return; - } + kref_get(&cert->kref); } void x509_certificate_put(x509_certificate *cert) { - if (!cert) - { + if (IS_ERR_OR_NULL(cert)) return; - } + kref_put(&cert->kref, x509_certificate_release); } diff --git a/cert_tools.h b/cert_tools.h index 44549ce6..b684b7cb 100644 --- a/cert_tools.h +++ b/cert_tools.h @@ -41,11 +41,16 @@ typedef struct struct list_head list; } cert_with_key; +/* + * x509_certificate_init + * + * returns an x509_certificate struct pointer or ERR_PTR() on error + */ x509_certificate *x509_certificate_init(void); void x509_certificate_get(x509_certificate *cert); void x509_certificate_put(x509_certificate *cert); -void add_cert_to_cache(char *key, x509_certificate *cert); +int add_cert_to_cache(char *key, x509_certificate *cert); cert_with_key *find_cert_from_cache(char *key); void remove_cert_from_cache(cert_with_key *cert); void remove_cert_from_cache_locked(cert_with_key *cert); diff --git a/commands.c b/commands.c index faaf5ee4..23707bb7 100644 --- a/commands.c +++ b/commands.c @@ -37,6 +37,8 @@ static LIST_HEAD(in_flight_command_list); static command_answer *this_send_command(char *name, char *data, task_context *context, bool is_message) { struct command *cmd = kzalloc(sizeof(struct command), GFP_KERNEL); + if (!cmd) + return ERR_PTR(-ENOMEM); uuid_gen(&cmd->uuid); cmd->name = name; @@ -70,18 +72,26 @@ static command_answer *this_send_command(char *name, char *data, task_context *c finish_wait(&cmd->wait_queue, &wait); + command_answer *cmd_answer = NULL; if (cmd->answer == NULL) { pr_err("command timeout # name[%s] uuid[%pUB] ", name, cmd->uuid.b); cmd->answer = answer_with_error("timeout"); + if (IS_ERR(cmd->answer)) + { + cmd_answer = (void *)cmd->answer; + goto out; + } } + cmd_answer = cmd->answer; + +out: mutex_lock(&command_list_lock); list_del(&cmd->list); mutex_unlock(&command_list_lock); - command_answer *cmd_answer = cmd->answer; free_command(cmd); return cmd_answer; @@ -103,12 +113,23 @@ void send_message(char *name, char *data, task_context *context) command_answer *send_augment_command() { - return send_command("augment", NULL, get_task_context()); + task_context *ctx = get_task_context(); + if (IS_ERR(ctx)) + return (void *)ctx; + + return send_command("augment", NULL, ctx); } command_answer *send_accept_command(u16 port) { + task_context *ctx = get_task_context(); + if (IS_ERR(ctx)) + return (void *)ctx; + JSON_Value *root_value = json_value_init_object(); + if (!root_value) + return ERR_PTR(-ENOMEM); + JSON_Object *root_object = json_value_get_object(root_value); if (!root_object) @@ -116,9 +137,11 @@ command_answer *send_accept_command(u16 port) return answer_with_error("could not get root object"); } - json_object_set_number(root_object, "port", port); - - command_answer *answer = send_command("accept", json_serialize_to_string(root_value), get_task_context()); + command_answer *answer = NULL; + if (json_object_set_number(root_object, "port", port) < 0) + answer = answer_with_error("could not set port"); + else + answer = send_command("accept", json_serialize_to_string(root_value), ctx); json_value_free(root_value); @@ -127,7 +150,14 @@ command_answer *send_accept_command(u16 port) command_answer *send_connect_command(u16 port) { + task_context *ctx = get_task_context(); + if (IS_ERR(ctx)) + return (void *)ctx; + JSON_Value *root_value = json_value_init_object(); + if (!root_value) + return ERR_PTR(-ENOMEM); + JSON_Object *root_object = json_value_get_object(root_value); if (!root_object) @@ -135,48 +165,87 @@ command_answer *send_connect_command(u16 port) return answer_with_error("could not get root object"); } - json_object_set_number(root_object, "port", port); - - command_answer *answer = send_command("connect", json_serialize_to_string(root_value), get_task_context()); + command_answer *answer = NULL; + if (json_object_set_number(root_object, "port", port) < 0) + answer = answer_with_error("could not set port"); + else + answer = send_command("connect", json_serialize_to_string(root_value), ctx); json_value_free(root_value); return answer; } +void csr_sign_answer_free(csr_sign_answer *answer) +{ + if (IS_ERR_OR_NULL(answer)) + return; + + kfree(answer->error); + kfree(answer); +} + +static char *prepare_csr_json(const unsigned char *csr, const char *ttl) +{ + char *json = NULL; + + JSON_Value *root_value = json_value_init_object(); + if (!root_value) + return ERR_PTR(-ENOMEM); + + JSON_Object *root_object = json_value_get_object(root_value); + if (!root_object) + goto out; + + if (json_object_set_string(root_object, "csr", csr) < 0) + goto out; + + if (ttl) + if (json_object_set_string(root_object, "ttl", ttl) < 0) + goto out; + + json = json_serialize_to_string(root_value); + +out: + json_value_free(root_value); + return json; +} + csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) { + csr_sign_answer *csr_sign_answer = NULL; + command_answer *answer = NULL; JSON_Value *json = NULL; - const char *errormsg; + const char *errormsg = NULL; + void *error = NULL; + + task_context *ctx = get_task_context(); + if (IS_ERR(ctx)) + return (void *)ctx; - csr_sign_answer *csr_sign_answer = kzalloc(sizeof(struct csr_sign_answer), GFP_KERNEL); + csr_sign_answer = kzalloc(sizeof(struct csr_sign_answer), GFP_KERNEL); + if (!csr_sign_answer) + return ERR_PTR(-ENOMEM); if (!csr) { - csr_sign_answer->error = strdup("nil csr"); - - return csr_sign_answer; + error = ERR_PTR(-EINVAL); + goto error; } - JSON_Value *root_value = json_value_init_object(); - JSON_Object *root_object = json_value_get_object(root_value); - - if (!root_object) + char *csr_json = prepare_csr_json(csr, ttl); + if (!csr_json) { - errormsg = "could not get root object"; + errormsg = "could not prepare csr json"; goto error; } - - json_object_set_string(root_object, "csr", csr); - if (ttl) + if (IS_ERR(csr_json)) { - json_object_set_string(root_object, "ttl", ttl); + error = csr_json; + goto error; } - command_answer *answer = send_command("csr_sign", json_serialize_to_string(root_value), get_task_context()); - - json_value_free(root_value); - + answer = send_command("csr_sign", csr_json, ctx); if (answer->error) { errormsg = answer->error; @@ -186,7 +255,6 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) if (answer->answer) { json = json_parse_string(answer->answer); - if (json == NULL) { errormsg = "could not parse answer JSON data"; @@ -194,7 +262,6 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) } JSON_Object *root = json_value_get_object(json); - if (root == NULL) { errormsg = "could not get root object from parsed JSON"; @@ -202,7 +269,6 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) } JSON_Array *trust_anchors = json_object_get_array(root, "trust_anchors"); - if (trust_anchors == NULL) { errormsg = "could not find trust anchors"; @@ -210,13 +276,23 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) } csr_sign_answer->cert = x509_certificate_init(); + if (IS_ERR(csr_sign_answer->cert)) + { + error = csr_sign_answer->cert; + goto error; + } csr_sign_answer->cert->trust_anchors_len = json_array_get_count(trust_anchors); size_t srclen; if (csr_sign_answer->cert->trust_anchors_len > 0) { - csr_sign_answer->cert->trust_anchors = kmalloc(csr_sign_answer->cert->trust_anchors_len * sizeof *csr_sign_answer->cert->trust_anchors, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors = kzalloc(csr_sign_answer->cert->trust_anchors_len * sizeof *csr_sign_answer->cert->trust_anchors, GFP_KERNEL); + if (!csr_sign_answer->cert->trust_anchors) + { + error = ERR_PTR(-ENOMEM); + goto error; + } size_t u; for (u = 0; u < csr_sign_answer->cert->trust_anchors_len; u++) @@ -231,7 +307,12 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) if (raw_subject != NULL) { srclen = strlen(raw_subject); - csr_sign_answer->cert->trust_anchors[u].dn.data = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].dn.data = kzalloc(srclen, GFP_KERNEL); + if (!csr_sign_answer->cert->trust_anchors[u].dn.data) + { + error = ERR_PTR(-ENOMEM); + goto error; + } csr_sign_answer->cert->trust_anchors[u].dn.len = base64_decode(csr_sign_answer->cert->trust_anchors[u].dn.data, srclen, raw_subject, srclen); } @@ -240,7 +321,12 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) if (rsa_n != NULL) { srclen = strlen(rsa_n); - csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n = kzalloc(srclen, GFP_KERNEL); + if (!csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n) + { + error = ERR_PTR(-ENOMEM); + goto error; + } csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.nlen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n, srclen, rsa_n, srclen); } @@ -249,7 +335,12 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) if (rsa_e != NULL) { srclen = strlen(rsa_e); - csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e = kzalloc(srclen, GFP_KERNEL); + if (!csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e) + { + error = ERR_PTR(-ENOMEM); + goto error; + } csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.elen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e, srclen, rsa_e, srclen); } } @@ -272,9 +363,19 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) csr_sign_answer->cert->chain_len = json_array_get_count(chain); csr_sign_answer->cert->chain_len++; csr_sign_answer->cert->chain = kzalloc(csr_sign_answer->cert->chain_len * sizeof *csr_sign_answer->cert->chain, GFP_KERNEL); + if (!csr_sign_answer->cert->chain) + { + error = ERR_PTR(-ENOMEM); + goto error; + } srclen = strlen(raw); - csr_sign_answer->cert->chain[0].data = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->chain[0].data = kzalloc(srclen, GFP_KERNEL); + if (!csr_sign_answer->cert->chain[0].data) + { + error = ERR_PTR(-ENOMEM); + goto error; + } csr_sign_answer->cert->chain[0].data_len = base64_decode(csr_sign_answer->cert->chain[0].data, srclen, raw, srclen); int k; @@ -288,7 +389,12 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) if (raw) { srclen = strlen(raw); - csr_sign_answer->cert->chain[j].data = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->chain[j].data = kzalloc(srclen, GFP_KERNEL); + if (!csr_sign_answer->cert->chain[j].data) + { + error = ERR_PTR(-ENOMEM); + goto error; + } csr_sign_answer->cert->chain[j].data_len = base64_decode(csr_sign_answer->cert->chain[j].data, srclen, raw, srclen); j++; @@ -309,12 +415,24 @@ csr_sign_answer *send_csrsign_command(const unsigned char *csr, const char *ttl) error: if (errormsg) { - csr_sign_answer->error = strdup(errormsg); + csr_sign_answer->error = kstrdup(errormsg, GFP_KERNEL); + if (!csr_sign_answer->error) + { + error = ERR_PTR(-ENOMEM); + } + error = ERR_PTR(-ENOMEM); } json_value_free(json); free_command_answer(answer); + if (IS_ERR(error)) + { + x509_certificate_put(csr_sign_answer->cert); + csr_sign_answer_free(csr_sign_answer); + return error; + } + return csr_sign_answer; } @@ -372,17 +490,20 @@ command *get_next_command(void) command_answer *answer_with_error(char *error_message) { command_answer *answer = kzalloc(sizeof(struct command_answer), GFP_KERNEL); - answer->error = strdup(error_message); + if (!answer) + return ERR_PTR(-ENOMEM); + + answer->error = kstrdup(error_message, GFP_KERNEL); + if (!answer->error) + return ERR_PTR(-ENOMEM); return answer; } void free_command_answer(command_answer *cmd_answer) { - if (!cmd_answer) - { + if (IS_ERR_OR_NULL(cmd_answer)) return; - } kfree(cmd_answer->error); kfree(cmd_answer->answer); diff --git a/commands.h b/commands.h index 117c7b86..ae2b0d4c 100644 --- a/commands.h +++ b/commands.h @@ -32,6 +32,7 @@ typedef struct csr_sign_answer } csr_sign_answer; void free_command_answer(command_answer *cmd_answer); +void csr_sign_answer_free(csr_sign_answer *answer); void send_message(char *name, char *data, task_context *context); command_answer *send_command(char *name, char *data, task_context *context); diff --git a/config.c b/config.c index 646d7bb0..9f115d50 100644 --- a/config.c +++ b/config.c @@ -26,12 +26,20 @@ void camblet_config_unlock(void) mutex_unlock(&camblet_config_mutex_lock); } -void camblet_config_init() +int camblet_config_init() { camblet_config_lock(); config = kzalloc(sizeof(camblet_config), GFP_KERNEL); + if (IS_ERR(config)) + { + camblet_config_unlock(); + return -ENOMEM; + } + strlcpy(config->trust_domain, "camblet", MAX_TRUST_DOMAIN_LEN); camblet_config_unlock(); + + return 0; } camblet_config *camblet_config_get_locked() @@ -42,11 +50,6 @@ camblet_config *camblet_config_get_locked() void camblet_config_free() { camblet_config_lock(); - if (!config) - { - return; - } - kfree(config); camblet_config_unlock(); } diff --git a/config.h b/config.h index ed0bf5a2..d7a1cbb2 100644 --- a/config.h +++ b/config.h @@ -21,7 +21,7 @@ typedef struct camblet_config camblet_config *camblet_config_get_locked(void); void camblet_config_lock(void); void camblet_config_unlock(void); -void camblet_config_init(void); +int camblet_config_init(void); void camblet_config_free(void); #endif diff --git a/csr.c b/csr.c index cf610547..20226026 100644 --- a/csr.c +++ b/csr.c @@ -48,14 +48,18 @@ void csr_unlock(csr_module *csr) wasm_vm_result init_csr_for(wasm_vm *vm, wasm_vm_module *module) { + wasm_vm_result result; csr_module *csr = csr_modules[wasm_vm_cpu(vm)]; if (csr == NULL) { csr = kzalloc(sizeof(struct csr_module), GFP_KERNEL); + if (!csr) + { + return wasm_vm_error("could not allocate memory"); + } csr->vm = vm; csr_modules[wasm_vm_cpu(vm)] = csr; } - wasm_vm_result result; wasm_vm_try_get_function(csr->generate_csr, wasm_vm_get_function(vm, module->name, "csr_gen")); wasm_vm_try_get_function(csr->csr_malloc, wasm_vm_get_function(vm, module->name, "csr_malloc")); wasm_vm_try_get_function(csr->csr_free, wasm_vm_get_function(vm, module->name, "csr_free")); @@ -67,7 +71,7 @@ wasm_vm_result init_csr_for(wasm_vm *vm, wasm_vm_module *module) return result; } - return (wasm_vm_result){.err = NULL}; + return wasm_vm_ok; } wasm_vm_result csr_malloc(csr_module *csr, i32 size) diff --git a/device_driver.c b/device_driver.c index c2c685b5..f6d729fd 100644 --- a/device_driver.c +++ b/device_driver.c @@ -72,7 +72,12 @@ int chardev_init(void) cls = class_create(DEVICE_NAME); #endif - device_create(cls, NULL, MKDEV(major, 0), NULL, DEVICE_NAME); + if (IS_ERR(cls)) + return PTR_ERR(cls); + + struct device *dev = device_create(cls, NULL, MKDEV(major, 0), NULL, DEVICE_NAME); + if (IS_ERR(dev)) + return PTR_ERR(dev); pr_info("device created # device[/dev/%s]", DEVICE_NAME); @@ -81,11 +86,19 @@ int chardev_init(void) void chardev_exit(void) { - device_destroy(cls, MKDEV(major, 0)); - class_destroy(cls); + if (!IS_ERR(cls)) + { + device_destroy(cls, MKDEV(major, 0)); + class_destroy(cls); + pr_info("device destroyed # device[/dev/%s]", DEVICE_NAME); + } /* Unregister the device */ - unregister_chrdev(major, DEVICE_NAME); + if (major > 0) + { + unregister_chrdev(major, DEVICE_NAME); + pr_info("char device unregistered # major[%d]", major); + } } /* Methods */ @@ -204,49 +217,103 @@ wasm_vm_result load_module(const char *name, const char *code, unsigned length, return result; } -static void load_sd_info(const char *data) +static int load_sd_info(const char *data) { + int retval = 0; + JSON_Value *json; + if (!data) { - return; + retval = -EINVAL; + goto ret; } pr_info("load service discovery info # data[%s]", data); - JSON_Value *json = json_parse_string(data); - if (json == NULL) + json = json_parse_string(data); + if (!json) { pr_err("could not load sd info: invalid json"); + retval = -EINVAL; + goto ret; } JSON_Object *root = json_value_get_object(json); - if (root == NULL) + if (!root) { pr_err("could not load sd info: invalid json root"); + retval = -EINVAL; + goto ret; } service_discovery_table *table = service_discovery_table_create(); + if (IS_ERR(table)) + { + retval = PTR_ERR(table); + goto ret; + } service_discovery_entry *entry; size_t i, k; for (i = 0; i < json_object_get_count(root); i++) { const char *name = json_object_get_name(root, i); + if (!name) + { + pr_err("could not load sd info: record[%d]: could not get object name", i); + retval = -EINVAL; + goto ret; + } JSON_Object *json_entry = json_object_get_object(root, name); + if (!json_entry) + { + pr_err("could not load sd info: record[%d]: could not get object", i); + retval = -EINVAL; + goto ret; + } JSON_Array *labels = json_object_get_array(json_entry, "labels"); + if (!labels) + { + pr_err("could not load sd info: record[%d]: could not get labels", i); + retval = -EINVAL; + goto ret; + } entry = kzalloc(sizeof(*entry), GFP_KERNEL); - entry->address = strdup(name); + if (!entry) + { + retval = -ENOMEM; + goto ret; + } + entry->address = kstrdup(name, GFP_KERNEL); + if (!entry->address) + { + service_discovery_entry_free(entry); + retval = -ENOMEM; + goto ret; + } pr_debug("create sd entry # address[%s]", entry->address); entry->labels_len = json_array_get_count(labels); - entry->labels = kmalloc(entry->labels_len * sizeof(char *), GFP_KERNEL); + entry->labels = kzalloc(entry->labels_len * sizeof(char *), GFP_KERNEL); + if (!entry->labels) + { + service_discovery_entry_free(entry); + retval = -ENOMEM; + goto ret; + } for (k = 0; k < entry->labels_len; k++) { const char *label = json_array_get_string(labels, k); - entry->labels[k] = strdup(label); + entry->labels[k] = kstrdup(label, GFP_KERNEL); + if (!entry->labels[k]) + { + service_discovery_entry_free(entry); + retval = -ENOMEM; + goto ret; + } pr_debug("set sd entry label # address[%s] label[%s]", entry->address, entry->labels[k]); } @@ -255,26 +322,33 @@ static void load_sd_info(const char *data) sd_table_replace(table); +ret: json_value_free(json); + + return retval; } -static void load_camblet_config(const char *data) +static int load_camblet_config(const char *data) { + int status = SUCCESS; + if (!data) - { - return; - } + return -EINVAL; JSON_Value *json = json_parse_string(data); - if (json == NULL) + if (!json) { pr_err("could not load camblet config: invalid json"); + status = -EINVAL; + goto out; } JSON_Object *root = json_value_get_object(json); - if (root == NULL) + if (!root) { pr_err("could not load camblet config: invalid json root"); + status = -EINVAL; + goto out; } const char *trust_domain = json_object_get_string(root, "trust_domain"); @@ -290,257 +364,348 @@ static void load_camblet_config(const char *data) camblet_config_unlock(); } +out: json_value_free(json); + return status; +} + +static char *base64_decode_data(const char *src, int *decoded_length) +{ + char *decoded = kzalloc(strlen(src) * 2, GFP_KERNEL); + if (!decoded) + return ERR_PTR(-ENOMEM); + + int length = base64_decode(decoded, strlen(src) * 2, src, strlen(src)); + if (length < 0) + { + kfree(decoded); + return ERR_PTR(-EINVAL); + } + + *decoded_length = length; + + return decoded; } static int parse_command(const char *data) { int status = SUCCESS; JSON_Value *json = NULL; + const char *command = NULL; + + if (!data) + return -EINVAL; - if (data) + json = json_parse_string(data); + if (!json) { - json = json_parse_string(data); - JSON_Object *root = json_value_get_object(json); - const char *command = json_object_get_string(root, "command"); + pr_err("parse_command: could not parse data"); + status = -EINVAL; + goto out; + } - pr_debug("incoming command # command[%s]", command); + JSON_Object *root = json_value_get_object(json); + if (!root) + { + pr_err("parse_command: invalid JSON root object"); + status = -EINVAL; + goto out; + } - if (strcmp("load", command) == 0) - { - const char *name = json_object_get_string(root, "name"); - pr_info("load module # name[%s]", name); + command = json_object_get_string(root, "command"); + if (!command) + { + pr_err("parse_command: missing 'command' property"); + status = -EINVAL; + goto out; + } - const char *code = json_object_get_string(root, "code"); - char *decoded = kzalloc(strlen(code) * 2, GFP_KERNEL); - int length = base64_decode(decoded, strlen(code) * 2, code, strlen(code)); - if (length < 0) - { - pr_crit("base64 decode failed"); - status = -1; - kfree(decoded); - goto cleanup; - } + pr_debug("incoming command # command[%s]", command); - const char *entrypoint = json_object_get_string(root, "entrypoint"); - if (entrypoint == NULL) - { - pr_info("setting default module entrypoint # entrypoint[%s]", DEFAULT_MODULE_ENTRYPOINT); - entrypoint = DEFAULT_MODULE_ENTRYPOINT; - } + if (strcmp("load", command) == 0) + { + const char *name = json_object_get_string(root, "name"); + if (!name) + { + pr_err("load: missing 'name' property"); + status = -EINVAL; + goto out; + } - wasm_vm_result result = load_module(name, decoded, length, entrypoint); - if (result.err) - { - pr_crit("could not load module # err[%s]", result.err); - status = -1; - kfree(decoded); - goto cleanup; - } + pr_info("load module # name[%s]", name); - kfree(decoded); + const char *code = json_object_get_string(root, "code"); + if (!code) + { + pr_err("load: missing 'code' property"); + status = -EINVAL; + goto out; } - if (strcmp("reset", command) == 0) + int length = 0; + char *decoded = base64_decode_data(code, &length); + if (IS_ERR(decoded)) { - pr_info("reseting vm"); - - wasm_vm_result result = reset_vms(); - if (result.err) - { - pr_crit("could not reset vm # err[%s]", result.err); - status = -1; - goto cleanup; - } + pr_err("could not decode data: err[%ld]", PTR_ERR(decoded)); + goto out; } - else if (strcmp("load_policies", command) == 0) + + const char *entrypoint = json_object_get_string(root, "entrypoint"); + if (!entrypoint) { - pr_info("load policies"); + pr_info("setting default module entrypoint # entrypoint[%s]", DEFAULT_MODULE_ENTRYPOINT); + entrypoint = DEFAULT_MODULE_ENTRYPOINT; + } - const char *code = json_object_get_string(root, "code"); - char *decoded = kzalloc(strlen(code) * 2, GFP_KERNEL); - int length = base64_decode(decoded, strlen(code) * 2, code, strlen(code)); - if (length < 0) - { - pr_crit("base64 decode failed"); - status = -1; - kfree(decoded); - goto cleanup; - } + wasm_vm_result result = load_module(name, decoded, length, entrypoint); + kfree(decoded); + if (result.err) + { + pr_crit("could not load module # err[%s]", result.err); + status = FAILURE; + goto out; + } + } + else if (strcmp("reset", command) == 0) + { + pr_info("reseting vm"); - load_opa_data(decoded); - kfree(decoded); + wasm_vm_result result = reset_vms(); + if (result.err) + { + pr_crit("could not reset vm # err[%s]", result.err); + status = FAILURE; + goto out; } - else if (strcmp("load_config", command) == 0) + } + else if (strcmp("load_policies", command) == 0) + { + pr_info("load policies"); + const char *code = json_object_get_string(root, "code"); + if (!code) { - pr_info("load config"); - - const char *code = json_object_get_string(root, "code"); - char *decoded = kzalloc(strlen(code) * 2, GFP_KERNEL); - int length = base64_decode(decoded, strlen(code) * 2, code, strlen(code)); - if (length < 0) - { - pr_crit("base64 decode failed"); - status = -1; - kfree(decoded); - goto cleanup; - } - - if (decoded) - { - load_camblet_config(decoded); - kfree(decoded); - } + pr_err("load_policies: missing 'code' property"); + status = -EINVAL; + goto out; } - else if (strcmp("load_sd_info", command) == 0) + int length = 0; + char *decoded = base64_decode_data(code, &length); + if (IS_ERR(decoded)) { - pr_info("load sd info"); - - const char *code = json_object_get_string(root, "code"); - char *decoded = kzalloc(strlen(code) * 2, GFP_KERNEL); - int length = base64_decode(decoded, strlen(code) * 2, code, strlen(code)); - if (length < 0) - { - pr_crit("base64 decode failed"); - status = -1; - kfree(decoded); - goto cleanup; - } - - if (decoded) - { - load_sd_info(decoded); - kfree(decoded); - } + status = PTR_ERR(decoded); + pr_err("load_policies: could not decode data: err[%d]", status); + goto out; } - else if (strcmp("manage_trace_requests", command) == 0) + load_opa_data(decoded); + kfree(decoded); + goto out; + } + else if (strcmp("load_config", command) == 0) + { + const char *code = json_object_get_string(root, "code"); + if (!code) { - const char *data = json_object_get_string(root, "data"); - if (data == NULL) - { - pr_debug("could not find data # command[%s]", command); - - goto cleanup; - } - - JSON_Value *data_json = json_parse_string(data); - if (data_json == NULL) - { - pr_debug("could not parse json # command[%s]", command); - - goto cleanup; - } + pr_err("load_config: missing 'code' property"); + status = -EINVAL; + goto out; + } + int length = 0; + char *decoded = base64_decode_data(code, &length); + if (IS_ERR(decoded)) + { + pr_err("could not decode data: err[%ld]", PTR_ERR(decoded)); + goto out; + } + int ret = load_camblet_config(decoded); + if (ret < 0) + pr_err("could not load camblet config # error_code[%d]", ret); + kfree(decoded); + goto out; + } + else if (strcmp("load_sd_info", command) == 0) + { + pr_info("load sd info"); - JSON_Object *data_root = json_value_get_object(data_json); - if (data_root == NULL) - { - pr_debug("invalid json format # command[%s]", command); + const char *code = json_object_get_string(root, "code"); + if (!code) + { + pr_err("load_sd_info: missing 'code' property"); + status = -EINVAL; + goto out; + } + int length = 0; + char *decoded = base64_decode_data(code, &length); + if (IS_ERR(decoded)) + { + pr_err("could not decode data: err[%ld]", PTR_ERR(decoded)); + goto out; + } + pr_info("sd info arrived # length[%d]", length); + int ret = load_sd_info(decoded); + if (ret < 0) + pr_err("could not load sd info # error_code[%d]", ret); + kfree(decoded); + goto out; + } + else if (strcmp("manage_trace_requests", command) == 0) + { + const char *data = json_object_get_string(root, "data"); + if (!data) + { + pr_err("missing 'data' property # command[%s]", command); + status = -EINVAL; + goto out; + } - goto request_trace_out; - } + JSON_Value *data_json = json_parse_string(data); + if (!data_json) + { + pr_err("could not parse json # command[%s]", command); + status = -EINVAL; + goto out; + } - const char *action = json_object_get_string(data_root, "action"); - if (action == NULL) - { - pr_debug("could not find action # command[%s]", command); + JSON_Object *data_root = json_value_get_object(data_json); + if (!data_root) + { + pr_err("invalid JSON root object # command[%s]", command); + status = -EINVAL; + goto request_trace_out; + } - goto request_trace_out; - } + const char *action = json_object_get_string(data_root, "action"); + if (!action) + { + pr_err("missing 'action' property # command[%s]", command); + status = -EINVAL; + goto request_trace_out; + } - int pid = -1; - if (json_object_has_value(data_root, "pid") == 1) - { - pid = json_object_get_number(data_root, "pid"); - } + int pid = -1; + if (json_object_has_value(data_root, "pid") == 1) + { + pid = json_object_get_number(data_root, "pid"); + } - int uid = -1; - if (json_object_has_value(data_root, "uid") == 1) - { - uid = json_object_get_number(data_root, "uid"); - } + int uid = -1; + if (json_object_has_value(data_root, "uid") == 1) + { + uid = json_object_get_number(data_root, "uid"); + } - const char *command_name = json_object_get_string(data_root, "command_name"); + const char *command_name = json_object_get_string(data_root, "command_name"); + if (!command_name) + { + pr_err("missing 'command_name' property # command[%s]", command); + status = -EINVAL; + goto request_trace_out; + } - pr_debug("manage trace # command[%s] action[%s] pid[%d] uid[%d] command_name[%s]", command, action, pid, uid, command_name); + pr_debug("manage trace # command[%s] action[%s] pid[%d] uid[%d] command_name[%s]", command, action, pid, uid, command_name); - if (strcmp(action, "add") == 0) - { - pr_debug("add trace # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); - add_trace_request(pid, uid, command_name); - } - else if (strcmp(action, "remove") == 0) + if (strcmp(action, "add") == 0) + { + pr_debug("add trace # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); + int ret = add_trace_request(pid, uid, command_name); + if (ret < 0) { - pr_debug("disable trace # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); - trace_request *tr = get_trace_request(pid, uid, command_name); - if (tr) - { - remove_trace_request(tr); - } - else - { - pr_debug("trace not exists # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); - } + pr_err("could not add trace request # command[%s] error_code[%d]", command, ret); } - else if (strcmp(action, "clear") == 0) - { - pr_debug("clear trace requests"); + } + else if (strcmp(action, "remove") == 0) + { + pr_debug("disable trace # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); + trace_request *tr = get_trace_request(pid, uid, command_name); + if (tr) + remove_trace_request(tr); + else + pr_debug("trace not exists # command[%s] pid[%d] uid[%d] command_name[%s]", command, pid, uid, command_name); + } + else if (strcmp(action, "clear") == 0) + { + pr_debug("clear trace requests"); - clear_trace_requests(); - } + clear_trace_requests(); + } - request_trace_out: - json_value_free(data_json); + request_trace_out: + json_value_free(data_json); - goto cleanup; - } - else if (strcmp("answer", command) == 0) + goto out; + } + else if (strcmp("answer", command) == 0) + { + const char *command_id = json_object_get_string(root, "id"); + if (!command_id) { - const char *command_id = json_object_get_string(root, "id"); + pr_err("answer: missing 'id' property"); + status = -EINVAL; + goto out; + } - pr_debug("command answer # uuid[%s]", command_id); + pr_debug("command answer # id[%s]", command_id); - uuid_t uuid; - uuid_parse(command_id, &uuid); - struct command *cmd = lookup_in_flight_command(uuid.b); + uuid_t uuid; + int ret = uuid_parse(command_id, &uuid); + if (ret < 0) + { + pr_err("answer: invalid command id # id[%s]", command_id); + status = ret; + goto out; + } - if (cmd == NULL) - { - pr_err("command not found # uuid[%s]", command_id); - status = -1; - goto cleanup; - } + struct command *cmd = lookup_in_flight_command(uuid.b); + if (!cmd) + { + pr_err("command not found # id[%s]", command_id); + status = -ENOENT; + goto out; + } - struct command_answer *cmd_answer = kzalloc(sizeof(struct command_answer), GFP_KERNEL); - const char *answer = json_object_get_string(root, "answer"); - const char *error = json_object_get_string(root, "error"); + struct command_answer *cmd_answer = kzalloc(sizeof(struct command_answer), GFP_KERNEL); + if (!cmd_answer) + { + status = -ENOMEM; + goto out; + } - if (error) + const char *error = json_object_get_string(root, "error"); + if (error) + { + cmd_answer->error = kstrdup(error, GFP_KERNEL); + if (!cmd_answer->error) { - cmd_answer->error = strdup(error); + kfree(cmd_answer); + status = -ENOMEM; + goto out; } + } - if (answer) + const char *answer = json_object_get_string(root, "answer"); + if (answer) + { + cmd_answer->answer = kstrdup(answer, GFP_KERNEL); + if (!cmd_answer->answer) { - cmd_answer->answer = strdup(answer); + kfree(cmd_answer); + status = -ENOMEM; + goto out; } + } - cmd->answer = cmd_answer; + cmd->answer = cmd_answer; - wake_up_interruptible(&cmd->wait_queue); - } - else - { - pr_err("invalid command # command[%s]", command); - status = -1; - goto cleanup; - } + wake_up_interruptible(&cmd->wait_queue); } - -cleanup: - if (json) + else { - json_value_free(json); + pr_err("invalid command # command[%s]", command); + status = -EINVAL; + goto out; } +out: + json_value_free(json); + return status; } @@ -565,51 +730,213 @@ static int device_release(struct inode *inode, struct file *file) static char *serialize_command(struct command *cmd) { char *serialized_string = NULL; + char *error = NULL; char uuid[UUID_STRING_LEN + 1]; + int length = snprintf(uuid, UUID_STRING_LEN + 1, "%pUB", cmd->uuid.b); if (length < 0) { - pr_crit("could not stringify uuid"); - goto cleanup; + pr_err("serialize_command: could not stringify uuid"); + return ERR_PTR(-EINVAL); } JSON_Value *root_value = json_value_init_object(); + if (!root_value) + { + pr_err("serialize_command: could not init root json value"); + return ERR_PTR(-ENOMEM); + } JSON_Object *root_object = json_value_get_object(root_value); + if (!root_object) + { + pr_err("serialize_command: could not get root object"); + error = ERR_PTR(-EINVAL); + goto cleanup; + } + + JSON_Value *context_value = NULL; + bool context_value_free = false; + JSON_Value *namespace_ids_value = NULL; + bool namespace_ids_value_free = false; if (cmd->context) { - JSON_Value *context_value = json_value_init_object(); + context_value = json_value_init_object(); + if (!context_value) + { + pr_err("serialize_command: could not init context json value"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + else + context_value_free = true; JSON_Object *context_object = json_value_get_object(context_value); - json_object_set_number(context_object, "uid", cmd->context->uid.val); - json_object_set_number(context_object, "gid", cmd->context->gid.val); - json_object_set_number(context_object, "pid", cmd->context->pid); - json_object_set_string(context_object, "command_path", cmd->context->command_path); - json_object_set_string(context_object, "command_name", cmd->context->command_name); - json_object_set_string(context_object, "cgroup_path", cmd->context->cgroup_path); - - JSON_Value *namespace_ids_value = json_value_init_object(); + if (!context_object) + { + pr_err("serialize_command: could not get context object"); + error = ERR_PTR(-EINVAL); + goto cleanup; + } + if (json_object_set_number(context_object, "uid", cmd->context->uid.val) < 0) + { + pr_err("serialize_command: could not set 'uid' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(context_object, "gid", cmd->context->gid.val) < 0) + { + pr_err("serialize_command: could not set 'gid' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(context_object, "pid", cmd->context->pid) < 0) + { + pr_err("serialize_command: could not set 'pid' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_string(context_object, "command_path", cmd->context->command_path) < 0) + { + pr_err("serialize_command: could not set 'command_path' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_string(context_object, "command_name", cmd->context->command_name) < 0) + { + pr_err("serialize_command: could not set 'command_name' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_string(context_object, "cgroup_path", cmd->context->cgroup_path) < 0) + { + pr_err("serialize_command: could not set 'cgroup_path' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + + if (json_object_set_value(root_object, "task_context", context_value) < 0) + { + pr_err("serialize_command: could not set 'task_context' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + else + context_value_free = false; + + namespace_ids_value = json_value_init_object(); + if (!namespace_ids_value) + { + pr_err("serialize_command: could not init namespace ids json value"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + else + namespace_ids_value_free = true; JSON_Object *namespace_ids_object = json_value_get_object(namespace_ids_value); - json_object_set_number(namespace_ids_object, "uts", cmd->context->namespace_ids.uts); - json_object_set_number(namespace_ids_object, "ipc", cmd->context->namespace_ids.ipc); - json_object_set_number(namespace_ids_object, "mnt", cmd->context->namespace_ids.mnt); - json_object_set_number(namespace_ids_object, "pid", cmd->context->namespace_ids.pid); - json_object_set_number(namespace_ids_object, "net", cmd->context->namespace_ids.net); - json_object_set_number(namespace_ids_object, "time", cmd->context->namespace_ids.time); - json_object_set_number(namespace_ids_object, "cgroup", cmd->context->namespace_ids.cgroup); + if (!namespace_ids_object) + { + pr_err("serialize_command: could not get namespace ids object"); + error = ERR_PTR(-EINVAL); + goto cleanup; + } - json_object_set_value(context_object, "namespace_ids", namespace_ids_value); - json_object_set_value(root_object, "task_context", context_value); + if (json_object_set_number(namespace_ids_object, "uts", cmd->context->namespace_ids.uts) < 0) + { + pr_err("serialize_command: could not set 'uts' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "ipc", cmd->context->namespace_ids.ipc) < 0) + { + pr_err("serialize_command: could not set 'ipc' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "mnt", cmd->context->namespace_ids.mnt) < 0) + { + pr_err("serialize_command: could not set 'mnt' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "pid", cmd->context->namespace_ids.pid) < 0) + { + pr_err("serialize_command: could not set 'pid' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "net", cmd->context->namespace_ids.net) < 0) + { + pr_err("serialize_command: could not set 'net' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "time", cmd->context->namespace_ids.time) < 0) + { + pr_err("serialize_command: could not set 'time' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + if (json_object_set_number(namespace_ids_object, "cgroup", cmd->context->namespace_ids.cgroup) < 0) + { + pr_err("serialize_command: could not set 'cgroup' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + + if (json_object_set_value(context_object, "namespace_ids", namespace_ids_value) < 0) + { + pr_err("serialize_command: could not set 'namespace_ids' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + else + namespace_ids_value_free = false; + } + + if (json_object_set_string(root_object, "id", uuid) < 0) + { + pr_err("serialize_command: could not set 'id' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + + if (json_object_set_string(root_object, "command", cmd->name) < 0) + { + pr_err("serialize_command: could not set 'command' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } + + if (cmd->data && json_object_set_string(root_object, "data", cmd->data) < 0) + { + pr_err("serialize_command: could not set 'data' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; } - json_object_set_string(root_object, "id", uuid); - json_object_set_string(root_object, "command", cmd->name); - json_object_set_string(root_object, "data", cmd->data); - json_object_set_boolean(root_object, "is_message", cmd->is_message); + if (json_object_set_boolean(root_object, "is_message", cmd->is_message) < 0) + { + pr_err("serialize_command: could not set 'is_message' property"); + error = ERR_PTR(-ENOMEM); + goto cleanup; + } serialized_string = json_serialize_to_string(root_value); + if (!serialized_string) + { + pr_err("serialize_command: could not serialize json"); + error = ERR_PTR(-ENOMEM); + } cleanup: json_value_free(root_value); + if (context_value_free) + json_value_free(context_value); + if (namespace_ids_value_free) + json_value_free(namespace_ids_value); + + if (IS_ERR(error)) + return error; return serialized_string; } @@ -636,9 +963,15 @@ static ssize_t device_read(struct file *file, /* see include/linux/fs.h */ { free_command(c); } + if (IS_ERR(command_json)) + { + pr_err("could not marshal command json # uuid[%pUB] error_code[%ld]", c->uuid.b, PTR_ERR(command_json)); + return -EINTR; + } if (command_json == NULL) { - return -EFAULT; + pr_err("could not marshal command json # uuid[%pUB]", c->uuid.b); + return -EINTR; } pr_debug("sent command # command[%s]", command_json); diff --git a/device_driver.h b/device_driver.h index f64635d8..06c9f908 100644 --- a/device_driver.h +++ b/device_driver.h @@ -19,7 +19,8 @@ wasm_vm_result load_module(const char *name, const char *code, unsigned length, const char *entrypoint); #define SUCCESS 0 -#define DEVICE_NAME "camblet" /* Dev name as it appears in /dev/devices */ +#define FAILURE -1 +#define DEVICE_NAME "camblet" /* Dev name as it appears in /dev/devices */ #define DEVICE_BUFFER_SIZE 2 * 1024 * 1024 /* Max length of the message from the device */ enum diff --git a/http.c b/http.c index 0c53ad74..cd8f795a 100644 --- a/http.c +++ b/http.c @@ -37,7 +37,11 @@ void inject_header(buffer_t *buffer, struct phr_header *headers, size_t num_head const char *old_value_pos = headers[i].value; // shift the rest of the buffer - buffer_grow(buffer, shift); // resize the buffer if necessary + if (!buffer_grow(buffer, shift)) // resize the buffer if necessary + { + pr_err("could not inject header: could not allocate memory"); + continue; + } memmove(old_value_pos + new_value_len, old_value_pos + headers[i].value_len, buffer->size - (old_value_pos + headers[i].value_len - buffer->data)); // inject the new value @@ -63,7 +67,11 @@ void inject_header(buffer_t *buffer, struct phr_header *headers, size_t num_head char *new_header_pos = headers[num_headers - 1].value + headers[num_headers - 1].value_len + 2; // shift the rest of the buffer - buffer_grow(buffer, new_header_len); // resize the buffer if necessary + if (!buffer_grow(buffer, new_header_len)) // resize the buffer if necessary + { + pr_err("could not inject header: could not allocate memory"); + return; + } memmove(new_header_pos + new_header_len, new_header_pos, buffer->size - (new_header_pos - buffer->data)); // inject the new header diff --git a/main.c b/main.c index e80e01fd..8080559f 100644 --- a/main.c +++ b/main.c @@ -39,24 +39,77 @@ bool ktls_available = true; module_param(ktls_available, bool, 0644); MODULE_PARM_DESC(ktls_available, "Marks if kTLS is available on the system"); +typedef struct camblet_init_status +{ + bool wasm; + bool chardev; + bool socket; + bool sd_table; + bool config; +} camblet_init_status; + +camblet_init_status __camblet_init_status = {0}; + +static void __camblet_exit(void) +{ + if (__camblet_init_status.socket) + socket_exit(); + if (__camblet_init_status.chardev) + chardev_exit(); + if (__camblet_init_status.wasm) + wasm_vm_destroy_per_cpu(); + if (__camblet_init_status.sd_table) + sd_table_free(); + if (__camblet_init_status.config) + camblet_config_free(); +} + static int __init camblet_init(void) { - pr_info("module loaded at 0x%p running on %d CPUs", camblet_init, num_online_cpus()); + int ret = 0; + + pr_info("load module at 0x%p running on %d CPUs", camblet_init, num_online_cpus()); wasm_vm_result result = wasm_vm_new_per_cpu(); if (result.err) { FATAL("wasm_vm_new_per_cpu: %s", result.err); - return -1; + ret = -1; + goto out; + } + __camblet_init_status.wasm = true; + + ret = camblet_config_init(); + if (ret < 0) + { + FATAL("could not init config: %d", ret); + goto out; } + __camblet_init_status.config = true; - camblet_config_init(); - sd_table_init(); + ret = sd_table_init(); + if (ret < 0) + { + FATAL("could not init sd table: %d", ret); + goto out; + } + __camblet_init_status.sd_table = true; - int ret = 0; + ret = chardev_init(); + if (ret < 0) + { + FATAL("could not init char device: %d", ret); + goto out; + } + __camblet_init_status.chardev = true; - ret += chardev_init(); - ret += socket_init(); + ret = socket_init(); + if (ret < 0) + { + FATAL("could not init socket proto: %d", ret); + goto out; + } + __camblet_init_status.socket = true; if (proxywasm_modules) { @@ -64,14 +117,16 @@ static int __init camblet_init(void) if (result.err) { FATAL("load_module -> proxywasm_tcp_metadata_filter: %s", result.err); - return -1; + ret = -1; + goto out; } result = load_module("proxywasm_stats_filter", filter_stats, size_filter_stats, "_initialize"); if (result.err) { FATAL("load_module -> proxywasm_stats_filter: %s", result.err); - return -1; + ret = -1; + goto out; } } @@ -79,27 +134,30 @@ static int __init camblet_init(void) if (result.err) { FATAL("load_module -> csr_module: %s", result.err); - return -1; + ret = -1; + goto out; } result = load_module("socket_opa", socket_wasm, socket_wasm_len, NULL); if (result.err) { FATAL("load_module -> socket_opa: %s", result.err); - return -1; + ret = -1; + goto out; } +out: + if (ret < 0) + __camblet_exit(); + else + pr_info("module loaded at 0x%p running on %d CPUs", camblet_init, num_online_cpus()); + return ret; } static void __exit camblet_exit(void) { - socket_exit(); - chardev_exit(); - wasm_vm_destroy_per_cpu(); - - sd_table_free(); - camblet_config_free(); + __camblet_exit(); pr_info("%s: module unloaded from 0x%p", KBUILD_MODNAME, camblet_exit); } diff --git a/opa.c b/opa.c index 565e614f..fad15201 100644 --- a/opa.c +++ b/opa.c @@ -228,12 +228,15 @@ static wasm_vm_result parse_opa_builtins(opa_wrapper *opa, char *json) pr_debug("opa builtins # json[%s]", json); if (builtin_count == 0) - { goto bail; - } // indexing starts from 1 for some reason, so we need one bigger array opa->builtins = kzalloc(builtin_count + 1 * sizeof(void *), GFP_KERNEL); + if (!opa->builtins) + { + result = wasm_vm_error("could not allocate memory"); + goto bail; + } int i; for (i = 0; i < builtin_count; i++) @@ -285,15 +288,24 @@ void opa_socket_context_free(opa_socket_context ctx) opa_socket_context parse_opa_socket_eval_result(char *json) { - JSON_Value *root_value = json_parse_string(json); + JSON_Value *root_value = NULL; opa_socket_context ret = {0}; - if (root_value) + root_value = json_parse_string(json); + if (!root_value) + { + ret.error = -EINVAL; + return ret; + } + else { JSON_Array *results = json_value_get_array(root_value); JSON_Object *result = json_array_get_object(results, 0); - if (result == NULL) + if (!result) + { + ret.error = -EINVAL; goto free; + } JSON_Array *policies; @@ -357,7 +369,12 @@ opa_socket_context parse_opa_socket_eval_result(char *json) ttl = json_object_dotget_string(matched_policy, "certificate.ttl"); if (ttl != NULL) { - ret.ttl = strdup(ttl); + ret.ttl = kstrdup(ttl, GFP_KERNEL); + if (!ret.ttl) + { + ret.error = -ENOMEM; + goto free; + } } size_t egress_id_len = json_object_dotget_string_len(matched_policy, "egress.id"); @@ -374,6 +391,11 @@ opa_socket_context parse_opa_socket_eval_result(char *json) camblet_config *config = camblet_config_get_locked(); size_t trust_domain_len = strlen(config->trust_domain); ret.id = kzalloc(egress_id_len + id_len + trust_domain_len + ttl_len + 1, GFP_KERNEL); + if (!ret.id) + { + ret.error = -ENOMEM; + goto free; + } strcat(ret.id, config->trust_domain); camblet_config_unlock(); @@ -405,7 +427,12 @@ opa_socket_context parse_opa_socket_eval_result(char *json) ret.allowed_spiffe_ids_length = json_array_get_count(allowed_spiffe_ids); for (i = 0; i < ret.allowed_spiffe_ids_length; i++) { - ret.allowed_spiffe_ids[i] = strdup(json_array_get_string(allowed_spiffe_ids, i)); + ret.allowed_spiffe_ids[i] = kstrdup(json_array_get_string(allowed_spiffe_ids, i), GFP_KERNEL); + if (!ret.allowed_spiffe_ids[i]) + { + ret.error = -ENOMEM; + goto free; + } } } @@ -441,6 +468,11 @@ opa_socket_context parse_opa_socket_eval_result(char *json) camblet_config *config = camblet_config_get_locked(); int workload_id_len = snprintf(NULL, 0, "spiffe://%s/%s", config->trust_domain, workload_id); ret.workload_id = kzalloc(workload_id_len + 1, GFP_KERNEL); + if (!ret.workload_id) + { + ret.error = -ENOMEM; + goto free; + } snprintf(ret.workload_id, workload_id_len + 1, "spiffe://%s/%s", config->trust_domain, workload_id); camblet_config_unlock(); } @@ -463,6 +495,11 @@ opa_socket_context parse_opa_socket_eval_result(char *json) } ret.dns = kzalloc(dns_len + 1, GFP_KERNEL); + if (!ret.dns) + { + ret.error = -ENOMEM; + goto free; + } for (i = 0; i < json_array_get_count(dns); i++) { strcat(ret.dns, json_array_get_string(dns, i)); @@ -650,6 +687,10 @@ wasm_vm_result init_opa_for(wasm_vm *vm, wasm_vm_module *module) wasm_vm_function *builtinsFunc; opa_wrapper *opa = kmalloc(sizeof(struct opa_wrapper), GFP_KERNEL); + if (!opa) + { + return wasm_vm_error("could not allocate memory"); + } wasm_vm_try_get_function(opa->malloc, wasm_vm_get_function(vm, module->name, "opa_malloc")); wasm_vm_try_get_function(opa->free, wasm_vm_get_function(vm, module->name, "opa_free")); wasm_vm_try_get_function(opa->eval, wasm_vm_get_function(vm, module->name, "opa_eval")); diff --git a/opa.h b/opa.h index 16ed4848..007d945a 100644 --- a/opa.h +++ b/opa.h @@ -31,6 +31,7 @@ typedef struct char *ttl; char *allowed_spiffe_ids[MAX_ALLOWED_SPIFFE_ID]; int allowed_spiffe_ids_length; + int error; } opa_socket_context; void opa_socket_context_free(opa_socket_context ctx); diff --git a/rsa_tools.c b/rsa_tools.c index 416c1b8d..2c81b735 100644 --- a/rsa_tools.c +++ b/rsa_tools.c @@ -41,9 +41,17 @@ uint32_t generate_rsa_keys(br_rsa_private_key *rsa_priv, br_rsa_public_key *rsa_ br_rsa_keygen rsa_keygen = br_rsa_keygen_get_default(); unsigned char *raw_priv_key = kmalloc(BR_RSA_KBUF_PRIV_SIZE(RSA_BIT_LENGTH), GFP_KERNEL); + if (!raw_priv_key) + return -ENOMEM; unsigned char *raw_pub_key = kmalloc(BR_RSA_KBUF_PUB_SIZE(RSA_BIT_LENGTH), GFP_KERNEL); + if (!raw_pub_key) + return -ENOMEM; - return rsa_keygen(&hmac_drbg_ctx.vtable, rsa_priv, raw_priv_key, rsa_pub, raw_pub_key, RSA_BIT_LENGTH, RSA_PUB_EXP); + int ret = rsa_keygen(&hmac_drbg_ctx.vtable, rsa_priv, raw_priv_key, rsa_pub, raw_pub_key, RSA_BIT_LENGTH, RSA_PUB_EXP); + if (ret == 0) + return -1; + + return ret; } void free_rsa_private_key(br_rsa_private_key *key) diff --git a/sd.c b/sd.c index df59d2b4..e5afbf9f 100644 --- a/sd.c +++ b/sd.c @@ -32,13 +32,24 @@ static void sd_table_unlock(void) service_discovery_table *service_discovery_table_create() { service_discovery_table *table = kzalloc(sizeof(service_discovery_table), GFP_KERNEL); + if (!table) + { + return ERR_PTR(-ENOMEM); + } + hash_init(table->htable); return table; } -void sd_table_init() +int sd_table_init() { sd_table = service_discovery_table_create(); + if (IS_ERR(sd_table)) + { + return PTR_ERR(sd_table); + } + + return 0; } static u64 sd_entry_hash(const char *name, int len) @@ -101,7 +112,7 @@ void sd_table_entry_del(service_discovery_entry *entry) sd_table_unlock(); } -static void service_discovery_entry_free(service_discovery_entry *entry) +void service_discovery_entry_free(service_discovery_entry *entry) { if (!entry) { @@ -120,10 +131,8 @@ static void service_discovery_entry_free(service_discovery_entry *entry) static void service_discovery_table_free_locked(service_discovery_table *table) { - if (!table) - { + if (IS_ERR_OR_NULL(table)) return; - } service_discovery_entry *entry; int i; diff --git a/sd.h b/sd.h index fbcbbf03..c6b7fc3f 100644 --- a/sd.h +++ b/sd.h @@ -26,13 +26,18 @@ typedef struct service_discovery_table DECLARE_HASHTABLE(htable, 8); } service_discovery_table; -void sd_table_init(void); +int sd_table_init(void); void sd_table_free(void); service_discovery_entry *sd_table_entry_get(const char *address); void sd_table_replace(service_discovery_table *table); void sd_table_entry_del(service_discovery_entry *entry); - +/* + * service_discovery_table_create + * + * returns a service_discovery_table struct pointer or ERR_PTR() on error + */ service_discovery_table *service_discovery_table_create(void); void service_discovery_table_entry_add(service_discovery_table *table, service_discovery_entry *entry); +void service_discovery_entry_free(service_discovery_entry *entry); #endif diff --git a/socket.c b/socket.c index aa6911ac..66c3cd50 100644 --- a/socket.c +++ b/socket.c @@ -128,6 +128,7 @@ struct camblet_socket }; static int get_read_buffer_capacity(camblet_socket *s); +static void tcp_connection_context_free(tcp_connection_context *ctx); static br_ssl_engine_context *get_ssl_engine_context(camblet_socket *s) { @@ -334,7 +335,7 @@ static void set_write_buffer_size(camblet_socket *s, int size) static void camblet_socket_free(camblet_socket *s) { - if (s) + if (!IS_ERR_OR_NULL(s)) { pr_debug("free camblet socket # command[%s]", current->comm); @@ -370,8 +371,7 @@ static void camblet_socket_free(camblet_socket *s) x509_certificate_put(s->cert); kfree(s->parameters); - kfree(s->conn_ctx->peer_spiffe_id); - kfree(s->conn_ctx); + tcp_connection_context_free(s->conn_ctx); kfree(s); } } @@ -405,13 +405,26 @@ int proxywasm_attach(proxywasm *p, camblet_socket *s, ListenerDirection directio static camblet_socket *camblet_new_server_socket(struct sock *sock, opa_socket_context opa_socket_ctx, tcp_connection_context *conn_ctx) { camblet_socket *s = kzalloc(sizeof(camblet_socket), GFP_KERNEL); + if (!s) + return ERR_PTR(-ENOMEM); s->sc = kzalloc(sizeof(br_ssl_server_context), GFP_KERNEL); + if (!s->sc) + goto enomem; s->rsa_priv = kzalloc(sizeof(br_rsa_private_key), GFP_KERNEL); + if (!s->rsa_priv) + goto enomem; s->rsa_pub = kzalloc(sizeof(br_rsa_public_key), GFP_KERNEL); + if (!s->rsa_pub) + goto enomem; s->parameters = kzalloc(sizeof(csr_parameters), GFP_KERNEL); + if (!s->parameters) + goto enomem; s->read_buffer = buffer_new(16 * 1024); + if (IS_ERR(s->read_buffer)) + goto enomem; s->write_buffer = buffer_new(16 * 1024); - + if (IS_ERR(s->write_buffer)) + goto enomem; s->sock = sock; s->opa_socket_ctx = opa_socket_ctx; @@ -437,17 +450,39 @@ static camblet_socket *camblet_new_server_socket(struct sock *sock, opa_socket_c } return s; + +enomem: + camblet_socket_free(s); + return ERR_PTR(-ENOMEM); } static camblet_socket *camblet_new_client_socket(struct sock *sock, opa_socket_context opa_socket_ctx, tcp_connection_context *conn_ctx) { camblet_socket *s = kzalloc(sizeof(camblet_socket), GFP_KERNEL); + if (!s) + return ERR_PTR(-ENOMEM); s->cc = kzalloc(sizeof(br_ssl_client_context), GFP_KERNEL); + if (!s->sc) + goto enomem; s->rsa_priv = kzalloc(sizeof(br_rsa_private_key), GFP_KERNEL); + if (!s->rsa_priv) + goto enomem; s->rsa_pub = kzalloc(sizeof(br_rsa_public_key), GFP_KERNEL); + if (!s->rsa_pub) + goto enomem; s->parameters = kzalloc(sizeof(csr_parameters), GFP_KERNEL); + if (!s->parameters) + goto enomem; s->read_buffer = buffer_new(16 * 1024); + if (IS_ERR(s->read_buffer)) + { + goto enomem; + } s->write_buffer = buffer_new(16 * 1024); + if (IS_ERR(s->write_buffer)) + { + goto enomem; + } s->sock = sock; s->opa_socket_ctx = opa_socket_ctx; @@ -474,6 +509,10 @@ static camblet_socket *camblet_new_client_socket(struct sock *sock, opa_socket_c } return s; + +enomem: + camblet_socket_free(s); + return ERR_PTR(-ENOMEM); } void dump_array(unsigned char array[], size_t len) @@ -690,7 +729,13 @@ int camblet_recvmsg(struct sock *sock, while (action != Continue) { - ret = camblet_socket_read(s, get_read_buffer_for_read(s, len), len, flags); + char *buf = get_read_buffer_for_read(s, len); + if (!buf) + { + ret = -ENOMEM; + goto bail; + } + ret = camblet_socket_read(s, buf, len, flags); if (ret < 0) { if (ret == -ERESTARTSYS) @@ -789,7 +834,13 @@ int camblet_sendmsg(struct sock *sock, struct msghdr *msg, size_t size) size_t prevbuflen = get_write_buffer_size(s); - len = copy_from_iter(get_write_buffer_for_write(s, size), size, &msg->msg_iter); + char *buf = get_write_buffer_for_write(s, size); + if (!buf) + { + ret = -ENOMEM; + goto bail; + } + len = copy_from_iter(buf, size, &msg->msg_iter); set_write_buffer_size(s, get_write_buffer_size(s) + len); @@ -1070,6 +1121,11 @@ int camblet_setsockopt(struct sock *sk, int level, opa_socket_context opa_socket_ctx = {.allowed = true, .passthrough = false, .mtls = false}; tcp_connection_context conn_ctx = {.direction = OUTPUT}; camblet_socket *s = camblet_new_client_socket(sk, opa_socket_ctx, &conn_ctx); + if (IS_ERR(s)) + { + pr_err("camblet_setsockopt error # command[%s] error_code[%ld]", current->comm, PTR_ERR(s)); + return PTR_ERR(s); + } sk->sk_user_data = s; @@ -1088,6 +1144,8 @@ int camblet_setsockopt(struct sock *sk, int level, } s->hostname = kzalloc(optlen, GFP_KERNEL); + if (!s->hostname) + return -ENOMEM; copy_from_sockptr(s->hostname, optval, optlen); return 0; } @@ -1220,10 +1278,10 @@ static int handle_cert_gen_locked(camblet_socket *sc) if (sc->rsa_priv->plen == 0 || sc->rsa_pub->elen == 0) { u_int32_t result = generate_rsa_keys(sc->rsa_priv, sc->rsa_pub); - if (result == 0) + if (result < 0) { pr_err("could not generate rsa keys"); - return -1; + return result; } } @@ -1287,7 +1345,12 @@ static int handle_cert_gen_locked(camblet_socket *sc) return -1; } - csr_ptr = strndup(generated_csr.csr_ptr + mem, generated_csr.csr_len); + csr_ptr = kstrndup(generated_csr.csr_ptr + mem, generated_csr.csr_len, GFP_KERNEL); + if (!csr_ptr) + { + csr_unlock(csr); + return -ENOMEM; + } free_result = csr_free(csr, generated_csr.csr_ptr); if (free_result.err) { @@ -1299,19 +1362,20 @@ static int handle_cert_gen_locked(camblet_socket *sc) csr_sign_answer *csr_sign_answer; csr_sign_answer = send_csrsign_command(csr_ptr, sc->opa_socket_ctx.ttl); + if (IS_ERR(csr_sign_answer)) + return PTR_ERR(csr_sign_answer); if (csr_sign_answer->error) { pr_err("error during CSR signing # err[%s]", csr_sign_answer->error); - kfree(csr_sign_answer->error); - kfree(csr_sign_answer); + x509_certificate_put(csr_sign_answer->cert); + csr_sign_answer_free(csr_sign_answer); return -1; } - else - { - x509_certificate_get(csr_sign_answer->cert); - sc->cert = csr_sign_answer->cert; - } - kfree(csr_sign_answer); + + x509_certificate_get(csr_sign_answer->cert); + sc->cert = csr_sign_answer->cert; + csr_sign_answer_free(csr_sign_answer); + return 0; } @@ -1335,11 +1399,11 @@ static int cache_and_validate_cert(camblet_socket *sc, char *key) { regen_cert: err = handle_cert_gen(sc); - if (err == -1) - { - return -1; - } - add_cert_to_cache(key, sc->cert); + if (err < 0) + return err; + err = add_cert_to_cache(key, sc->cert); + if (err < 0) + return err; } // Cert found in the cache use that else @@ -1368,6 +1432,8 @@ static int cache_and_validate_cert(camblet_socket *sc, char *key) static tcp_connection_context *tcp_connection_context_init(direction direction, struct sock *s, u16 port) { tcp_connection_context *ctx = kzalloc(sizeof(tcp_connection_context), GFP_KERNEL); + if (!ctx) + return ERR_PTR(-ENOMEM); ctx->direction = direction; ctx->id = (u64)s; @@ -1422,6 +1488,15 @@ static tcp_connection_context *tcp_connection_context_init(direction direction, return ctx; } +static void tcp_connection_context_free(tcp_connection_context *ctx) +{ + if (IS_ERR_OR_NULL(ctx)) + return; + + kfree(ctx->peer_spiffe_id); + kfree(ctx); +} + void add_sd_entry_labels_to_json(service_discovery_entry *sd_entry, JSON_Value *json) { if (!json) @@ -1513,6 +1588,11 @@ static command_answer *prepare_opa_input(const tcp_connection_context *conn_ctx, } answer = kzalloc(sizeof(struct command_answer), GFP_KERNEL); + if (!answer) + { + json_value_free(json); + return ERR_PTR(-ENOMEM); + } answer->answer = json_serialize_to_string(json); cleanup: @@ -1576,6 +1656,14 @@ opa_socket_context enriched_socket_eval(const tcp_connection_context *conn_ctx, // augmenting process connection augmentation_response *response = augment_workload(); + if (IS_ERR(response)) + { + opa_socket_ctx.error = PTR_ERR(response); + char err_code_str[8]; + snprintf(err_code_str, 8, "%d", opa_socket_ctx.error); + trace_err(conn_ctx, "could not augment process connection", 2, "error_code", err_code_str); + goto ret; + } if (response->error) { trace_err(conn_ctx, "could not augment process connection", 2, "error", response->error); @@ -1597,10 +1685,11 @@ opa_socket_context enriched_socket_eval(const tcp_connection_context *conn_ctx, augmentation_response_put(response); } +ret: return opa_socket_ctx; } -void camblet_configure_server_tls(camblet_socket *sc) +int camblet_configure_server_tls(camblet_socket *sc) { /* * Initialise the context with the cipher suites and @@ -1624,7 +1713,10 @@ void camblet_configure_server_tls(camblet_socket *sc) br_ssl_server_set_trust_anchor_names_alt(sc->sc, sc->cert->trust_anchors, sc->cert->trust_anchors_len); bool insecure; - br_x509_camblet_init(&sc->xc, &sc->sc->eng, &sc->opa_socket_ctx, sc->conn_ctx, insecure = false); + int ret = br_x509_camblet_init(&sc->xc, &sc->sc->eng, &sc->opa_socket_ctx, sc->conn_ctx, insecure = false); + if (ret < 0) + return ret; + br_ssl_engine_set_default_rsavrfy(&sc->sc->eng); } @@ -1649,9 +1741,11 @@ void camblet_configure_server_tls(camblet_socket *sc) * Initialise the simplified I/O wrapper context. */ br_sslio_init(&sc->ioc, &sc->sc->eng, br_low_read, sc, br_low_write, sc); + + return 0; } -void camblet_configure_client_tls(camblet_socket *sc) +int camblet_configure_client_tls(camblet_socket *sc) { /* * Initialise the context with the cipher suites and @@ -1689,7 +1783,9 @@ void camblet_configure_client_tls(camblet_socket *sc) (sizeof suites) / (sizeof suites[0])); bool insecure; - br_x509_camblet_init(&sc->xc, &sc->cc->eng, &sc->opa_socket_ctx, sc->conn_ctx, insecure = trust_anchors_len == 0); + int ret = br_x509_camblet_init(&sc->xc, &sc->cc->eng, &sc->opa_socket_ctx, sc->conn_ctx, insecure = trust_anchors_len == 0); + if (ret < 0) + return ret; // mTLS enablement if (sc->opa_socket_ctx.mtls) @@ -1724,6 +1820,23 @@ void camblet_configure_client_tls(camblet_socket *sc) * SSL client context, and the two callbacks for socket I/O. */ br_sslio_init(&sc->ioc, &sc->cc->eng, br_low_read, sc, br_low_write, sc); + + return 0; +} + +void socket_reset(struct sock *sk) +{ + pr_info("sk reset\n"); + tcp_set_state(sk, TCP_CLOSE); +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 1, 0) + if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) + inet_reset_saddr(sk); +#else + inet_bhash2_reset_saddr(sk); +#endif + sk->sk_route_caps = 0; + struct inet_sock *inet = inet_sk(sk); + inet->inet_dport = 0; } struct sock *camblet_accept(struct sock *sk, int flags, int *err, bool kern) @@ -1745,7 +1858,7 @@ struct sock *camblet_accept(struct sock *sk, int flags, int *err, bool kern) if (!client_sk && *err != 0) { - goto error; + return NULL; } // return if the agent is not running @@ -1757,8 +1870,20 @@ struct sock *camblet_accept(struct sock *sk, int flags, int *err, bool kern) u16 port = (u16)(sk->sk_portpair >> 16); tcp_connection_context *conn_ctx = tcp_connection_context_init(INPUT, client_sk, port); + if (IS_ERR(conn_ctx)) + { + *err = PTR_ERR(conn_ctx); + goto error; + } opa_socket_context opa_socket_ctx = enriched_socket_eval(conn_ctx, INPUT, client_sk, port); + if (opa_socket_ctx.error < 0) + { + tcp_connection_context_free(conn_ctx); + *err = opa_socket_ctx.error; + opa_socket_context_free(opa_socket_ctx); + goto error; + } if (opa_socket_ctx.allowed) { @@ -1769,9 +1894,15 @@ struct sock *camblet_accept(struct sock *sk, int flags, int *err, bool kern) } sc = camblet_new_server_socket(client_sk, opa_socket_ctx, conn_ctx); + if (IS_ERR(sc)) + { + pr_err("could not create camblet server socket"); + *err = PTR_ERR(sc); + goto error; + } if (!sc) { - pr_err("could not create camblet socket"); + pr_err("could not create camblet server socket"); goto error; } @@ -1779,12 +1910,18 @@ struct sock *camblet_accept(struct sock *sk, int flags, int *err, bool kern) memcpy(sc->rsa_pub, rsa_pub, sizeof *sc->rsa_pub); int result = cache_and_validate_cert(sc, sc->opa_socket_ctx.id); - if (result == -1) + if (result < 0) { + *err = result; goto error; } - camblet_configure_server_tls(sc); + result = camblet_configure_server_tls(sc); + if (result < 0) + { + *err = result; + goto error; + } // We should save the ssl context here to the socket // and overwrite the socket protocol with our own @@ -1823,9 +1960,7 @@ int camblet_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) } if (err != 0) - { - goto error; - } + return err; // return if the agent is not running, and we don't have a socket context attached to the socket if (sc == NULL && atomic_read(&already_open) == CDEV_NOT_USED) @@ -1834,22 +1969,41 @@ int camblet_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) } tcp_connection_context *conn_ctx = tcp_connection_context_init(OUTPUT, sk, port); + if (IS_ERR(conn_ctx)) + { + err = PTR_ERR(conn_ctx); + goto error; + } opa_socket_context opa_socket_ctx = enriched_socket_eval(conn_ctx, OUTPUT, sk, port); + if (opa_socket_ctx.error < 0) + { + tcp_connection_context_free(conn_ctx); + err = opa_socket_ctx.error; + opa_socket_context_free(opa_socket_ctx); + goto error; + } if (opa_socket_ctx.allowed) { if (!opa_socket_ctx.workload_id_is_valid) { - return -CAMBLET_EINVALIDSPIFFEID; + err = -CAMBLET_EINVALIDSPIFFEID; + goto error; } if (!sc) { sc = camblet_new_client_socket(sk, opa_socket_ctx, conn_ctx); + if (IS_ERR(sc)) + { + pr_err("could not create camblet client socket"); + err = PTR_ERR(sc); + goto error; + } if (!sc) { - pr_err("could not create camblet socket"); + pr_err("could not create camblet client socket"); goto error; } } @@ -1860,13 +2014,19 @@ int camblet_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) if (sc->opa_socket_ctx.mtls) { int result = cache_and_validate_cert(sc, sc->opa_socket_ctx.id); - if (result == -1) + if (result < 0) { + err = result; goto error; } } - camblet_configure_client_tls(sc); + int result = camblet_configure_client_tls(sc); + if (result < 0) + { + err = result; + goto error; + } // We should save the ssl context here to the socket // and overwrite the socket protocol with our own @@ -1878,10 +2038,7 @@ int camblet_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) error: camblet_socket_free(sc); - - lock_sock(sk); - sk->sk_prot->close(sk, 0); - release_sock(sk); + socket_reset(sk); return err; } @@ -1945,12 +2102,16 @@ int socket_init(void) //- generate global tls key rsa_priv = kzalloc(sizeof(br_rsa_private_key), GFP_KERNEL); + if (!rsa_priv) + return -ENOMEM; rsa_pub = kzalloc(sizeof(br_rsa_public_key), GFP_KERNEL); + if (!rsa_pub) + return -ENOMEM; u_int32_t result = generate_rsa_keys(rsa_priv, rsa_pub); - if (result == 0) + if (result < 0) { pr_err("could not generate rsa keys"); - return -1; + return result; } pr_info("socket support loaded"); diff --git a/string.c b/string.c index ad8744e2..834b7a49 100644 --- a/string.c +++ b/string.c @@ -10,18 +10,6 @@ #include -char *strndup(const char *str, size_t size) -{ - char *dst = kzalloc(size + 1, GFP_KERNEL); - return strncpy(dst, str, size); -} - -char *strdup(const char *str) -{ - int len = strlen(str); - return strndup(str, len); -} - char *strprintf(const char *fmt, ...) { va_list args; @@ -31,6 +19,8 @@ char *strprintf(const char *fmt, ...) va_end(args); char *dst = kzalloc(len, GFP_KERNEL); + if (!dst) + return ERR_PTR(-ENOMEM); va_start(args, fmt); vsnprintf(dst, len, fmt, args); diff --git a/string.h b/string.h index c75afa80..ee64dad1 100644 --- a/string.h +++ b/string.h @@ -13,8 +13,6 @@ #include -char *strdup(const char *str); -char *strndup(const char *str, size_t size); char *strprintf(const char *fmt, ...); #endif diff --git a/task_context.c b/task_context.c index 540ec6e1..a6727ff9 100644 --- a/task_context.c +++ b/task_context.c @@ -29,12 +29,21 @@ struct mnt_namespace struct ns_common ns; }; +static char *get_current_proc_path(char *buf, int buflen); + task_context *get_task_context(void) { struct task_context *context = kmalloc(sizeof(struct task_context), GFP_KERNEL); + if (!context) + return ERR_PTR(-ENOMEM); strcpy(context->command_name, current->comm); context->command_path = get_current_proc_path(context->command_path_buffer, sizeof(context->command_path_buffer)); + if (IS_ERR(context->command_path)) + { + free_task_context(context); + return (void *)context->command_path; + } current_uid_gid(&context->uid, &context->gid); context->pid = current->pid; @@ -58,10 +67,18 @@ task_context *get_task_context(void) void free_task_context(struct task_context *context) { + if (IS_ERR_OR_NULL(context)) + return; + kfree(context); } -char *get_current_proc_path(char *buf, int buflen) +/* + * get_current_proc_path + * + * returns a pointer into the buffer or ERR_PTR() on error + */ +static char *get_current_proc_path(char *buf, int buflen) { struct file *exe_file; char *result = ERR_PTR(-ENOENT); diff --git a/task_context.h b/task_context.h index cf508bdc..50ffad1d 100644 --- a/task_context.h +++ b/task_context.h @@ -17,8 +17,6 @@ #define MAX_PATH_LEN 256 #define MAX_COMM_LEN 64 -char *get_current_proc_path(char *buf, int buflen); - struct namespace_ids { unsigned int uts; @@ -42,6 +40,11 @@ typedef struct task_context char cgroup_path[MAX_PATH_LEN]; } task_context; +/* + * get_task_context + * + * returns a task_context struct pointer or ERR_PTR() on error + */ task_context *get_task_context(void); void free_task_context(struct task_context *context); diff --git a/third-party/BearSSL b/third-party/BearSSL index bab27368..55dcf07d 160000 --- a/third-party/BearSSL +++ b/third-party/BearSSL @@ -1 +1 @@ -Subproject commit bab27368a2a365111482fbf763641d21d6905348 +Subproject commit 55dcf07dbb693dd1eed3985e3a1f3cca85cc5500 diff --git a/tls.c b/tls.c index 1ad51d70..2618f83b 100644 --- a/tls.c +++ b/tls.c @@ -146,7 +146,12 @@ xwc_end_chain(const br_x509_class **ctx) { if (camblet_cc->conn_ctx->peer_spiffe_id == NULL && mini_cc->name_elts[i].buf != NULL) { - camblet_cc->conn_ctx->peer_spiffe_id = strdup(mini_cc->name_elts[i].buf); + camblet_cc->conn_ctx->peer_spiffe_id = kstrdup(mini_cc->name_elts[i].buf, GFP_KERNEL); + if (!camblet_cc->conn_ctx->peer_spiffe_id) + { + pr_crit("xwc_end_chain: could not allocate memory"); + break; + } } spiffe_id = mini_cc->name_elts[i].buf; @@ -190,38 +195,55 @@ static const br_x509_class x509_camblet_vtable = { xwc_get_pkey, }; -void br_x509_camblet_init(br_x509_camblet_context *ctx, br_ssl_engine_context *eng, opa_socket_context *socket_context, tcp_connection_context *conn_ctx, bool insecure) -{ - ctx->vtable = &x509_camblet_vtable; - ctx->socket_context = socket_context; - ctx->conn_ctx = conn_ctx; - ctx->insecure = insecure; +static void br_x509_name_elts_free(br_name_element *name_elts, size_t num); +int br_x509_camblet_init(br_x509_camblet_context *ctx, br_ssl_engine_context *eng, opa_socket_context *socket_context, tcp_connection_context *conn_ctx, bool insecure) +{ br_name_element *name_elts = kmalloc(sizeof(br_name_element) * 3, GFP_KERNEL); + if (!name_elts) + return -ENOMEM; char const *oids[] = {OID_rfc822Name, OID_dNSName, OID_uniformResourceIdentifier}; int i; + int num = 0; for (i = 0; i < sizeof(oids) / sizeof(oids[0]); i++) { name_elts[i].oid = oids[i]; name_elts[i].buf = kmalloc(sizeof(char) * 256, GFP_KERNEL); + if (!name_elts[i].buf) + { + br_x509_name_elts_free(name_elts, num); + return -ENOMEM; + } name_elts[i].len = 256; + num++; } - br_x509_minimal_set_name_elements(&ctx->ctx, name_elts, sizeof(oids) / sizeof(oids[0])); + ctx->vtable = &x509_camblet_vtable; + ctx->socket_context = socket_context; + ctx->conn_ctx = conn_ctx; + ctx->insecure = insecure; + br_x509_minimal_set_name_elements(&ctx->ctx, name_elts, num); br_ssl_engine_set_x509(eng, &ctx->vtable); + + return 0; } -void br_x509_camblet_free(br_x509_camblet_context *ctx) +static void br_x509_name_elts_free(br_name_element *name_elts, size_t num) { int i; - for (i = 0; i < ctx->ctx.num_name_elts; i++) + for (i = 0; i < num; i++) { - kfree(ctx->ctx.name_elts[i].buf); + kfree(name_elts[i].buf); } - kfree(ctx->ctx.name_elts); + kfree(name_elts); +} + +void br_x509_camblet_free(br_x509_camblet_context *ctx) +{ + br_x509_name_elts_free(ctx->ctx.name_elts, ctx->ctx.num_name_elts); } void setup_aes_gcm_128_crypto_info(crypto_info *crypto_info, const uint8_t *iv, const uint8_t *key, uint64_t seq) diff --git a/tls.h b/tls.h index 00b048dd..eae692a9 100644 --- a/tls.h +++ b/tls.h @@ -44,7 +44,7 @@ typedef struct crypto_info size_t cipher_type_len; } crypto_info; -void br_x509_camblet_init(br_x509_camblet_context *ctx, br_ssl_engine_context *eng, opa_socket_context *socket_context, tcp_connection_context *conn_ctx, bool insecure); +int br_x509_camblet_init(br_x509_camblet_context *ctx, br_ssl_engine_context *eng, opa_socket_context *socket_context, tcp_connection_context *conn_ctx, bool insecure); void br_x509_camblet_free(br_x509_camblet_context *ctx); bool is_tls_handshake(const uint8_t *b); diff --git a/trace.c b/trace.c index 610492d5..4d9eafbb 100644 --- a/trace.c +++ b/trace.c @@ -253,10 +253,15 @@ char *compose_log_message(const char *message, int n, va_list args) int trace_log(const tcp_connection_context *conn_ctx, const char *message, int log_level, int n, ...) { + int ret = 0; unsigned int i; va_list args, args_copy; char *level = NULL; + task_context *task_ctx = get_task_context(); + if (IS_ERR(task_ctx)) + return PTR_ERR(task_ctx); + if (n < 0 || (n > 0 && n % 2 != 0)) { return -EINVAL; @@ -297,47 +302,48 @@ int trace_log(const tcp_connection_context *conn_ctx, const char *message, int l va_end(args); } - task_context *tc = get_task_context(); - trace_request *tr = get_trace_request_by_partial_match(tc->pid, tc->uid.val, tc->command_name); - free_task_context(tc); - - if (tr == NULL) + trace_request *tr = get_trace_request_by_partial_match(task_ctx->pid, task_ctx->uid.val, task_ctx->command_name); + if (!tr) { - return 0; + free_task_context(task_ctx); + return ret; } JSON_Value *root_value = json_value_init_object(); - JSON_Object *root_object = json_value_get_object(root_value); - if (!root_value) { + free_task_context(task_ctx); return -ENOMEM; } + JSON_Object *root_object = json_value_get_object(root_value); + if (json_object_set_string(root_object, "message", message) < 0) { - json_value_free(root_value); - - return -ENOMEM; + ret = -ENOMEM; + goto out; } if (log_level > 0 && json_object_set_string(root_object, "level", level) < 0) { - json_value_free(root_value); - - return -ENOMEM; + ret = -ENOMEM; + goto out; } if (conn_ctx) { const char *id_str = strprintf("%llu", conn_ctx->id); + if (IS_ERR(id_str)) + { + ret = PTR_ERR(id_str); + goto out; + } int retval = json_object_set_string(root_object, "correlation_id", id_str); kfree(id_str); if (retval < 0) { - json_value_free(root_value); - - return -ENOMEM; + ret = -ENOMEM; + goto out; } } @@ -353,9 +359,11 @@ int trace_log(const tcp_connection_context *conn_ctx, const char *message, int l } va_end(args); - send_message("log", json_serialize_to_string(root_value), get_task_context()); + send_message("log", json_serialize_to_string(root_value), task_ctx); +out: json_value_free(root_value); + free_task_context(task_ctx); - return 0; + return ret; } diff --git a/wasm.c b/wasm.c index d0d6a943..35014d7d 100644 --- a/wasm.c +++ b/wasm.c @@ -45,6 +45,11 @@ static M3Result m3_link_all(IM3Module module); wasm_vm *wasm_vm_new(int cpu) { wasm_vm *vm = kzalloc(sizeof(wasm_vm), GFP_KERNEL); + if (!vm) + { + pr_crit("wasm_vm_new: could not allocate memory"); + return NULL; + } vm->cpu = cpu; @@ -198,7 +203,12 @@ wasm_vm_result wasm_vm_load_module(wasm_vm *vm, const char *name, unsigned char goto on_error; } - char *module_name = strdup(name); + char *module_name = kstrdup(name, GFP_KERNEL); + if (!module_name) + { + result = "could not allocate memory"; + goto on_error; + } m3_SetModuleName(module, module_name); result = m3_link_all(module);