search
Search
Join our weekly DS/ML newsletter layers DS/ML Guides
menu
menu search toc more_vert
Robocat
Guest 0reps
Thanks for the thanks!
close
Comments
Log in or sign up
Cancel
Post
account_circle
Profile
exit_to_app
Sign out
help Ask a question
Share on Twitter
search
keyboard_voice
close
Searching Tips
Search for a recipe:
"Creating a table in MySQL"
Search for an API documentation: "@append"
Search for code: "!dataframe"
Apply a tag filter: "#python"
Useful Shortcuts
/ to open search panel
Esc to close search panel
to navigate between search results
d to clear all current filters
Enter to expand content preview
icon_star
Doc Search
icon_star
Code Search Beta
SORRY NOTHING FOUND!
mic
Start speaking...
Voice search is only supported in Safari and Chrome.
Navigate to
A
A
brightness_medium
share
arrow_backShare
Twitter
Facebook

Iterating over each row of a PySpark DataFrame

Machine Learning
chevron_right
PySpark
chevron_right
Cookbooks
schedule Jul 1, 2022
Last updated
local_offer PySpark
Tags

Iterating over a PySpark DataFrame is tricky because of its distributed nature - the data of a PySpark DataFrame is typically scattered across multiple worker nodes. This guide explores three solutions for iterating over each row, but I recommend opting for the first solution!

Using the map method of RDD to iterate over the rows of PySpark DataFrame

All Spark DataFrames are internally represented using Spark's built-in data structure called RDD (resilient distributed dataset). One way of iterating over the rows of a PySpark DataFrame is to use the map(~) function available only to RDDs - we therefore need to convert the PySpark DataFrame into a RDD first.

As an example, consider the following PySpark DataFrame:

df = spark.createDataFrame([("Alex", 15), ("Bob", 20), ("Cathy", 25)], ["name", "age"])
df.show()
+-----+---+
| name|age|
+-----+---+
| Alex| 15|
| Bob| 20|
|Cathy| 25|
+-----+---+

We can iterate over each row of this PySpark DataFrame like so:

from pyspark.sql import Row

def my_func(row):
d = row.asDict()
d.update({'name': d['name'].upper()})
updated_row = Row(**d)
return updated_row
rdd = df.rdd.map(my_func)
rdd.toDF().show()
+-----+---+
| name|age|
+-----+---+
| ALEX| 15|
| BOB| 20|
|CATHY| 25|
+-----+---+

Here, note the following:

  • the conversion from PySpark DataFrame to RDD is simple - df.rdd.

  • we then use the map(~) method of the RDD, which takes in as argument a function. This function takes as input a single Row object and is invoked for each row of the PySpark DataFrame.

  • in the first line of our custom function my_func(~), we convert the Row into a dictionary using asDict(). The reason for this is that we cannot mutate the Row object directly - and so we must convert the Row object into a dictionary, then perform an update on the dictionary, and then finally convert the updated dictionary back to a Row object.

  • the ** in Row(**d) converts the dictionary d into keyword arguments for the Row(~) constructor.

NOTE

Unlike the other solutions that will be discussed below, this solution allows us to update the values of each row while we iterate over the rows.

Using the collect method and then iterating in the driver node

Another solution is to use the collect(~) method to push all the data from the worker nodes to the driver program, and then iterate over the rows.

As an example, consider the following PySpark DataFrame:

df = spark.createDataFrame([["Alex", 20], ["Bob", 30], ["Cathy", 40]], ["name", "age"])
df.show()
+-----+---+
| name|age|
+-----+---+
| Alex| 20|
| Bob| 30|
|Cathy| 40|
+-----+---+

We can use the collect(~) method to first send all the data from the worker nodes to the driver program, and then perform a simple for-loop:

for row in df.collect():
print(row.name)
Alex
Bob
Cathy
WARNING

Watch out for the following limitations:

  • since the collect(~) method will send all the data to the driver node, make sure that your driver node has enough memory to avoid an out-of-memory error.

  • we cannot update the value of the rows while we iterate.

Using foreach to iterate over the rows in the worker nodes

The foreach(~) method instructs the worker nodes in the cluster to iterate over each row (as a Row object) of a PySpark DataFrame and apply a function on each row on the worker node hosting the row:

# This function fires in the worker node
def f(row):
print(row.name)

df.foreach(f)

Here, the printed results will only be displayed in the standard output of the worker node instead of the driver program.

WARNING

The following are some hard limitations of foreach(~) imposed by Spark:

  • the row is read-only. This means that you cannot update the row values while iterating.

  • since the worker nodes are performing the iteration and not the driver program, standard output/error will not be shown in our session/notebook. For instance, performing a print(~) as we have done in our function will not display the printed results in our session/notebook - instead we would need to check the log of the worker nodes.

Given such limitations, one of the main use case of foreach(~) is to log - either to a file or an external database - the rows of the PySpark DataFrame.

mail
Join our newsletter for updates on new DS/ML comprehensive guides (spam-free)
robocat
Published by Isshin Inada
Edited by 0 others
Did you find this page useful?
Ask a question or leave a feedback...