Skip to content

PySpark Patterns

Common Patterns and Best Practices for PySpark


Overview

This document covers common PySpark patterns for data engineering at scale. These patterns are production-tested and optimized for TB-PB scale workloads. Understanding these patterns is essential for writing efficient, maintainable PySpark code.


Core Patterns

Pattern 1: Medallion Architecture (Bronze-Silver-Gold)

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, current_timestamp
spark = SparkSession.builder.appName("MedallionETL").getOrCreate()
def process_bronze_to_silver(source_path: str, target_path: str, table_name: str):
"""
Transform bronze (raw) to silver (cleaned) layer.
"""
# Read bronze
bronze_df = spark.read.format("delta").load(source_path)
# Transform
silver_df = bronze_df.filter(col("valid") == True) \
.dropDuplicates(["id"]) \
.withColumn("processed_at", current_timestamp())
# Write silver
silver_df.write.format("delta") \
.mode("overwrite") \
.partitionBy("date") \
.save(target_path)
# Process multiple tables
tables = ["users", "events", "transactions"]
for table in tables:
process_bronze_to_silver(
f"s3://bucket/bronze/{table}",
f"s3://bucket/silver/{table}",
table
)

Pattern 2: Type 2 SCD (Slowly Changing Dimension)

from delta.tables import DeltaTable
from pyspark.sql.functions import col, lit, when
def apply_scd_type_2(target_table: str, source_df: DeltaTable):
"""
Apply SCD Type 2: Track history with start_date and end_date.
"""
target = DeltaTable.forPath(spark, target_table)
# Merge logic
target.alias("t").merge(
source_df.alias("s"),
"t.id = s.id AND t.current = true"
).whenMatchedUpdate(
condition = "s.updated_at > t.updated_at",
set = {
"current": "false",
"end_date": "s.updated_at"
}
).whenNotMatchedInsert(
values = {
"id": "s.id",
"name": "s.name",
"start_date": "s.updated_at",
"end_date": "lit(None).cast('date')",
"current": "lit(true)"
}
).execute()

Pattern 3: Late Data Handling

from pyspark.sql.functions import col, when, lit
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
def handle_late_data(df, watermark_threshold="30 minutes"):
"""
Handle late arriving data with watermarking.
"""
# Apply watermark
stream_df = df.withWatermark("event_time", watermark_threshold)
# Deduplicate within watermark
window_spec = Window.partitionBy("id") \
.orderBy(col("event_time").desc())
deduplicated = stream_df.withColumn(
"rank",
row_number().over(window_spec)
).filter(col("rank") == 1).drop("rank")
return deduplicated

Join Patterns

Broadcast Join (Small + Large)

from pyspark.sql.functions import broadcast
# Small dimension table (< 10MB)
countries = spark.read.parquet("s3://bucket/dim/countries")
# Large fact table
transactions = spark.read.parquet("s3://bucket/fact/transactions")
# Broadcast small table
result = transactions.join(
broadcast(countries),
"country_code",
"left"
)
# Alternative: Configure threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")

Skew Join Handling

from pyspark.sql.functions import col, salt_id
def handle_skew_join(large_df, small_df, key="user_id"):
"""
Handle skewed joins with salting.
"""
# Add salt to skewed keys
salted_large = large_df.withColumn(
"salt",
(col(key) % 10).alias("salt")
)
# Replicate small table 10 times
small_df = small_df.crossJoin(
spark.range(10).withColumnRenamed("id", "salt")
)
# Join on both key and salt
joined = salted_large.join(
small_df,
[key, "salt"]
).drop("salt")
return joined

Shuffle Sort Merge Join

# Configure for large joins
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # Disable broadcast
# Large + Large join
result = df1.join(df2, "key", "inner")

Aggregation Patterns

Simple Aggregation

from pyspark.sql.functions import sum, count, avg, stddev
# Group by aggregation
result = df.groupBy("category").agg(
sum("amount").alias("total_amount"),
count("*").alias("transaction_count"),
avg("amount").alias("avg_amount"),
stddev("amount").alias("stddev_amount")
)

Window Aggregation

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead
# Define window
window_spec = Window.partitionBy("user_id") \
.orderBy(col("event_timestamp").desc())
# Row number
df_with_rank = df.withColumn(
"rank",
row_number().over(window_spec)
)
# Rolling average
rolling_window = Window.partitionBy("user_id") \
.orderBy("event_timestamp") \
.rowsBetween(-10, 0) # 10 rows back to current
df_rolling = df.withColumn(
"rolling_avg",
avg("amount").over(rolling_window)
)

Cube and Rollup

from pyspark.sql.functions import cube, rollup, grouping_id
# Cube (all combinations of dimensions)
cube_result = df.cube("category", "region").agg(
sum("amount").alias("total")
)
# Rollup (hierarchical aggregations)
rollup_result = df.rollup("category", "subcategory").agg(
sum("amount").alias("total")
)
# Identify grouping level
with_group_id = cube_result.withColumn(
"grouping_level",
grouping_id("category", "region")
)

Performance Patterns

Filter Early

# Bad: Filter after expensive operations
result = df.join(other_df, "key") \
.filter(col("date") > "2025-01-01")
# Good: Filter before join
result = df.filter(col("date") > "2025-01-01") \
.join(other_df.filter(col("date") > "2025-01-01"), "key")

Select Columns Early

