-
Notifications
You must be signed in to change notification settings - Fork 0
/
viz_app.py
65 lines (50 loc) · 1.75 KB
/
viz_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import s3fs
import os
from io import StringIO
import pandas as pd
from constants import APP_DIR
# Create connection object.
# `anon=False` means not anonymous, i.e. it uses access keys to pull data.
fs = s3fs.S3FileSystem(anon=False,
key=st.secrets["ACCESS_KEY"],
secret=st.secrets["PRIVATE_KEY"])
# Retrieve file contents.
# Uses st.cache_data to only rerun when the query changes or after 10 min.
@st.cache_data(ttl=600)
def read_file(filename):
with fs.open(filename) as f:
return f.read().decode("utf-8")
content = read_file("s3://pmpf-data/sagemaker-xgboost-prediction/data/test.csv")
df = pd.read_csv(StringIO(content), header=None)
df.iloc[:,1] = pd.to_datetime(df.iloc[:,1])
dates = df.iloc[:, 1]
first_n_dates = list(dates[:10])
options = [str(d) for d in first_n_dates]
""" Call Model """
import boto3
sagemaker = boto3.client('sagemaker-runtime')
def get_row(date):
row = df.loc[df[1] == date]
return row
def get_prediction(row):
df_row = pd.DataFrame(row)
res = sagemaker.invoke_endpoint(
EndpointName='sagemaker-xgboost-2023-02-10-04-30-05-328',
Body=df_row.iloc[:, 1:].to_csv(index=False, header=False),
ContentType='text/csv',
Accept='Accept'
)
prediction = res['Body'].read().decode('UTF-8')
return prediction
def update_cell(date):
row = get_row(date)
row[1] = 0
## hack to get first row
p = get_prediction(row.iloc[[0]])
return p
""" Actual App Front-end """
st.title("Azure Predictive Maintenance Challenge")
st.subheader("Test Data")
date = st.selectbox("Which Date would you like to test?", tuple(options))
st.write(update_cell(date))