Apache Spark Unit Testing

Apache Spark is a lightning-fast unified analytics engine for big data and machine learning. Apache Spark is included in almost all of the Hadoop distributions. Apache Spark is the hottest open source project and is favoured over map reduce introduced by Hadoop for big data processing. Spark is lightning fast when compared to Hadoop Map Reduce data processing due to its in-memory computation engine.

Spark ETL and big data analytics pipeline can become fairly complex. It can comprise of multiple transformations and actions. Developing Spark applications can be a challenging task without a proper unit testing mechanism. This post will showcase how we can use pytest to unit test spark application.

Spark Application

In order to demonstrate the unit testing, we will consider a publicly available dataset of Airline On-Time performance and devise a spark analytics around the same. The data used in the demo is published here

The spark application will help get an ordered list of airlines which have arrived late maximum number of times. We will be reading the data CSV file and write the output in another csv file. We will be using another dataset which has the airline name description. The description of the carrier code can be found here

So the analytics pipeline will basically comprise of 2 major stages:

  1. Order the flights in the descending order of the number of times it arrived late
  2. Enrich the flight detail with the airline name along with the unique carrier code

Our corresponding spark pipeline will look like the below:

def execute(spark: SparkSession, args: Dict):
    arrival_df = load_data(spark, args.arrival_data_file_location)

    late_flights_df = find_the_late_flights(arrival_df)

    carrier_desc_df = load_data(spark, args.carrier_description_file_location)

    late_flights_enriched_df = enrich_with_flight_desc(late_flights_df, carrier_desc_df)

    write_data( df=late_flights_enriched_df, destination_path="/tmp/devrats/output")


The spark application loads the dataframe from the csv files and eventually writes the results in the file. We will be defining unit test case for the two steps in the pipeline “find_the_late_flights” and “enrich_with_flight_desc”.

Pytest the pipeline

Unit testing for pyspark will be done using Pytest. Pytest is full-featured python testing framework that makes it easy to write unit tests. SparkSession is needed in order to test the spark transformations. We will use pytest fixture with a module scope to provide spark session wherever needed for our unit tests.

Test functions will receive fixture objects by naming them as an input argument. For each argument name, a fixture function with that name provides the fixture object. Fixture functions are registered by marking them with @pytest.fixture. The spark session providing pytest fixture will look like the following:

def spark():
    spark = SparkSession.builder.master("local[2]").getOrCreate()
    return spark


The fixture is defined in the conftest.py. This fixtures defined in the conftest.py can be shared across multiple test files. The magic is you don’t need to import any of these fixtures, it is discovered automatically by pytest 

The fixture gets an existing sparksession or if there is no existing one, creates a new one. The spark session we use in the unit testing here, is run locally on 2 cores as defined by the spark session builder.

The Unit Test Case

In order to test the methods “find_the_late_flights” and “enrich_with_flight_desc” we need to define dataframes and pass them to the testing methods. The below code block shows the defined dataframe.

test_flight_arrival_data = [('BA104',20),
             ('AI101', 0)]

test_flight_arrival_schema = StructType([
    StructField("UniqueCarrier", StringType(), True),
    StructField("ArrDelay", IntegerType(), True)

def flight_df(spark):
    #type: (SparkSession) -> DataFrame
    df=spark.createDataFrame(test_flight_arrival_data, test_flight_arrival_schema)
    return df


The dataframes is defined as pytest fixtures. It uses the earlier defined spark pytest fixture to create the dataframe from the test schema and tuples. The dataframe fixtures will now be used as an input to the unit test case as below:

def test_find_the_late_flights(flight_df):
    late_flight_df = find_the_late_flights(flight_df)
    assert late_flight_df.count() == 2


Running the unit test case is as simple as running the test file like:

$ pytest flight_analyser_tests.py

You can read more about Pytest and Apache Spark at the below links:

You can refer to the source code in the github repo here. Please leave a comment you would liked the post or if you have any suggestions. Thank you for reading. Cheers! 🙂


Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.