Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker-init/demo/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ ENV SPARK_DIST_CLASSPATH="/opt/spark/jars/*"
ENV SPARK_EXTRA_CLASSPATH="/opt/spark/jars/*:/opt/chronon/jars/*"
ENV HADOOP_CLASSPATH="/opt/spark/jars/*"

EXPOSE 8181
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Port 8181 exposure requires additional security measures

Based on the code analysis, the DataServer implementation has several security concerns:

  1. No authentication/authorization mechanisms are implemented
  2. The server accepts any incoming HTTP requests without filtering
  3. Sensitive data from drift and summary series is exposed without access controls
  4. The server binds to port 8181 without IP address restrictions

Required actions:

  • Implement authentication middleware
  • Add request origin validation
  • Consider restricting the binding address to non-public interfaces
  • Add rate limiting for the endpoints
🔗 Analysis chain

Verify port exposure security implications

The exposure of port 8181 aligns with the DataServer requirements, but ensure:

  1. The port is not exposed in production environments if not needed
  2. Appropriate authentication/authorization is implemented in DataServer

Let's verify the DataServer implementation and its security measures:


Let me analyze the DataServer implementation further to verify security measures.


Let me check one more aspect of the DataServer implementation to verify request handling and response mechanisms.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check DataServer implementation for security measures
# Look for authentication/authorization code and security headers

# Search for DataServer implementation
echo "Searching for DataServer security implementations..."
rg -A 5 "class DataServer" 

# Look for security-related configurations
echo "Checking for security configurations..."
rg -i "auth|security|cors|csrf" --type scala

Length of output: 18758


Script:

#!/bin/bash
# Check DataServer handler implementation and network configuration
echo "Checking DataServer handler and network setup..."
rg -A 10 "private def handle" spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala

# Look for network binding configuration
echo "Checking network binding configuration..."
rg "bind|bootstrap|channel" spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala

Length of output: 997


Script:

#!/bin/bash
# Check HTTP request handling and response generation
echo "Checking HTTP handling in DataServer..."
rg -A 15 "extends SimpleChannelInboundHandler" spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala

# Look for any access control or filtering logic
echo "Checking request filtering..."
rg -B 2 -A 5 "HttpRequest|HttpResponse" spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala

Length of output: 2630

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chill. this is an internal demo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikhil-zlai Understood, thanks for clarifying that this is an internal demo.


✏️ Learnings added
Learnt from: nikhil-zlai
PR: zipline-ai/chronon#101
File: docker-init/demo/Dockerfile:35-35
Timestamp: 2024-11-27T20:34:54.376Z
Learning: In internal demos, when exposing port 8181 in `docker-init/demo/Dockerfile`, implementing security measures like authentication and access controls on the `DataServer` may not be necessary.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


CMD ["tail", "-f", "/dev/null"]
7 changes: 5 additions & 2 deletions docker-init/demo/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
run build.sh once, and you can repeatedly exec
sbt spark/assembly + run.sh on iterations to the chronon code.
run build.sh once, and you can repeatedly exec to quickly visualize

In first terminal: `sbt spark/assembly`
In second terminal: `./run.sh` to load the built jar and serve the data on localhost:8181
In third terminal: `streamlit run viz.py`
Empty file modified docker-init/demo/build.sh
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions docker-init/demo/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fi

# Run new container
docker run -d \
-p 8181:8181 \
--name spark-app \
-v "$SPARK_JAR_PATH":/opt/chronon/jars \
-v "$LOG_CONFIG_FILE":/opt/chronon/log4j2.properties \
Expand Down
94 changes: 94 additions & 0 deletions docker-init/demo/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import streamlit as st
import requests
from datetime import datetime
import pandas as pd
from collections import defaultdict

# Configure the page to use wide mode
st.set_page_config(layout="wide")

# Add custom CSS to make charts wider
st.markdown("""
<style>
.element-container {
width: 100%;
}
.stChart {
width: 100%;
min-width: 400px;
}
</style>
""", unsafe_allow_html=True)

def format_timestamp(ts_ms):
"""Format millisecond timestamp to readable date."""
return datetime.fromtimestamp(ts_ms/1000).strftime('%Y-%m-%d %H:%M')

