-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathaccident-glm.R
56 lines (46 loc) · 1.51 KB
/
accident-glm.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#' Copyright(c) Microsoft Corporation.
#' Licensed under the MIT license.
library(azuremlsdk)
library(optparse)
library(caret)
options <- list(
make_option(c("-d", "--data_folder")),
make_option(c("-p", "--percent_train"))
)
opt_parser <- OptionParser(option_list = options)
opt <- parse_args(opt_parser)
## Print data folder to log
paste(opt$data_folder)
accidents <- readRDS(file.path(opt$data_folder, "accidents.Rd"))
summary(accidents)
## Create data partition for use with caret
train.pct <- as.numeric(opt$percent_train)
if(length(train.pct)==0 || (train.pct<0) || (train.pct>1)) train.pct <- 0.75
accident_idx <- createDataPartition(accidents$dead, p = train.pct, list = FALSE)
accident_trn <- accidents[accident_idx, ]
accident_tst <- accidents[-accident_idx, ]
## utility function to calculate accuracy in test set
calc_acc = function(actual, predicted) {
mean(actual == predicted)
}
## Caret GLM model on training set with 5-fold cross validation
accident_glm_mod <- train(
form = dead ~ .,
data = accident_trn,
trControl = trainControl(method = "cv", number = 5),
method = "glm",
family = "binomial"
)
summary(accident_glm_mod)
log_metric_to_run("Accuracy",
calc_acc(actual = accident_tst$dead,
predicted = predict(accident_glm_mod, newdata = accident_tst))
)
log_metric_to_run("Method","GLM")
log_metric_to_run("TrainPCT",train.pct)
output_dir = "outputs"
if (!dir.exists(output_dir)){
dir.create(output_dir)
}
saveRDS(accident_glm_mod, file = "./outputs/model.rds")
message("Model saved")