How to find the cumulative sum / sum for each group using DataFrame in Spark [Python version]

It is a method to calculate the cumulative sum while grouping and sorting columns using the Window function of Spark's python version DataFrame.

It's a method I searched for while looking at the official Python API documentation, so there may be a better way. The version of Spark used is 1.5.2.

Sample data

Prepare the test data in the PostgreSQL table and load it into pyspark as a DataFrame.

$ SPARK_CLASSPATH=postgresql-9.4-1202.jdbc41.jar PYSPARK_DRIVER_PYTHON=ipython pyspark
In [1]: df ='jdbc').options(url='jdbc:postgresql://localhost:5432/postgres?user=postgres', dbtable='').load()
In [2]: df.printSchema()
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
In [4]:
|  a|                   b|  c|
|  1|2015-11-22 10:00:...|  1|
|  1|2015-11-22 10:10:...|  2|
|  1|2015-11-22 10:20:...|  3|
|  1|2015-11-22 10:30:...|  4|
|  1|2015-11-22 10:40:...|  5|
|  1|2015-11-22 10:50:...|  6|
|  1|2015-11-22 11:00:...|  7|
|  1|2015-11-22 11:10:...|  8|
|  1|2015-11-22 11:20:...|  9|
|  1|2015-11-22 11:30:...| 10|
|  1|2015-11-22 11:40:...| 11|
|  1|2015-11-22 11:50:...| 12|
|  1|2015-11-22 12:00:...| 13|
|  2|2015-11-22 10:00:...|  1|
|  2|2015-11-22 10:10:...|  2|
|  2|2015-11-22 10:20:...|  3|
|  2|2015-11-22 10:30:...|  4|
|  2|2015-11-22 10:40:...|  5|
|  2|2015-11-22 10:50:...|  6|
|  2|2015-11-22 11:00:...|  7|
only showing top 20 rows

Column a is for grouping, column b is for sorting, and column c is for calculation.

Cumulative sum for each column group

While grouping by column a, sort by column b and take the cumulative sum of column c.

First, the definition of Window

In [6]: from pyspark.sql.Window import Window

In [7]: from pyspark.sql import functions as func

In [8]: window = Window.partitionpartitionBy(df.a).orderBy(df.b).rangeBetween(-sys.maxsize,0)

In [9]: window
Out[9]: <pyspark.sql.window.WindowSpec at 0x18368d0>

Create a Column that calculates pyspark.sql.functions.sum () on this window

In [10]: cum_c = func.sum(df.c).over(window)

In [11]: cum_c
Out[11]: Column<'sum(c) WindowSpecDefinition UnspecifiedFrame>

Create a new DataFrame with this Column attached to the original DataFrame

In [12]: mod_df = df.withColumn("cum_c", cum_c)

In [13]: mod_df
Out[13]: DataFrame[a: int, b: timestamp, c: int, cum_c: bigint]

In [14]: mod_df.printSchema()
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)

In [15]:
|  a|                   b|  c|cum_c|
|  1|2015-11-22 10:00:...|  1|    1|
|  1|2015-11-22 10:10:...|  2|    3|
|  1|2015-11-22 10:20:...|  3|    6|
|  1|2015-11-22 10:30:...|  4|   10|
|  1|2015-11-22 10:40:...|  5|   15|
|  1|2015-11-22 10:50:...|  6|   21|
|  1|2015-11-22 11:00:...|  7|   28|
|  1|2015-11-22 11:10:...|  8|   36|
|  1|2015-11-22 11:20:...|  9|   45|
|  1|2015-11-22 11:30:...| 10|   55|
|  1|2015-11-22 11:40:...| 11|   66|
|  1|2015-11-22 11:50:...| 12|   78|
|  1|2015-11-22 12:00:...| 13|   91|
|  2|2015-11-22 10:00:...|  1|    1|
|  2|2015-11-22 10:10:...|  2|    3|
|  2|2015-11-22 10:20:...|  3|    6|
|  2|2015-11-22 10:30:...|  4|   10|
|  2|2015-11-22 10:40:...|  5|   15|
|  2|2015-11-22 10:50:...|  6|   21|
|  2|2015-11-22 11:00:...|  7|   28|
only showing top 20 rows

