Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class Word2VecModel private[ml] (
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
*/
val getVectors: DataFrame = {
@transient lazy val getVectors: DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
Expand Down
40 changes: 40 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
# limitations under the License.
#

import sys
if sys.version > '3':
basestring = str

from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector

__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
Expand Down Expand Up @@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
+----+--------------------+
| a|[-0.3511952459812...|
| b|[0.29077222943305...|
| c|[0.02315592765808...|
+----+--------------------+
...
>>> model.findSynonyms("a", 2).show()
+----+-------------------+
|word| similarity|
+----+-------------------+
| b|0.29255685145799626|
| c|-0.5414068302988307|
+----+-------------------+
...
>>> model.transform(doc).head().model
DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
"""
Expand Down Expand Up @@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel):
Model fitted by Word2Vec.
"""

def getVectors(self):
"""
Returns the vector representation of the words as a dataframe
with two fields, word and vector.
"""
return self._call_java("getVectors")

def findSynonyms(self, word, num):
"""
Find "num" number of words closest in similarity to "word".
word can be a string or vector representation.
Returns a dataframe with two fields word and similarity (which
gives the cosine similarity).
"""
if not isinstance(word, basestring):
word = _convert_to_vector(word)
return self._call_java("findSynonyms", word, num)


@inherit_doc
class PCA(JavaEstimator, HasInputCol, HasOutputCol):
Expand Down