Skip to content

Commit a8dbce2

Browse files
Patrick GerbesPatrick Gerbes
Patrick Gerbes
authored and
Patrick Gerbes
committed
Initial commit
0 parents  commit a8dbce2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2006
-0
lines changed

app/actors/BatchTrainer.scala

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package actors
2+
3+
import actors.Director.BatchTrainingFinished
4+
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
5+
import akka.event.LoggingReceive
6+
import models.preprocessing.DataFrame
7+
import models.math.OptimizationResult
8+
import models.math.OptimizationRoutine._
9+
import models.math.Optimizer._
10+
import models.math.WeightInitializer._
11+
import models.math.WeightUpdate
12+
import models.ml.LinearRegression._
13+
14+
trait BatchTrainerProxy extends Actor
15+
16+
class BatchTrainer(director: ActorRef) extends Actor with ActorLogging with BatchTrainerProxy {
17+
import BatchTrainer._
18+
19+
var model: Option[OptimizationResult] = None
20+
21+
override def receive = LoggingReceive {
22+
23+
case Train(featureData: DataFrame) =>
24+
25+
log.debug("Received Train message with feature data")
26+
log.info("Starting batch training")
27+
28+
// Warning: This is a hack until I can implement
29+
// this the right way.
30+
val modelTrainingConfig = optimize(
31+
iter = 200,
32+
seed = 123L,
33+
initAlpha = 0.1,
34+
momentum = 0.9,
35+
gradientFunction = linearRegressionGradient,
36+
costFunction = linearRegressionCost,
37+
_: WeightUpdate,
38+
miniBatchFraction = 0.1,
39+
weightInitializer = gaussianInitialization,
40+
_: DataFrame
41+
)
42+
43+
val SGDWithMomentum = modelTrainingConfig(stochasticGradientDescent, featureData)
44+
45+
SGDWithMomentum.weights.last
46+
47+
model = Option(SGDWithMomentum)
48+
49+
log.info("Batch training finished")
50+
51+
director ! BatchTrainingFinished
52+
53+
case GetLatestModel =>
54+
log.debug("Received GetLatestModel message")
55+
sender ! BatchTrainerModel(model)
56+
log.debug(s"Returned model $model")
57+
}
58+
}
59+
60+
object BatchTrainer {
61+
def props(director: ActorRef) = Props(new BatchTrainer(director: ActorRef))
62+
case class BatchTrainerModel(model: Option[OptimizationResult])
63+
case class BatchFeatures(features: Option[DataFrame])
64+
65+
}
66+
67+

app/actors/Classifier.scala

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package actors
2+
3+
import actors.Classifier.Classify
4+
import actors.DataHandler.Fetch
5+
import akka.actor._
6+
import akka.event.LoggingReceive
7+
import breeze.linalg.DenseVector
8+
import models.preprocessing.DataFrameMonad
9+
import services.prediction.Predictor
10+
11+
class Classifier(dataHandler: ActorRef,
12+
batchTrainer: ActorRef,
13+
predictor: Predictor) extends Actor with ActorLogging {
14+
15+
override def receive = LoggingReceive {
16+
17+
case Classify(eventId: Int) =>
18+
19+
log.info(s"Start predicting attendance for event '$eventId'")
20+
21+
val originalSender = sender
22+
23+
val handler = context.actorOf(
24+
FetchResponseHandler.props(batchTrainer, originalSender, predictor), "fetch-response-message-handler")
25+
26+
log.debug(s"Created handler $handler")
27+
28+
dataHandler.tell(Fetch(eventId), handler)
29+
}
30+
}
31+
32+
object Classifier {
33+
def props(dataHandler: ActorRef,
34+
batchTrainer: ActorRef,
35+
predictor: Predictor) = Props(new Classifier(dataHandler, batchTrainer, predictor))
36+
37+
type PredictionVector = DataFrameMonad[DenseVector[Double]]
38+
39+
case class Classify(eventId: Int)
40+
case class ClassificationResult(batchModelResult: PredictionVector)
41+
42+
object ClassificationResult {
43+
44+
}
45+
}
46+
47+
48+
49+
50+
51+
52+
53+
54+