You can calculate it.

Sum for each column group

Now calculate the sum of column c for each group of column a. Set the DataFrame to pyspark.sql.GroupedData with groupBy () and use pyspark.sql.GroupedData.sum (). It's complicated with sum (), but be careful because an error will occur if you have a Column option as an argument.

In [25]: sum_c_df = df.groupBy('a').sum('c')

Also, unlike before, this is not a Window function, so the result returned is a DataFrame. Moreover, the column name that stores the sum is decided arbitrarily.

In [26]: sum_c_df
Out[26]: DataFrame[a: int, sum(c): bigint]

Well, it's complicated.

For the time being, attach it as a column to the original DataFrame.

In [27]: mod_df3 = mod_df2.join('a'sum_c_df, 'a'()

In [28]: mod_df3.printSchema()
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)
 |-- sum(c): long (nullable = true)

In [29]:
|  a|                   b|  c|  cum_c|sum(c)|
|  1|2015-11-22 10:00:...|  1|      1|    91|
|  1|2015-11-22 10:10:...|  2|      3|    91|
|  1|2015-11-22 10:20:...|  3|      6|    91|
|  1|2015-11-22 10:30:...|  4|     10|    91|
|  1|2015-11-22 10:40:...|  5|     15|    91|
|  1|2015-11-22 10:50:...|  6|     21|    91|
|  1|2015-11-22 11:00:...|  7|     28|    91|
|  1|2015-11-22 11:10:...|  8|     36|    91|
|  1|2015-11-22 11:20:...|  9|     45|    91|
|  1|2015-11-22 11:30:...| 10|     55|    91|
|  1|2015-11-22 11:40:...| 11|     66|    91|
|  1|2015-11-22 11:50:...| 12|     78|    91|
|  1|2015-11-22 12:00:...| 13|     91|    91|
|  2|2015-11-22 10:00:...|  1|      1|    91|
|  2|2015-11-22 10:10:...|  2|      3|    91|
|  2|2015-11-22 10:20:...|  3|      6|    91|
|  2|2015-11-22 10:30:...|  4|     10|    91|
|  2|2015-11-22 10:40:...|  5|     15|    91|
|  2|2015-11-22 10:50:...|  6|     21|    91|
|  2|2015-11-22 11:00:...|  7|     28|    91|
only showing top 20 rows

You have successfully calculated the sum for each group.

For each column group (sum-cumulative sum)

Now let's calculate the remaining value up to the sum for column c. That is, sum-cumulative sum.