def load_data():
"""Load data from the API."""
try:
response = requests.get("http://localhost:8181/api/drift-series")
return response.json()
except Exception as e:
st.error(f"Error loading data: {e}")
return None

def create_series_df(timestamps, values):
"""Create a DataFrame for a series."""
return pd.DataFrame({
'timestamp': [format_timestamp(ts) for ts in timestamps],
'value': values
})

def is_valid_series(series):
return any(value is not None for value in series)

def main():
st.title("Drift Board")

# Load data
data = load_data()
if not data:
return

# Group data by groupName
grouped_data = defaultdict(list)
for entry in data:
group_name = entry["key"].get("groupName", "unknown")
grouped_data[group_name].append(entry)

# Series types and their display names
series_types = {
"percentileDriftSeries": "Percentile Drift",
"histogramDriftSeries": "Histogram Drift",
"countChangePercentSeries": "Count Change %"
}

# Create tabs for each group
group_tabs = st.tabs(list(grouped_data.keys()))

# Fill each tab with its group's data
for tab, (group_name, group_entries) in zip(group_tabs, grouped_data.items()):
with tab:
for entry in group_entries:
# Create expander for each column
column_name = entry["key"].get("column", "unknown")
with st.expander(f"Column: {column_name}", expanded=True):
# Get available series for this entry
available_series = [s_type for s_type in series_types.keys()
if s_type in entry and is_valid_series(entry[s_type])]

if available_series:
# Create columns for charts with extra padding
cols = st.columns([1] * len(available_series))

# Create charts side by side
for col_idx, series_type in enumerate(available_series):
if is_valid_series(entry[series_type]):
with cols[col_idx]:
st.subheader(series_types[series_type])
df = create_series_df(entry["timestamps"], entry[series_type])
st.line_chart(df.set_index('timestamp'), height=400)

if __name__ == "__main__":
main()
141 changes: 141 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package ai.chronon.spark.scripts

import ai.chronon.api.TileDriftSeries
import ai.chronon.api.TileSeriesKey
import ai.chronon.api.TileSummarySeries
import ai.chronon.api.thrift.TBase
import ai.chronon.online.stats.DriftStore
import ai.chronon.online.stats.DriftStore.SerializableSerializer
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.SerializationFeature
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.Unpooled
import io.netty.channel._
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.http._
import io.netty.util.CharsetUtil

import java.util.Base64
import java.util.function.Supplier
import scala.reflect.ClassTag

class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSummarySeries], port: Int = 8181) {
private val logger = org.slf4j.LoggerFactory.getLogger(getClass)
private val bossGroup = new NioEventLoopGroup(1)
private val workerGroup = new NioEventLoopGroup()
private val mapper = new ObjectMapper()
.registerModule(DefaultScalaModule)
.enable(SerializationFeature.INDENT_OUTPUT)

private class HttpServerHandler extends SimpleChannelInboundHandler[HttpObject] {
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
ctx.flush()
}

private val serializer: ThreadLocal[SerializableSerializer] =
ThreadLocal.withInitial(new Supplier[SerializableSerializer] {
override def get(): SerializableSerializer = DriftStore.compactSerializer
})

private def convertToBytesMap[T <: TBase[_, _]: Manifest: ClassTag](
series: T,
keyF: T => TileSeriesKey): Map[String, String] = {
val serializerInstance = serializer.get()
val encoder = Base64.getEncoder
val keyBytes = serializerInstance.serialize(keyF(series))
val valueBytes = serializerInstance.serialize(series)
Map(
"keyBytes" -> encoder.encodeToString(keyBytes),
"valueBytes" -> encoder.encodeToString(valueBytes)
)
}

override def channelRead0(ctx: ChannelHandlerContext, msg: HttpObject): Unit = {
msg match {
case request: HttpRequest =>
val uri = request.uri()

val start = System.currentTimeMillis()
val (status, content) = uri match {
case "/health" =>
(HttpResponseStatus.OK, """{"status": "healthy"}""")

case "/api/drift-series" =>
//val dtos = driftSeries.map(d => convertToBytesMap(d, (tds: TileDriftSeries) => tds.getKey))
(HttpResponseStatus.OK, mapper.writeValueAsString(driftSeries))

case "/api/summary-series" =>
val dtos = summarySeries.map(d => convertToBytesMap(d, (tds: TileSummarySeries) => tds.getKey))
(HttpResponseStatus.OK, mapper.writeValueAsString(dtos))

case "/api/metrics" =>
val metrics = Map(
"driftSeriesCount" -> driftSeries.size,
"summarySeriesCount" -> summarySeries.size
)
(HttpResponseStatus.OK, mapper.writeValueAsString(metrics))

case _ =>
(HttpResponseStatus.NOT_FOUND, """{"error": "Not Found"}""")
}
val end = System.currentTimeMillis()
logger.info(s"Request $uri took ${end - start}ms, status: $status, content-size: ${content.length}")

val response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
status,
Unpooled.copiedBuffer(content, CharsetUtil.UTF_8)
)

response
.headers()
.set(HttpHeaderNames.CONTENT_TYPE, "application/json")
.set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes())

if (HttpUtil.isKeepAlive(request)) {
response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE)
}

