PySpark SQL Functions | collect_set method
Start your free 7-days trial now!
PySpark SQL Functions' collect_set(~) method returns a unique set of values in a column. Null values are ignored.
Use collect_list(~) instead to obtain a list of values that allows for duplicates.
Parameters
1. col | string or Column object
The column label or a Column object.
Return Value
A PySpark SQL Column object (pyspark.sql.column.Column).
Assume that the order of the returned set may be random since the order is affected by shuffle operationslink.
Examples
Consider the following PySpark DataFrame:
data = [("Alex", "A"), ("Alex", "B"), ("Bob", "A"), ("Cathy", "C"), ("Dave", None)]
+-----+-----+| name|group|+-----+-----+| Alex| A|| Alex| B|| Bob| A||Cathy| C|| Dave| null|+-----+-----+
Getting a set of column values in PySpark
To get the unique set of values in the group column:
Equivalently, you can pass in a Column object to collect_set(~) as well:
Notice how the null value does not appear in the resulting set.
Getting the set as a standard list
To get the set as a standard list:
Here, the PySpark DataFrame's collect() method returns a list of Row objects. This list is guaranteed to be length one due to the nature of collect_set(~). The Row object contains the list so we need to include another [0].
Getting a set of column values of each group in PySpark
The method collect_set(~) is often used in the context of aggregation. Consider the same PySpark DataFrame as before:
+-----+-----+| name|group|+-----+-----+| Alex| A|| Alex| B|| Bob| A||Cathy| C|| Dave| null|+-----+-----+
To flatten the group column into a single set for each name: