diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index 016faa735acd6..58fa01edb80e8 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -440,9 +440,10 @@ The following configurations are optional:
| kafkaConsumer.pollTimeoutMs |
long |
- 512 |
+ 120000 |
streaming and batch |
- The timeout in milliseconds to poll data from Kafka in executors. |
+ The timeout in milliseconds to poll data from Kafka in executors. When not defined it falls
+ back to spark.network.timeout. |
| fetchOffset.numRetries |
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
index 5f23029d9fed3..f2bf7cd1360ec 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.kafka010.consumer
import java.{util => ju}
import java.io.Closeable
+import java.time.Duration
import java.util.concurrent.TimeoutException
import scala.collection.JavaConverters._
@@ -73,7 +74,7 @@ private[kafka010] class InternalKafkaConsumer(
// Seek to the offset because we may call seekToBeginning or seekToEnd before this.
seek(offset)
- val p = consumer.poll(pollTimeoutMs)
+ val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
val r = p.records(topicPartition)
logDebug(s"Polled $groupId ${p.partitions()} ${r.size}")
val offsetAfterPoll = consumer.position(topicPartition)
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
index 142e946188ace..09af5a0815147 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.kafka010
import java.{util => ju}
+import java.time.Duration
import scala.collection.JavaConverters._
@@ -203,7 +204,7 @@ private[kafka010] class InternalKafkaConsumer[K, V](
}
private def poll(timeout: Long): Unit = {
- val p = consumer.poll(timeout)
+ val p = consumer.poll(Duration.ofMillis(timeout))
val r = p.records(topicPartition)
logDebug(s"Polled ${p.partitions()} ${r.size}")
buffer = r.listIterator