PySpark Patterns
Common Patterns and Best Practices for PySpark
Overview
This document covers common PySpark patterns for data engineering at scale. These patterns are production-tested and optimized for TB-PB scale workloads. Understanding these patterns is essential for writing efficient, maintainable PySpark code.
Core Patterns
Pattern 1: Medallion Architecture (Bronze-Silver-Gold)
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import col, current_timestamp
spark = SparkSession.builder.appName("MedallionETL").getOrCreate()
def process_bronze_to_silver(source_path: str, target_path: str, table_name: str): """ Transform bronze (raw) to silver (cleaned) layer. """ # Read bronze bronze_df = spark.read.format("delta").load(source_path)
# Transform silver_df = bronze_df.filter(col("valid") == True) \ .dropDuplicates(["id"]) \ .withColumn("processed_at", current_timestamp())
# Write silver silver_df.write.format("delta") \ .mode("overwrite") \ .partitionBy("date") \ .save(target_path)
# Process multiple tablestables = ["users", "events", "transactions"]for table in tables: process_bronze_to_silver( f"s3://bucket/bronze/{table}", f"s3://bucket/silver/{table}", table )Pattern 2: Type 2 SCD (Slowly Changing Dimension)
from delta.tables import DeltaTablefrom pyspark.sql.functions import col, lit, when
def apply_scd_type_2(target_table: str, source_df: DeltaTable): """ Apply SCD Type 2: Track history with start_date and end_date. """ target = DeltaTable.forPath(spark, target_table)
# Merge logic target.alias("t").merge( source_df.alias("s"), "t.id = s.id AND t.current = true" ).whenMatchedUpdate( condition = "s.updated_at > t.updated_at", set = { "current": "false", "end_date": "s.updated_at" } ).whenNotMatchedInsert( values = { "id": "s.id", "name": "s.name", "start_date": "s.updated_at", "end_date": "lit(None).cast('date')", "current": "lit(true)" } ).execute()Pattern 3: Late Data Handling
from pyspark.sql.functions import col, when, litfrom pyspark.sql.window import Windowfrom pyspark.sql.functions import row_number
def handle_late_data(df, watermark_threshold="30 minutes"): """ Handle late arriving data with watermarking. """ # Apply watermark stream_df = df.withWatermark("event_time", watermark_threshold)
# Deduplicate within watermark window_spec = Window.partitionBy("id") \ .orderBy(col("event_time").desc())
deduplicated = stream_df.withColumn( "rank", row_number().over(window_spec) ).filter(col("rank") == 1).drop("rank")
return deduplicatedJoin Patterns
Broadcast Join (Small + Large)
from pyspark.sql.functions import broadcast
# Small dimension table (< 10MB)countries = spark.read.parquet("s3://bucket/dim/countries")
# Large fact tabletransactions = spark.read.parquet("s3://bucket/fact/transactions")
# Broadcast small tableresult = transactions.join( broadcast(countries), "country_code", "left")
# Alternative: Configure thresholdspark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")Skew Join Handling
from pyspark.sql.functions import col, salt_id
def handle_skew_join(large_df, small_df, key="user_id"): """ Handle skewed joins with salting. """ # Add salt to skewed keys salted_large = large_df.withColumn( "salt", (col(key) % 10).alias("salt") )
# Replicate small table 10 times small_df = small_df.crossJoin( spark.range(10).withColumnRenamed("id", "salt") )
# Join on both key and salt joined = salted_large.join( small_df, [key, "salt"] ).drop("salt")
return joinedShuffle Sort Merge Join
# Configure for large joinsspark.conf.set("spark.sql.join.preferSortMergeJoin", "true")spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # Disable broadcast
# Large + Large joinresult = df1.join(df2, "key", "inner")Aggregation Patterns
Simple Aggregation
from pyspark.sql.functions import sum, count, avg, stddev
# Group by aggregationresult = df.groupBy("category").agg( sum("amount").alias("total_amount"), count("*").alias("transaction_count"), avg("amount").alias("avg_amount"), stddev("amount").alias("stddev_amount"))Window Aggregation
from pyspark.sql.window import Windowfrom pyspark.sql.functions import row_number, rank, dense_rank, lag, lead
# Define windowwindow_spec = Window.partitionBy("user_id") \ .orderBy(col("event_timestamp").desc())
# Row numberdf_with_rank = df.withColumn( "rank", row_number().over(window_spec))
# Rolling averagerolling_window = Window.partitionBy("user_id") \ .orderBy("event_timestamp") \ .rowsBetween(-10, 0) # 10 rows back to current
df_rolling = df.withColumn( "rolling_avg", avg("amount").over(rolling_window))Cube and Rollup
from pyspark.sql.functions import cube, rollup, grouping_id
# Cube (all combinations of dimensions)cube_result = df.cube("category", "region").agg( sum("amount").alias("total"))
# Rollup (hierarchical aggregations)rollup_result = df.rollup("category", "subcategory").agg( sum("amount").alias("total"))
# Identify grouping levelwith_group_id = cube_result.withColumn( "grouping_level", grouping_id("category", "region"))Performance Patterns
Filter Early
# Bad: Filter after expensive operationsresult = df.join(other_df, "key") \ .filter(col("date") > "2025-01-01")
# Good: Filter before joinresult = df.filter(col("date") > "2025-01-01") \ .join(other_df.filter(col("date") > "2025-01-01"), "key")Select Columns Early
# Bad: Read all columns, filter laterdf = spark.read.parquet("s3://bucket/large_table")result = df.select("id", "name", "value")
# Good: Select columns at read timedf = spark.read.parquet("s3://bucket/large_table") \ .select("id", "name", "value")Repartition for Write
# Repartition before write (avoid small files)df.repartition(200, "date") \ .write.format("delta") \ .mode("overwrite") \ .partitionBy("date") \ .save("s3://bucket/output")
# Coalesce (reduce partitions without shuffle)df.coalesce(10) \ .write.format("delta") \ .mode("overwrite") \ .save("s3://bucket/output")Cache Strategy
# Cache for multiple actionsdf_cached = df.filter(col("date") >= "2025-01-01").cache()
df_cached.count() # First action: materialize cachedf_cached.groupBy("category").count().show() # Uses cachedf_cached.groupBy("region").count().show() # Uses cache
# Unmatch when donedf_cached.unpersist()Error Handling Patterns
Try-Catch for Transformations
from pyspark.sql.functions import col, when, litfrom pyspark.sql.types import IntegerType
def safe_parse_int(column): """Safely parse string to int, default to null on failure.""" return when(col(column).cast(IntegerType()).isNotNull(), col(column).cast(IntegerType())) \ .otherwise(lit(None))
# Apply safe parsingdf = df.withColumn("user_id_int", safe_parse_int("user_id_str"))Validation Framework
from pyspark.sql.functions import col, count, when, lit
def validate_data(df, rules): """ Validate data against rules. Returns: (valid_df, invalid_df, validation_report) """ validation_results = []
for rule_name, rule_func in rules.items(): # Apply rule violations = df.filter(~rule_func(col("*")))
# Record result validation_results.append({ "rule": rule_name, "violations": violations.count(), "sample": violations.limit(10).collect() })
return validation_results
# Define rulesrules = { "no_null_ids": lambda c: col("id").isNotNull(), "positive_amount": lambda c: col("amount") > 0, "valid_email": lambda c: col("email").rlike("^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}$")}
# Validateresults = validate_data(df, rules)for result in results: print(f"Rule: {result['rule']}, Violations: {result['violations']}")UDF Patterns
Pandas UDF (Vectorized)
from pyspark.sql.functions import pandas_udffrom pyspark.sql.types import IntegerTypeimport pandas as pd
# Define Pandas UDF (much faster than regular UDF)@pandas_udf(IntegerType())def calculate_category_pandas(description: pd.Series) -> pd.Series: """ Categorize based on description using pandas. """ categories = [] for desc in description: if pd.isna(desc): categories.append(0) elif "premium" in desc.lower(): categories.append(1) else: categories.append(2) return pd.Series(categories)
# Apply UDFdf = df.withColumn( "category", calculate_category_pandas(col("description")))Regular UDF (Use Sparingly)
from pyspark.sql.functions import udffrom pyspark.sql.types import StringType
# Define UDF@udf(StringType())def format_name(first_name: str, last_name: str) -> str: """ Format name as "Last, First". """ if not first_name or not last_name: return None return f"{last_name.upper()}, {first_name.capitalize()}"
# Apply UDFdf = df.withColumn( "formatted_name", format_name(col("first_name"), col("last_name")))File I/O Patterns
Read with Schema
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
# Define schemaschema = StructType([ StructField("id", IntegerType(), nullable=False), StructField("name", StringType(), nullable=True), StructField("created_at", TimestampType(), nullable=True)])
# Read with schema (faster, safer)df = spark.read.format("delta") \ .schema(schema) \ .load("s3://bucket/data")Write with Options
# Write with file size optimizationdf.write.format("delta") \ .mode("overwrite") \ .option("maxRecordsPerFile", 1000000) \ .option("compression", "zstd") \ .partitionBy("date", "country") \ .save("s3://bucket/output")Cost Optimization Patterns
Use Column Pruning
# Read only required columnsdf = spark.read.format("delta") \ .load("s3://bucket/table") \ .select("id", "name", "amount", "date")Use Predicate Pushdown
# Filter at source (pushed to storage)df = spark.read.format("delta") \ .load("s3://bucket/table") \ .filter(col("date") >= "2025-01-01")Optimize Shuffle Partitions
# Reduce shuffle partitions for small dataspark.conf.set("spark.sql.shuffle.partitions", "50")
# Increase for large dataspark.conf.set("spark.sql.shuffle.partitions", "500")Testing Patterns
Data Testing with Great Expectations
# Test data qualitydef test_data_quality(df): """ Test data quality and return report. """ tests = { "row_count": df.count() > 0, "no_null_ids": df.filter(col("id").isNull()).count() == 0, "positive_amounts": df.filter(col("amount") <= 0).count() == 0, "unique_ids": df.select("id").distinct().count() == df.count() }
return tests
# Run testsresults = test_data_quality(df)for test_name, passed in results.items(): status = "PASS" if passed else "FAIL" print(f"{test_name}: {status}")Key Takeaways
- Filter and select early: Reduce data movement
- Use broadcast joins: For small tables (10-100x faster)
- Handle skew: Salting for skewed joins
- Cache strategically: For reused DataFrames
- Use Pandas UDFs: 10-100x faster than regular UDFs
- Define schemas: Faster reads, type safety
- Monitor shuffle: Tune partitions
- Test data quality: Validation in pipelines
Back to Module 2