# Bad: Read all columns, filter later
df = spark.read.parquet("s3://bucket/large_table")
result = df.select("id", "name", "value")
# Good: Select columns at read time
df = spark.read.parquet("s3://bucket/large_table") \
.select("id", "name", "value")

Repartition for Write

# Repartition before write (avoid small files)
df.repartition(200, "date") \
.write.format("delta") \
.mode("overwrite") \
.partitionBy("date") \
.save("s3://bucket/output")
# Coalesce (reduce partitions without shuffle)
df.coalesce(10) \
.write.format("delta") \
.mode("overwrite") \
.save("s3://bucket/output")

Cache Strategy

# Cache for multiple actions
df_cached = df.filter(col("date") >= "2025-01-01").cache()
df_cached.count() # First action: materialize cache
df_cached.groupBy("category").count().show() # Uses cache
df_cached.groupBy("region").count().show() # Uses cache
# Unmatch when done
df_cached.unpersist()

Error Handling Patterns

Try-Catch for Transformations

from pyspark.sql.functions import col, when, lit
from pyspark.sql.types import IntegerType
def safe_parse_int(column):
"""Safely parse string to int, default to null on failure."""
return when(col(column).cast(IntegerType()).isNotNull(),
col(column).cast(IntegerType())) \
.otherwise(lit(None))
# Apply safe parsing
df = df.withColumn("user_id_int", safe_parse_int("user_id_str"))

Validation Framework

from pyspark.sql.functions import col, count, when, lit
def validate_data(df, rules):
"""
Validate data against rules.
Returns: (valid_df, invalid_df, validation_report)
"""
validation_results = []
for rule_name, rule_func in rules.items():
# Apply rule
violations = df.filter(~rule_func(col("*")))
# Record result
validation_results.append({
"rule": rule_name,
"violations": violations.count(),
"sample": violations.limit(10).collect()
})
return validation_results
# Define rules
rules = {
"no_null_ids": lambda c: col("id").isNotNull(),
"positive_amount": lambda c: col("amount") > 0,
"valid_email": lambda c: col("email").rlike("^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}$")
}
# Validate
results = validate_data(df, rules)
for result in results:
print(f"Rule: {result['rule']}, Violations: {result['violations']}")

UDF Patterns

Pandas UDF (Vectorized)

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import IntegerType
import pandas as pd
# Define Pandas UDF (much faster than regular UDF)
@pandas_udf(IntegerType())
def calculate_category_pandas(description: pd.Series) -> pd.Series:
"""
Categorize based on description using pandas.
"""
categories = []
for desc in description:
if pd.isna(desc):
categories.append(0)
elif "premium" in desc.lower():
categories.append(1)
else:
categories.append(2)
return pd.Series(categories)
# Apply UDF
df = df.withColumn(
"category",
calculate_category_pandas(col("description"))
)

Regular UDF (Use Sparingly)

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
# Define UDF
@udf(StringType())
def format_name(first_name: str, last_name: str) -> str:
"""
Format name as "Last, First".
"""
if not first_name or not last_name:
return None
return f"{last_name.upper()}, {first_name.capitalize()}"
# Apply UDF
df = df.withColumn(
"formatted_name",
format_name(col("first_name"), col("last_name"))
)

File I/O Patterns

Read with Schema

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
# Define schema
schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=True),
StructField("created_at", TimestampType(), nullable=True)
])
# Read with schema (faster, safer)
df = spark.read.format("delta") \
.schema(schema) \
.load("s3://bucket/data")

Write with Options

# Write with file size optimization
df.write.format("delta") \
.mode("overwrite") \
.option("maxRecordsPerFile", 1000000) \
.option("compression", "zstd") \
.partitionBy("date", "country") \
.save("s3://bucket/output")

Cost Optimization Patterns

Use Column Pruning

# Read only required columns
df = spark.read.format("delta") \
.load("s3://bucket/table") \
.select("id", "name", "amount", "date")

Use Predicate Pushdown

# Filter at source (pushed to storage)
df = spark.read.format("delta") \
.load("s3://bucket/table") \
.filter(col("date") >= "2025-01-01")

Optimize Shuffle Partitions

# Reduce shuffle partitions for small data
spark.conf.set("spark.sql.shuffle.partitions", "50")
# Increase for large data
spark.conf.set("spark.sql.shuffle.partitions", "500")

Testing Patterns

Data Testing with Great Expectations

# Test data quality
def test_data_quality(df):
"""
Test data quality and return report.
"""
tests = {
"row_count": df.count() > 0,
"no_null_ids": df.filter(col("id").isNull()).count() == 0,
"positive_amounts": df.filter(col("amount") <= 0).count() == 0,
"unique_ids": df.select("id").distinct().count() == df.count()
}
return tests
# Run tests
results = test_data_quality(df)
for test_name, passed in results.items():
status = "PASS" if passed else "FAIL"
print(f"{test_name}: {status}")

Key Takeaways

  1. Filter and select early: Reduce data movement
  2. Use broadcast joins: For small tables (10-100x faster)
  3. Handle skew: Salting for skewed joins
  4. Cache strategically: For reused DataFrames
  5. Use Pandas UDFs: 10-100x faster than regular UDFs
  6. Define schemas: Faster reads, type safety
  7. Monitor shuffle: Tune partitions
  8. Test data quality: Validation in pipelines

Back to Module 2