-
Notifications
You must be signed in to change notification settings - Fork 0
/
wp_predict.py
133 lines (116 loc) · 5.55 KB
/
wp_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import tensorflow as tf
import pickle
import json
import pandas as pd
import numpy as np
from sklearn.externals import joblib
from mongoengine import connect
from models import DealW2v
from models import PosData
from models import WepickDeal
from wp_rnn_37 import wp_rnn_classifier_fn
from tf_wals_lib import wals,wals_cate
import wprecservice_pb2
import wprecservice_pb2_grpc
connect('wprec',host='mongodb://10.102.61.251:27017')
# this should be set to user input, this values are temporary
HISTORY_FROM='04-01'
HISTORY_TO='04-10'
PREDICT_DATE='04-11'
class WpRecService(wprecservice_pb2_grpc.WpRecServiceServicer):
def GetRecommend(self, request, context):
if request.methodName=='' or request.dayFrom=='' or request.dayTo=='' or request.predictMoment=='':
return wprecservice_pb2.RecommendResponse(error=0)
profile_data_path='profile_'+request.dayFrom+'_'+request.dayTo+'.csv'
data_seq_path='wp_'+request.dayFrom+'_'+request.dayTo+'_seq.json'
with open(data_seq_path,'r') as f:
seq_data=json.load(f)
for elem in seq_data:
if elem['id']==request.user:
if len(elem['pos'])>36:
user_seq=elem['pos'][-37:]
else:
user_seq=elem['pos']
break
profile_df=pd.read_csv(profile_data_path,index_col=0)
user_profile=profile_df.loc[request.user].tolist()
deal_list=np.load('dict_'+request.dayFrom+'_'+request.predictMoment[5:-3]+'.npy')
deals=WepickDeal.objects(pk__gte=request.predictMoment+' 20',pk__lte=request.predictMoment+' 99')
scaler=joblib.load('scaler.pkl')
deal_slots=[]
deal_ids=[]
predict_input=[]
predict_seq_input=[]
for elem in deals:
deal=DealW2v.objects(pk=elem['deal'].id).first()
if deal!=None:
deal_vec=deal.vectorizedWords
predict_input.append(user_profile+deal_vec)
deal_slots.append(int(elem.id[-2:]))
deal_ids.append(elem['deal'].id)
predict_seq_input.append(user_seq+[elem['deal'].id]+[0]*(37-len(user_seq)))
predict_input=scaler.transform(predict_input)
predict_seq_lens=[len(user_seq)+1]*len(predict_seq_input)
if request.methodName=='dnn_tf':
pass
elif request.methodName=='alibaba_din':
pass
elif request.methodName=='gbc':
gbc=joblib.load('wpgbc.pkl')
probs=gbc.predict_proba(predict_input)[:,1]
elif request.methodName=='logistic':
lr=joblib.load('wplr.pkl')
probs=lr.predict_proba(predict_input)[:,1]
elif request.methodName=='rnn' or request.methodName=='rnn_bi':
if request.methodName=='rnn':
model_path='./seq_models'
else:
model_path='./seq_bi_models'
deal_dict=np.array([[0.0]*100]+[DealW2v.objects(pk=elem).first().vectorizedWords for elem in deal_list[1:]])
predict_input_fn=tf.estimator.inputs.numpy_input_fn({'seq':np.array(predict_seq_input),'seq_len':np.array(predict_seq_lens)},shuffle=False)
rnn_predictor=tf.estimator.Estimator(wp_rnn_classifier_fn,model_path,
params={
'dict':deal_dict,
'rnn_depth':3,
'bidirectional':request.methodName=='rnn_bi',
'use_dropout':True,
'dropout_input_keep':0.9,
'dropout_output_keep':0.9
})
result=rnn_predictor.predict(predict_input_fn)
probs=[elem['prob'] for elem in result]
elif request.methodName=='logistic_tf':
pass
elif request.methodName=='boosted_tree_tf':
pass
results=[{'id':deal_ids[i],'slot':deal_slots[i],'score':probs[i]} for i in range(len(predict_input))]
return wprecservice_pb2.RecommendResponse(error=-1,result=results)
def GetMfRecommend(self, request, context):
if request.dayFrom=='' or request.dayTo=='' or request.predictMoment=='' or request.user==0:
return wprecservice_pb2.RecommendResponse(error=0)
if request.dimension==0:
request.dimension=30
if request.weight==0.0:
request.weight=0.5
if request.coef==0.0:
request.coef=2.0
if request.nIter==0:
request.nIter=30
results=wals(request.user,request.dayFrom,request.dayTo,request.predictMoment,request.dimension,request.weight,request.coef,request.nIter)
if results==-1:
return wprecservice_pb2.RecommendResponse(error=0)
else:
return wprecservice_pb2.RecommendResponse(error=-1,result=results)
def GetMfCateRecommend(self, request, context):
if request.dayFrom=='' or request.dayTo=='':
return wprecservice_pb2.MfDataResponse(error=0)
if request.dimension==0:
request.dimension=30
if request.weight==0.0:
request.weight=0.5
if request.coef==0.0:
request.coef=2.0
if request.nIter==0:
request.nIter=30
dimension,row,col=wals_cate(request.dayFrom,request.dayTo,request.dimension,request.weight,request.coef,request.nIter)
return wprecservice_pb2.MfDataResponse(error=-1,numFeatures=dimension,users=row,items=col)