-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Description
I'd like to preprocess my data by sending in a string input through predictor.predict(data) and turning it into numerical embeddings just as my train_input_fn is doing with vocab_processor.fit_transform before going though my model_fn :
def train_input_fn(training_dir, hyperparmeters):
return _input_fn(training_dir, 'meta_data_train.csv')
def _input_fn(training_dir, training_filename):
training_set = pd.read_csv(os.path.join(training_dir, training_filename), dtype={'Classification class name': object}, encoding='cp1252')
global n_words
# Prepare training and testing data
data = training_set['Features']
target = pd.Series(training_set['Labels'])
if training_filename == 'meta_data_test.csv':
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore("s3://sagemaker-blah/vocab")
data = np.array(list(vocab_processor.transform(data)))
return tf.estimator.inputs.numpy_input_fn(
x={INPUT_TENSOR_NAME: data},
y=target,
num_epochs=100,
shuffle=False)()
elif training_filename == 'meta_data_train.csv':
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
data = np.array(list(vocab_processor.fit_transform(data)))
vocab_processor.save("s3://sagemaker-blah/vocab")
n_words = len(vocab_processor.vocabulary_)
return tf.estimator.inputs.numpy_input_fn(
x={INPUT_TENSOR_NAME: data},
y=target,
batch_size=len(data),
num_epochs=None,
shuffle=True)()
The documentation says to do it through serving_input_fn but I'm not sure how I can access and manipulate the data from my tensor using vocab_processor.transform. Here's my serving_input_fn for context:
def serving_input_fn(hyperparmeters):
tensor = tf.placeholder(tf.int64, shape=[None, MAX_DOCUMENT_LENGTH])
return build_raw_serving_input_receiver_fn({INPUT_TENSOR_NAME: tensor})()
I tried doing so through a input_fn instead:
def input_fn(serialized_input, content_type):
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
deserialized_input = pickle.loads(serialized_input)
deserialized_input = np.array(list(vocab_processor.fit_transform(deserialized_input)))
return deserialized_input
Here, I had an error deserializing:
KeyError: '['
What would be the best method to preprocess the data?
Metadata
Metadata
Assignees
Labels
No labels