Friday, October 3, 2025

Creating Integration Tests for Databricks

Integration Tests with Databricks

Overview

  • There are no special tools required for testing Databricks Notebooks. All that is required is Python, Pytest, and patience.

Strategy

  • Use mock data.
  • Use the Arrange, Act Assert pattern
  • Leave the system as you found it. Perform a data cleanup.
  • Run the full suite of integration tests when deploying to an environment.

Create a New Branch

Before you do anything else, create a new branch off your development branch. Example: main or develop

Open with VS Code

Checkout the branch and open the directory with VS Code.

Create Folders

Integration Tests should be separated from Unit Tests. In VS Code, in your project, create a folder structure like this:

📁 project
    📁 tests
        📁 integration
        📁 unit

Create example tests

Create an example passing test in the integration and unit folders. Name the file:

test_example.py

'''
This is an example unit test
'''
import logging
import pytest

logger = logging.getLogger()

# test data
MESSAGE = "Hello World"

def test_pass():
    '''Example pass scenario'''
    logger.info("This is an example unit test for a pass scenario")
    assert MESSAGE == "Hello World"

Create a test runner

A Databricks Python notebook must be created to run the test. It has to be a Python notebook so that any notebooks that are run from it will have debugging information included.

pytest_databricks.py

# Databricks notebook source
# MAGIC %pip install pytest
# MAGIC import pytest
# MAGIC import sys
# MAGIC
# MAGIC sys.dont_write_bytecode = True
# MAGIC
# MAGIC retcode = pytest.main([".", "-v", "-p", "no:cacheprovider"])
# MAGIC assert retcode == 0, "The pytest invocation failed. See the log for details."

Verify the Example Test

  1. Check in your branch into Git
  2. Launch Databricks
  3. Run the pytest_databricks.py for your project in Databricks to ensure it works.

Plan the Integration Test for your Databricks Notebook

  • Identify the parameters and inputs required for the Databricks Notebook.
  • Identify what is output by the Databricks Notebook.
  • Decide what is critical to verify: table existence, column presence, expected data, files.
  • Identify cleanup steps: remove tables, delete directories, files, etc.

Create the Integration Test

Use the template below to create the integration test. The file and method must begin with test_.

Example Databricks Integration Test

import sys
import os
import shutil
import logging
import pytest
from common import TestUtil
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

logger = logging.getLogger()

spark = SparkSession.getActiveSession()
if spark is None:
    spark = SparkSession.builder.getOrCreate()

dbutils = DBUtils(spark)

integration_test_volume = TestUtil.get_integration_test_volume()   
integration_test_table_path = TestUtil.get_integration_test_table_path()
catalog_name = integration_test_table_path.split('.')[0]
schema_name = integration_test_table_path.split('.')[1]    
source_dir = f"{integration_test_volume}/source_dir"
dest_dir = f"{integration_test_volume}/dest_dir"
output_table = f"{integration_test_table_path}.output_table"

def _cleanup():    
    TestUtil.recreate_directory(source_dir)
    TestUtil.recreate_directory(dest_dir)
    TestUtil.drop_table_if_exists(output_table)
    logger.info("✅ Cleanup Completed")
    
def test_my_notebook():
    _cleanup()
    
    # Arrange
    logger.info(f"✅ Test data created")
    notebook_path = "../../notebooks/notebook_to_test"  

    params = {
        "parameter_1": "1",
        "parameter_2": "2",
        "parameter_3": "3"
    }

    # Act
    result = dbutils.notebook.run(
        notebook_path,
        timeout_seconds=300,
        arguments=params
    )

    logger.info("✅ Notebook finished")

    # Assert
    TestUtil.verify_table_exists(output_table)

    expected_cols = [
        "Column1",
        "Column2"
    ]

    TestUtil.verify_table_columns(output_table, expected_cols)
    TestUtil.verify_column_data(output_table, "Column1", "Hello World")
    logger.info("✅ Assert Completed")

    _cleanup()
    logger.info("✅ Cleanup Completed")

Test Util

This class has common helper methods:

class TestUtil:

    @staticmethod
    def recreate_directory(path: str) -> None:
        try:
            dbutils.fs.rm(path, recurse=True)
            logger.info(f"Deleted: {path}")
        except Exception as e:
            logger.info(f"(Info) Could not delete {path}, may not exist yet: {e}")
        dbutils.fs.mkdirs(path)
        logger.info(f"✅ Recreated: {path}")

    @staticmethod
    def verify_table_columns(table_full_name: str, expected_columns: list[str]) -> None:
        if not spark.catalog.tableExists(table_full_name):
            assert False, f"❌ Table does not exist: {table_full_name}"
        df = spark.table(table_full_name)
        table_columns = [c.lower() for c in df.columns]
        missing = [col for col in expected_columns if col.lower() not in table_columns]
        if missing:
            assert False, f"❌ Missing columns in {table_full_name}: {', '.join(missing)}"
        logger.info(f"✅ All columns exist in {table_full_name}")

    @staticmethod
    def verify_column_data(table_name: str, col_name: str, data: any) -> None:
        df = spark.table(table_name)
        matching = df.where(df[col_name] == data).limit(1).count()
        assert matching > 0, f"❌ No record found with {col_name} = {data} in {table_name}"

    @staticmethod
    def verify_table_exists(table_name: str) -> None:
        if not spark.catalog.tableExists(table_name):
            assert False, f"❌ Table does not exist: {table_name}"
        logger.info(f"✅ Verified table exists: {table_name}")

    @staticmethod
    def drop_table_if_exists(table_name: str) -> None:
        if spark.catalog.tableExists(table_name):
            spark.sql(f"DROP TABLE {table_name}")
            logger.info(f"✅ Dropped table: {table_name}")

    @staticmethod
    def get_integration_test_volume() -> str:
        return os.environ.get("INTEGRATION_TEST_VOLUME", "/Volumes/test_volume")

    @staticmethod
    def get_integration_test_table_path() -> str:
        return os.environ.get("INTEGRATION_TEST_TABLE_PATH", "test_catalog.test_schema")

No comments:

Post a Comment