app/actors/DataHandler.scala

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package actors
2+
3+
import akka.actor.{Actor, Props}
4+
import akka.event.LoggingReceive
5+
import actors.DataHandler.{Fetch, FetchResponse}
6+
import models.preprocessing.DataFrame
7+
import play.api.libs.json._
8+
import play.api.Logger
9+
import utils.Utils.createTestDataFrame
10+
11+
trait DataHandlerProxy extends Actor
12+
13+
class DataHandler extends Actor with DataHandlerProxy {
14+
val log = Logger(this.getClass)
15+
16+
override def receive = LoggingReceive {
17+
18+
case Fetch(eventId) =>
19+
log.debug(s"Received Fetch message for eventId = $eventId from $sender")
20+
21+
// Warning: This is a hack until I can figure out how actors
22+
// interact with futures via data services.
23+
// Or even if they should interact in this context.
24+
//TODO: Write websocket (or use streams) for on the fly classification
25+
val predictionData = createTestDataFrame(10)
26+
sender ! FetchResponse(eventId, predictionData)
27+
28+
case undefined => log.warn(s"Unexpected message $undefined")
29+
}
30+
}
31+
32+
object DataHandler {
33+
def props = Props(new DataHandler)
34+
case class Fetch(eventId: Int)
35+
case class FetchResponse(eventId: Int, predictionData: DataFrame)
36+
37+
object Fetch {
38+
implicit val fetchFormat = Json.format[Fetch]
39+
}
40+
41+
42+
}
43+

app/actors/Director.scala

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package actors
2+
3+
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
4+
import akka.event.LoggingReceive
5+
import services.prediction.Predictor
6+
7+
class Director(eventServer: ActorRef) extends Actor with ActorLogging {
8+
import Director._
9+
10+
val dataHandler = context.actorOf(Props[DataHandler], "data-handler")
11+
12+
val batchTrainer = context.actorOf(BatchTrainer.props(self), "batch-trainer")
13+
14+
val predictor = new Predictor
15+
16+
val classifier = context.actorOf(Classifier.props(dataHandler, batchTrainer, predictor), "classifier")
17+
18+
context.actorOf(TrainingDataInitializer.props(batchTrainer, eventServer), "training-data-initializer")
19+
20+
21+
override def receive = LoggingReceive {
22+
23+
case GetClassifier => sender ! classifier
24+
25+
case BatchTrainingFinished => batchTrainer ! GetLatestModel
26+
27+
case undefined => log.info(s"Unexpected message $undefined")
28+
}
29+
}
30+
31+
object Director {
32+
def props(eventServer: ActorRef) = Props(new Director(eventServer))
33+
34+
case object GetClassifier
35+
36+
case object BatchTrainingFinished
37+
38+
}

app/actors/EventListener.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package actors
2+
3+
import akka.actor.{Actor, ActorRef, Props}
4+
5+
class EventListener(out: ActorRef, producer: ActorRef) extends Actor {
6+
override def preStart() = producer ! Subscribe
7+
override def postStop(): Unit = producer ! Unsubscribe
8+
9+
def receive = {
10+
case msg: String => out ! msg
11+
}
12+
}
13+
14+
object EventListener {
15+
def props(out: ActorRef, producer: ActorRef) = Props(new EventListener(out, producer))
16+
}

app/actors/EventServer.scala

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package actors
2+
3+
import akka.actor._
4+
import akka.event.LoggingReceive
5+
6+
trait EventServerProxy extends Actor
7+
8+
class EventServer extends Actor with ActorLogging with EventServerProxy {
9+
var clients = Set.empty[ActorRef]
10+
def receive = LoggingReceive {
11+
case msg: String => clients.foreach(_ ! msg)
12+
case Subscribe =>
13+
context.watch(sender)
14+
clients += sender
15+
case Unsubscribe =>
16+
context.unwatch(sender)
17+
clients -= sender
18+
}
19+
}
20+
object EventServer {
21+
def props = Props[EventServer]
22+
}