ctx.write(response)
case _ =>
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
cause.printStackTrace()
ctx.close()
}
}

def start(): Unit = {
try {
val b = new ServerBootstrap()
b.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])
.childHandler(new ChannelInitializer[SocketChannel] {
override def initChannel(ch: SocketChannel): Unit = {
val p = ch.pipeline()
p.addLast(new HttpServerCodec())
p.addLast(new HttpObjectAggregator(65536))
p.addLast(new HttpServerHandler())
}
})
.option[Integer](ChannelOption.SO_BACKLOG, 128)
.childOption[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)

val f = b.bind(port).sync()
println(s"Server started at http://localhost:$port/metrics")
f.channel().closeFuture().sync()
} finally {
shutdown()
}
}

private def shutdown(): Unit = {
workerGroup.shutdownGracefully()
bossGroup.shutdownGracefully()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,6 @@ import scala.util.ScalaJavaConversions.IteratorOps
object ObservabilityDemo {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)

def time(message: String)(block: => Unit): Unit = {
logger.info(s"$message..".yellow)
val start = System.currentTimeMillis()
block
val end = System.currentTimeMillis()
logger.info(s"$message took ${end - start} ms".green)
}

class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val startDs: ScallopOption[String] = opt[String](
name = "start-ds",
default = Some("2023-01-01"),
descr = "Start date in YYYY-MM-DD format"
)

val endDs: ScallopOption[String] = opt[String](
name = "end-ds",
default = Some("2023-02-30"),
descr = "End date in YYYY-MM-DD format"
)

val rowCount: ScallopOption[Int] = opt[Int](
name = "row-count",
default = Some(700000),
descr = "Number of rows to generate"
)

val namespace: ScallopOption[String] = opt[String](
name = "namespace",
default = Some("observability_demo"),
descr = "Namespace for the demo"
)

verify()
}

def main(args: Array[String]): Unit = {

val config = new Conf(args)
Expand Down Expand Up @@ -183,6 +147,9 @@ object ObservabilityDemo {
}
}

val server = new DataServer(driftSeries, summarySeries)
server.start()

val startTs = 1673308800000L
val endTs = 1674172800000L
val joinName = "risk.user_transactions.txn_join"
Expand Down Expand Up @@ -211,4 +178,40 @@ object ObservabilityDemo {
spark.stop()
System.exit(0)
}

def time(message: String)(block: => Unit): Unit = {
logger.info(s"$message..".yellow)
val start = System.currentTimeMillis()
block
val end = System.currentTimeMillis()
logger.info(s"$message took ${end - start} ms".green)
}

class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
val startDs: ScallopOption[String] = opt[String](
name = "start-ds",
default = Some("2023-01-08"),
descr = "Start date in YYYY-MM-DD format"
)

val endDs: ScallopOption[String] = opt[String](
name = "end-ds",
default = Some("2023-02-30"),
descr = "End date in YYYY-MM-DD format"
)

val rowCount: ScallopOption[Int] = opt[Int](
name = "row-count",
default = Some(700000),
descr = "Number of rows to generate"
)

val namespace: ScallopOption[String] = opt[String](
name = "namespace",
default = Some("observability_demo"),
descr = "Namespace for the demo"
)

verify()
}
}