import org.apache.spark.TaskContext
val ctx = TaskContext.get
TaskContext
TaskContext is the contract for contextual information about a Task in Spark that allows for registering task listeners.
You can access the active TaskContext instance using TaskContext.get method.
Using TaskContext you can access local properties that were set by the driver.
|
Note
|
TaskContext is serializable.
|
TaskContext Contract
trait TaskContext {
def taskSucceeded(index: Int, result: Any)
def jobFailed(exception: Exception)
}
| Method | Description |
|---|---|
Id of the Stage the task belongs to. Used when… |
|
Id of the Partition computed by the task. Used when… |
|
Specifies how many times the task has been attempted (starting from 0). Used when… |
|
Id of the attempt of the task. Used when… |
|
Gives all the metrics sources by |
|
Used when… Accesses local properties set by the driver using SparkContext.setLocalProperty. |
|
TaskMetrics of the active Task. Used when… |
|
Used when… |
|
Used when… |
|
Used when… |
|
A flag that is enabled when a task was killed. Used when… |
|
Registers a Used when… |
|
Registers a Used when… |
unset Method
|
Caution
|
FIXME |
setTaskContext Method
|
Caution
|
FIXME |
Accessing Active TaskContext — get Method
get(): TaskContext
get method returns the TaskContext instance for an active task (as a TaskContextImpl). There can only be one instance and tasks can use the object to access contextual information about themselves.
val rdd = sc.range(0, 3, numSlices = 3)
scala> rdd.partitions.size
res0: Int = 3
rdd.foreach { n =>
import org.apache.spark.TaskContext
val tc = TaskContext.get
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
|
Note
|
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.
|
Registering Task Listeners
Using TaskContext object you can register task listeners for task completion regardless of the final state and task failures only.
addTaskCompletionListener Method
addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
addTaskCompletionListener methods register a TaskCompletionListener listener to be executed on task completion.
|
Note
|
It will be executed regardless of the final state of a task - success, failure, or cancellation. |
val rdd = sc.range(0, 5, numSlices = 1)
import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
rdd.foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskCompletionListener(printTaskInfo)
}
addTaskFailureListener Method
addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener methods register a TaskFailureListener listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.
val rdd = sc.range(0, 2, numSlices = 2)
import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|error: ${error.toString}
|-------------------""".stripMargin
println(msg)
}
val throwExceptionForOddNumber = (n: Long) => {
if (n % 2 == 1) {
throw new Exception(s"No way it will pass for odd number: $n")
}
}
// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
}
// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
it
}.map(throwExceptionForOddNumber).count
(Unused) Accessing Partition Id — getPartitionId Method
getPartitionId(): Int
getPartitionId gets the active TaskContext and returns partitionId or 0 (if TaskContext not available).
|
Note
|
getPartitionId is not used.
|