In [30]: diff_sum_c = mod_df3[('sum(c)'] - mod_df3['cum_c']

In [31]: mod_df4 = mod_df3.withColumn("diff_sum_c", diff_sum_c)

In [34]:
|  a|                   b|  c|cum_c_2|sum(c)|diff_sum_c|
|  1|2015-11-22 10:00:...|  1|      1|    91|        90|
|  1|2015-11-22 10:10:...|  2|      3|    91|        88|
|  1|2015-11-22 10:20:...|  3|      6|    91|        85|
|  1|2015-11-22 10:30:...|  4|     10|    91|        81|
|  1|2015-11-22 10:40:...|  5|     15|    91|        76|
|  1|2015-11-22 10:50:...|  6|     21|    91|        70|
|  1|2015-11-22 11:00:...|  7|     28|    91|        63|
|  1|2015-11-22 11:10:...|  8|     36|    91|        55|
|  1|2015-11-22 11:20:...|  9|     45|    91|        46|
|  1|2015-11-22 11:30:...| 10|     55|    91|        36|
|  1|2015-11-22 11:40:...| 11|     66|    91|        25|
|  1|2015-11-22 11:50:...| 12|     78|    91|        13|
|  1|2015-11-22 12:00:...| 13|     91|    91|         0|
|  2|2015-11-22 10:00:...|  1|      1|    91|        90|
|  2|2015-11-22 10:10:...|  2|      3|    91|        88|
|  2|2015-11-22 10:20:...|  3|      6|    91|        85|
|  2|2015-11-22 10:30:...|  4|     10|    91|        81|
|  2|2015-11-22 10:40:...|  5|     15|    91|        76|
|  2|2015-11-22 10:50:...|  6|     21|    91|        70|
|  2|2015-11-22 11:00:...|  7|     28|    91|        63|
only showing top 20 rows


As I noticed this time, using SPARK_CLASSPATH seems to be deprecated in Spark 1.0 and above. When I started pyspark, I got the following message.

15/11/22 12:32:44 WARN spark.SparkConf: 
SPARK_CLASSPATH was detected (set to 'postgresql-9.4-1202.jdbc41.jar').
This is deprecated in Spark 1.0+.

Please instead use:
 - ./spark-submit with --driver-class-path to augment the driver classpath
 - spark.executor.extraClassPath to augment the executor classpath

Apparently, when using a cluster, this environment variable is not transmitted correctly on different servers, so it seems recommended to use a different parameter.

Umm. I have to understand the difference between local and distributed environment.

Recommended Posts

How to find the cumulative sum / sum for each group using DataFrame in Spark [Python version]
How to get the Python version
How to set the output resolution for each keyframe in Blender
[Introduction to Python] How to use the in operator in a for statement?
How to check opencv version in python
Switch the module to be loaded for each execution environment in Python
Match the distribution of each group in Python
How to specify TLS version in python requests
How to find the correlation for categorical variables
[Circuit x Python] How to find the transfer function of a circuit using Lcapy
How to find the coefficient of the trendline that passes through the vertices in Python
Find the cumulative distribution function by sorting (Python version)
How to retrieve the nth largest value in Python
[For beginners] How to use say command in python!
How to get the variable name itself in python
How to get the number of digits in Python
How to know the current directory in Python in Blender
How to auto-submit Microsoft Forms using python (Mac version)
How to exit when using Python in Terminal (Mac)
How to retrieve multiple arrays using slice in python.
[Introduction to Python] How to stop the loop using break?
How to execute a command using subprocess in Python
[Python] How to output the list values in order
Whole type conversion for each dataframe column in python
[Introduction to Python] How to write repetitive statements using for statements
I just want to find the 95% confidence interval for the difference in population ratios in Python
How to unit test a function containing the current time using freezegun in python
Output the specified table of Oracle database in Python to Excel for each file
How to count the number of occurrences of each element in the list in Python with weight
How to set the development environment for each project with VSCode + Python extension + Miniconda
How to find the first element that matches your criteria in a Python list
How to change python version of Notebook in Watson Studio (or Cloud Pak for Data)
How to find the optimal number of clusters in k-means
[python] How to check if the Key exists in the dictionary
Convert from Pandas DataFrame to System.Data.DataTable using Python for .NET
Check the operation of Python for .NET in each environment
[python] How to use the library Matplotlib for drawing graphs
Find the difference in Python
How to use the __call__ method in a Python class
How to define multiple variables in a python for statement
How to specify Cache-Control for blob storage in Azure Storage in Python
How to generate a query using the IN operator in Django
How to get the last (last) value in a list in Python
How to implement Python EXE for Windows in Docker container
I didn't know how to use the [python] for statement
Find out how many each character is in the string.
How to change Python version
How to develop in Python
How to pass the execution result of a shell command in a list in Python (non-blocking version)
How to execute the sed command many times using the for statement
Things to watch out for when using default arguments in Python
How to determine the existence of a selenium element in Python
How to change the log level of Azure SDK for Python
How to know the internal structure of an object in Python
python / pandas / dataframe / How to get the simplest row / column / index / column
The 15th offline real-time how to write reference problem in Python
How to get followers and followers from python using the Mastodon API
The 17th Offline Real-time How to Solve Writing Problems in Python
How to check the memory size of a variable in Python
Automatically resize screenshots for the App Store for each screen in Python
How to judge that the cross key is input in Python3