app/actors/FetchResponseHandler.scala

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package actors
2+
3+
import actors.DataHandler.FetchResponse
4+
import actors.FetchResponseHandler.FetchResponseTimeout
5+
import akka.actor.{Actor, ActorLogging, ActorRef, Props, Terminated}
6+
import akka.event.LoggingReceive
7+
import services.prediction.Predictor
8+
9+
import scala.concurrent.duration._
10+
11+
class FetchResponseHandler(batchTrainer: ActorRef,
12+
originalSender: ActorRef,
13+
predictor: Predictor) extends Actor with ActorLogging {
14+
15+
def receive = LoggingReceive {
16+
case fetchResponse: FetchResponse =>
17+
timeoutMessenger.cancel()
18+
val handler = context.actorOf(
19+
TrainingModelResponseHandler.props(
20+
fetchResponse,
21+
originalSender,
22+
predictor), "training-model-response-message-handler")
23+
log.debug(s"Created handler $handler")
24+
batchTrainer ! (GetLatestModel, handler)
25+
context.watch(handler)
26+
case t: Terminated =>
27+
log.debug(s"Received Terminated message for training model response handler $t")
28+
context.stop(self)
29+
case FetchResponseTimeout =>
30+
log.debug("Timeout occurred")
31+
originalSender ! FetchResponseTimeout
32+
context.stop(self)
33+
}
34+
import context.dispatcher
35+
36+
val timeoutMessenger = context.system.scheduler.scheduleOnce(5 seconds) {
37+
self ! FetchResponseTimeout
38+
}
39+
}
40+
41+
object FetchResponseHandler {
42+
case object FetchResponseTimeout
43+
def props(batchTrainer: ActorRef, originalSender: ActorRef, predictor: Predictor) =
44+
Props(new FetchResponseHandler(batchTrainer, originalSender, predictor))
45+
}
46+
47+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package actors
2+
3+
import actors.BatchTrainer.BatchTrainerModel
4+
import actors.ModelPerformanceSupervisor.TrainerType.TrainerType
5+
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
6+
import akka.event.LoggingReceive
7+
import models.preprocessing.DataFrame
8+
import play.api.libs.json.{Json, Reads, Writes}
9+
10+
import utils.EnumeratorUtils
11+
12+
class ModelPerformanceSupervisor extends Actor with ActorLogging {
13+
14+
import ModelPerformanceSupervisor._
15+
16+
var clients = Set.empty[ActorRef]
17+
18+
var df: Option[DataFrame] = None
19+
20+
var batchTrainerModel: Option[BatchTrainerModel] = None
21+
22+
override def receive = LoggingReceive {
23+
24+
case batchModel: BatchTrainerModel =>
25+
batchTrainerModel = Some(batchModel)
26+
validateBatchModel(batchModel) foreach sendMessage
27+
28+
case TrainingSet(c: DataFrame) =>
29+
df = Some(c)
30+
31+
case Subscribe =>
32+
context.watch(sender)
33+
clients += sender
34+
for {
35+
model <- batchTrainerModel
36+
performance <- validateBatchModel(model)
37+
} yield sender ! performance
38+
39+
case Unsubscribe =>
40+
context.unwatch(sender)
41+
clients -= sender
42+
43+
}
44+
45+
//TODO: Write actual performance implementation.
46+
def validateBatchModel(batchTrainerModel: BatchTrainerModel): Option[ModelPerformance] = {
47+
Option(ModelPerformance("Regression", "1.0", .99, .99))
48+
}
49+
50+
51+
def sendMessage(msg: ModelPerformance) = clients.foreach(_ ! msg)
52+
53+
def logStatistics(performance: ModelPerformance): Unit = {
54+
log.info(s"Trainer type: ${performance.trainer}")
55+
log.info(s"Current model: ${performance.model}")
56+
log.info(s"Area under the ROC curve: ${performance.areaUnderRoc}")
57+
log.info(s"Accuracy: ${performance.accuracy}")
58+
}
59+
60+
}
61+
62+
object ModelPerformanceSupervisor {
63+
64+
def props() = Props(new ModelPerformanceSupervisor)
65+
66+
object TrainerType extends Enumeration {
67+
68+
type TrainerType = TrainerType.Value
69+
70+
val Batch = Value
71+
72+
implicit val reads: Reads[TrainerType] = EnumeratorUtils.enumReads(TrainerType)
73+
74+
implicit val writes: Writes[TrainerType] = EnumeratorUtils.enumWrites
75+
76+
}
77+
78+
case class TrainingSet(data: DataFrame)
79+
80+
case class ModelPerformance(trainer: String, model: String, areaUnderRoc: Double, accuracy: Double)
81+
82+
83+
object ModelPerformance {
84+
85+
implicit val formatter = Json.format[ModelPerformance]
86+
87+
}
88+
89+
}

0 commit comments

Comments
 (0)