Modern data pipelines handle massive volumes of structured and unstructured data every day. As datasets grow, poorly optimized Spark jobs become slower, more expensive, and harder to scale. Common issues include long execution times, excessive shuffling, memory bottlenecks, and inefficient joins.
Effective PySpark optimization can significantly improve performance, reduce infrastructure costs, and enhance cluster efficiency. In this article, we’ll explore 12 proven PySpark optimization techniques with practical examples and real-world performance strategies used by data engineers.
How Spark Executes Your Code
You need to learn how Spark executes your code before you start your optimization work. Developers write PySpark code without understanding the underlying processes which power their code. The absence of knowledge results in suboptimal performance decisions. The core mechanics of this section enable readers to understand every optimization technique which follows.
Understanding Spark Architecture
Spark operates its distributed system which enables simultaneous data processing across various computers. Every Spark application consists of two primary components which operate in unison.
- Driver vs Executors
The Driver serves as the central command system for your Spark application. It executes your main program while developing the execution strategy and supervising all operational activities. The Executors function as the operational staff. The cluster distributes these workers to various machines which store data in memory while conducting actual computational tasks.
The Driver divides the work into smaller tasks which it dispatches to Executors when you submit a Spark job. Each Executor operates on its designated data segment without any dependencies on other systems. The combination of parallel processing methods enables Spark to deliver high-speed performance.
- Jobs, Stages, and Tasks
Spark organizes your computation work into three hierarchical layers.
- Job: A complete computation triggered by an action (like
count()orwrite()). - Stage: A set of tasks that can run without shuffling data across the network.
- Task: The smallest unit of work. Each task processes one partition of data.
You can find performance problems in the Spark UI by using this hierarchical structure to locate various system components.
Lazy Evaluation in Spark
The Spark framework will not execute your transformations at the moment you create them. The system records your intended actions when you use the filter() and select() and groupBy() functions. The system creates a logical structure to represent your intended actions. The system requires you to perform an action which includes show() and count() and write() to initiate the execution process.
Lazy evaluation describes this pattern of operation. The system enables Spark to design an entire query plan which it will execute after all planning is finished. Before any work begins Spark can change the order of tasks and move data source filters closer and remove unneeded components.
Understanding Spark Transformations and Actions
All PySpark operations fall into two categories.
- Transformations: Transformations create new DataFrames through their execution of lazy operations. The functions
filter(),select(),join(),groupBy(), andwithColumn()create new DataFrames through their execution of lazy operations. Spark records these but does not run them yet. - Actions: Actual execution begins when actions are performed. The functions
count(),collect(),show(),write(), andfirst()serve as examples of this behavior. When you call an action, Spark evaluates all the queued transformations and runs the job.
A common mistake occurs when people execute multiple actions on the same DataFrame without needing them. The system executes all transformations again for every action unless you use data caching.
Reading Spark Execution Plans with explain()
The explain() method is your debugging tool. The system displays its complete query execution plan through this feature. The system allows you to observe two aspects of the operation because it shows filter pushdown results and broadcast join usage and shuffle operation details.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ExplainDemo").getOrCreate()
df = spark.read.parquet("/data/sales.parquet")
df_filtered = df.filter(df["revenue"] > 5000).select("product", "revenue")
# Read the execution plan
df_filtered.explain(True)
Output:
== Parsed Logical Plan ==
'Project ['product,'revenue]
+- 'Filter ('revenue > 5000)
+- Relation[...] parquet== Analyzed Logical Plan ==
...== Optimized Logical Plan ==
Project [product#10,revenue#11]
+- Filter (isnotnull(revenue#11) AND (revenue#11 > 5000))
+- Relation[...] parquet== Physical Plan ==
*(1) Project [product#10,revenue#11]
+- *(1) Filter (isnotnull(revenue#11) AND (revenue#11 > 5000))
+- *(1) FileScan parquet [...] PushedFilters:[IsNotNull(revenue),GreaterThan(revenue,5000.0)]
You can see PushedFilters present in the output. The filter applies at the file level which serves as an excellent performance indicator.
Ways to Optimise Your Spark Models
Now, we’ll go through the techniques that will help to optimize your spark models.
Technique 1: Use Columnar File Formats Like Parquet or ORC
The file format you select results in significant effects on Spark’s ability to read data. Teams prefer CSV and JSON as their standard formats because these formats require minimal effort to produce. The use of these formats causes major performance issues when operations reach their maximum limits.
Why CSV and JSON Are Slower
CSV and JSON are row-based formats. To read a single column, Spark must read every row and parse all columns. This wastes I/O and CPU time. They also have no built-in schema, so Spark must infer it which adds extra overhead.
Benefits of Parquet and ORC
Parquet and ORC function as column-based data formats which support analytical operations. The system organizes data storage according to columns instead of storing data according to rows.
- Columnar Storage: Columnar Storage allows Spark to access only the specific columns which you require. When you choose 3 columns from a dataset containing 50 columns Spark will exclude 47 columns from the processing.
- Compression Benefits: Columnar formats achieve superior data compression results by using their columnar storage structure. The compression process works effectively because similar values within a single column maintain proximity. The system achieves storage cost reductions while accelerating reading times.
- Predicate Pushdown: Parquet and ORC maintain statistical information (minimum and maximum values and null counts) for every column across all row groups. Spark uses these statistics to skip entire chunks of data without reading them.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.types import (
StructType,
StructField,
StringType,
IntegerType,
DoubleType
)
spark = SparkSession.builder.appName("FileFormatDemo").getOrCreate()
# Create dummy sales data
data = [
("P001", "Laptop", "Electronics", 1200.50, 30),
("P002", "Phone", "Electronics", 800.00, 75),
("P003", "Desk", "Furniture", 350.00, 20),
("P004", "Chair", "Furniture", 200.00, 50),
("P005", "Monitor", "Electronics", 450.75, 40),
("P006", "Keyboard", "Electronics", 80.00, 100),
("P007", "Lamp", "Furniture", 60.00, 60),
("P008", "Tablet", "Electronics", 600.00, 25),
]
schema = StructType([
StructField("product_id", StringType(), True),
StructField("product_name", StringType(), True),
StructField("category", StringType(), True),
StructField("price", DoubleType(), True),
StructField("units_sold", IntegerType(), True),
])
df = spark.createDataFrame(data, schema)
# Write as CSV (slow format)
df.write.mode("overwrite").csv("/tmp/sales_csv")
# Write as Parquet (fast columnar format)
df.write.mode("overwrite").parquet("/tmp/sales_parquet")
# Read back Parquet — fast, schema-aware
df_parquet = spark.read.parquet("/tmp/sales_parquet")
df_parquet.select("product_name", "price").show()
Output:
Best Practices for File Formats
- Use Parquet for analytical workloads and pipelines.
- Use ORC when working with Hive or HBase ecosystems.
- Always write with Snappy compression for a good balance of speed and size.
- Avoid CSV and JSON for intermediate storage between pipeline steps.
Technique 2: Filter Data as Early as Possible
The simplest and most effective PySpark optimization method involves performing early data filtering. The speed of your entire system improves when Spark processes a smaller amount of data throughout your entire pipeline.
What Is Predicate Pushdown?
A predicate is a filter condition that includes both age > 30 and status == "active". Predicate pushdown means Spark moves these filter conditions as close to the data source as possible, ideally into the file scan itself. Spark performs its reading process by applying filters instead of retrieving all data for subsequent filtering.
Why Early Filtering Improves Performance
The operation of filtering before processing enables all subsequent tasks to work with a smaller data set which includes joins and aggregations and sorts. The process results in decreased memory requirements and reduced network demands and shorter CPU processing times for each stage of your project.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("EarlyFilterDemo").getOrCreate()
# Dummy employee data
data = [
(1, "Alice", "Engineering", 95000, "active"),
(2, "Bob", "Marketing", 72000, "inactive"),
(3, "Charlie", "Engineering", 110000, "active"),
(4, "Diana", "HR", 65000, "active"),
(5, "Eve", "Engineering", 88000, "inactive"),
(6, "Frank", "Marketing", 78000, "active"),
(7, "Grace", "HR", 70000, "active"),
(8, "Hank", "Engineering", 120000, "active"),
]
schema = ["emp_id", "name", "department", "salary", "status"]
df = spark.createDataFrame(data, schema)
# BAD: Filter late after join and aggregation
df_bad = (
df.groupBy("department")
.sum("salary")
.filter(col("sum(salary)") > 200000)
)
# GOOD: Filter early before aggregation
df_good = (
df.filter(
(col("status") == "active") &
(col("salary") > 70000)
)
.groupBy("department")
.sum("salary")
)
df_good.show()
Output:
Verifying Optimization Using explain()
df_good.explain()
Output:
Common Filtering Mistakes
- The system operates through its checking process which executes after the joining operation.
- The process needs to execute data collection through
collect()which brings data to Python before users start their data filtering work through Python loops. - The system allows for filters on calculated columns when users should first apply filters on original source columns.
Technique 3: Select Only Required Columns
Reading unnecessary columns wastes I/O time and memory. Many developers write select("*") out of habit but this practice causes your Spark jobs to suffer performance problems when running on wide datasets.
The Problem with Wide DataFrames
A wide DataFrame has many columns which can reach hundreds in actual data warehouse environments. The 200 columns need to be loaded because your analysis needs to use only 5 of them.
Why select(“*”) Hurts Performance
select("*") forces Spark to read all columns while it processes your job through its different stages. Spark can eliminate entire columns from its processing when you choose specific data elements through columnar formats such as Parquet.
Column Pruning in Spark
Column pruning is the process of eliminating unused columns from the query plan. Spark’s Catalyst optimizer performs column pruning automatically when you use explicit select() statements. The system completely avoids reading those columns from the source.
PySpark Code Example
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ColumnPruningDemo").getOrCreate()
# Wide dummy dataset
data = [
("E001", "Alice", 30, "F", "Engineering", 95000, "New York", "[email protected]", "2018-05-10", "active"),
("E002", "Bob", 35, "M", "Marketing", 72000, "Chicago", "[email protected]", "2019-03-15", "inactive"),
("E003", "Charlie", 28, "M", "Engineering", 110000, "San Francisco", "[email protected]", "2020-01-20", "active"),
("E004", "Diana", 42, "F", "HR", 65000, "Austin", "[email protected]", "2015-07-08", "active"),
]
schema = [
"emp_id",
"name",
"age",
"gender",
"department",
"salary",
"city",
"email",
"join_date",
"status"
]
df = spark.createDataFrame(data, schema)
# BAD: Read all columns
df_bad = df.select("*").filter(df["status"] == "active")
# GOOD: Select only what you need
df_good = (
df.select("emp_id", "name", "department", "salary")
.filter(df["status"] == "active")
)
df_good.show()
Output:
How Catalyst Optimizer Helps
The Catalyst optimizer of Spark automatically removes columns from its physical plan construction process. The system tracks needed columns for complex queries while eliminating unneeded ones through its tracing mechanism. The use of explicit select() statements enables Catalyst to perform its task with greater precision.
Technique 4: Optimize Partitioning
Partitioning is one of the most impactful areas of PySpark performance. Getting your partition strategy wrong can make even simple jobs run slowly.
Understanding Spark Partitions
A partition functions as a DataFrame section which remains accessible through one executor. Spark conducts simultaneous processing of each DataFrame partition. The system achieves increased processing capacity through additional partitions yet excessive tiny partitions result in processing delays. Your cluster functions at below its maximum capacity because you have created excessively large partitions.
Default Partitioning Behavior
Spark establishes data partitions from file input based on the number of input splits. HDFS and S3 systems create one partition for each file block. Spark creates 200 partitions for shuffle operations which include groupBy and join operations because spark.sql.shuffle.partitions controls this default setting.
The use of 200 shuffle partitions exceeds requirements for small datasets because it results in excessive tiny tasks. The 200 partition count might not adequately handle very large datasets.
How Partitions Affect Parallelism
Spark allows execution of one task for each partition which uses one core of the system. Spark starts 20 tasks simultaneously across 10 execution stages when your cluster has 20 cores and your system has 200 partitions. The system requires 10 cores to operate because you created 10 partitions.
The standard recommendation suggests using 2 to 4 partitions for each CPU core present within your cluster.
repartition() vs coalesce()
The two methods both alter partition counts yet their operational processes differ from each other.
- repartition(n): The function
repartition(n)redistributes data through a complete network-based shuffle operation. You should use it when you want to create more partitions or when you require equal-sized partitions. The process incurs high costs because it transmits data through the network system. - coalesce(n): The function
coalesce(n)achieves partition reduction through non-disruptive partition movement. The function enables partition merging on executors when two partitions exist. You should use it to decrease partitions (for example, before writing output). The solution costs less money to implement yet it produces partition sizes which do not reach equal distribution.
PySpark Code Example
from pyspark.sql import SparkSession
spark = (
SparkSession.builder
.appName("PartitionDemo")
.config("spark.sql.shuffle.partitions", "10")
.getOrCreate()
)
# Create dummy transaction data
data = [
(
i,
f"TXN{i:05d}",
float(i * 15.5),
"completed" if i % 3 != 0 else "failed"
)
for i in range(1, 101)
]
schema = ["txn_id", "txn_ref", "amount", "status"]
df = spark.createDataFrame(data, schema)
print(f"Initial partitions: {df.rdd.getNumPartitions()}")
# Increase partitions for parallel processing
df_repartitioned = df.repartition(20)
print(
f"After repartition(20): "
f"{df_repartitioned.rdd.getNumPartitions()}"
)
# Reduce partitions before writing output
df_coalesced = df_repartitioned.coalesce(4)
print(
f"After coalesce(4): "
f"{df_coalesced.rdd.getNumPartitions()}"
)
# Repartition by a column for join optimization
df_by_status = df.repartition(10, "status")
df_by_status.groupBy("status").count().show()
Output:
Technique 5: Use Broadcast Joins for Small Tables
The most resource-intensive operations in Spark systems become their most expensive operations because they need to move data between different network locations. A broadcast join allows you to remove the need for data movement when one table remains small.
Why Spark Joins Are Expensive
The standard Spark join requires Both DataFrames to have matching keys on the same executor. The Spark system achieves this result by transferring data through the network which moves machine rows until their matching keys reach the correct location. The process of network data transfer incurs both high expenses and extended time delays.
What Is a Broadcast Join?
In a broadcast join, Spark sends a full copy of the small table to every executor. The executors use their local large table partitions to perform the join without needing to shuffle data between them. This approach results in a substantial decrease of execution time.
When to Use Broadcast Joins
You should use a broadcast join when one table exists which can be entirely stored in the memory of each executor. Spark automatically broadcasts tables smaller than spark.sql.autoBroadcastJoinThreshold (default 10 MB). You can manually broadcast larger tables if your executors have enough memory.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
spark = (
SparkSession.builder
.appName("BroadcastJoinDemo")
.getOrCreate()
)
# Large fact table — orders
orders_data = [
(1001, "C01", "P001", 2, 2401.00),
(1002, "C02", "P003", 1, 350.00),
(1003, "C01", "P002", 3, 2400.00),
(1004, "C03", "P001", 1, 1200.50),
(1005, "C02", "P005", 2, 901.50),
(1006, "C04", "P006", 5, 400.00),
(1007, "C03", "P004", 2, 400.00),
(1008, "C01", "P007", 1, 60.00),
]
orders = spark.createDataFrame(
orders_data,
["order_id", "customer_id", "product_id", "qty", "total_amount"]
)
# Small dimension table — product categories
# (candidate for broadcast)
product_data = [
("P001", "Laptop", "Electronics"),
("P002", "Phone", "Electronics"),
("P003", "Desk", "Furniture"),
("P004", "Chair", "Furniture"),
("P005", "Monitor", "Electronics"),
("P006", "Keyboard", "Electronics"),
("P007", "Lamp", "Furniture"),
]
products = spark.createDataFrame(
product_data,
["product_id", "product_name", "category"]
)
# BAD: Standard join (triggers shuffle)
df_standard = orders.join(
products,
on="product_id",
how="inner"
)
# GOOD: Broadcast join
# (no shuffle for small table)
df_broadcast = orders.join(
broadcast(products),
on="product_id",
how="inner"
)
df_broadcast.select(
"order_id",
"product_name",
"category",
"total_amount"
).show()
Output:
How Broadcast Joins Reduce Shuffle
When Spark sees broadcast(products), it ships the entire products table to every executor upfront. Each executor keeps the table in their memory storage. The join process runs on every executor which manages its own orders partition by matching rows without any network data transmission. The result produces a join process that completes at a speed which exceeds normal performance.
Technique 6: Enable Adaptive Query Execution (AQE)
The introduction of Adaptive Query Execution (AQE) in Spark version 3.0 brought the most significant performance boost to Spark between its present time and its last major update. The system allows Spark to modify your query optimizations during execution by using real data metrics which it obtains through runtime operations.
What Is AQE in Spark?
Spark used to create a complete execution plan which it would follow throughout the entire process without making any adjustments based on actual data. The implementation of AQE enables this functionality. The feature enables Spark to assess execution performance through actual data assessment which it obtains from each shuffle period.
Runtime Query Optimization with AQE
The system includes three primary functions which start working immediately after users activate the system.
- Dynamic Join Strategy Selection: The system allows AQE to change its execution method from sort-merge join to broadcast join during runtime. Spark automatically sends one side of a join to all nodes when it detects that the join’s size will be smaller than predicted after a shuffle operation. This approach prevents a complete shuffle operation when the table exceeds the broadcast size limit which base on file dimensions.
- Skew Join Optimization: Uneven data distribution creates data skew because some partitions receive higher data volumes than other partitions. This situation leads to one or two slow tasks which prevent the entire job from progressing. The system uses AQE to find runtime skewed partitions which it then divides into smaller parts for better distribution of tasks.
- Post-Shuffle Partition Coalescing: The system allows AQE to combine multiple low volume shuffle partitions into one larger partition after completing the shuffle operation. This process eliminates the requirement for multiple small tasks which perform minimal functions because of their low execution volume.
PySpark Code Example
from pyspark.sql import SparkSession
spark = (
SparkSession.builder
.appName("AQEDemo")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.sql.adaptive.skewJoin.enabled", "true")
.config("spark.sql.adaptive.localShuffleReader.enabled", "true")
.getOrCreate()
)
# Dummy sales transactions
sales_data = [
(
i,
f"CUST_{i % 50:03d}",
f"PROD_{i % 20:03d}",
float(i * 10.5)
)
for i in range(1, 201)
]
sales = spark.createDataFrame(
sales_data,
["sale_id", "customer_id", "product_id", "revenue"]
)
# Dummy product catalog
catalog_data = [
(
f"PROD_{i:03d}",
f"Product {i}",
"Category A" if i % 2 == 0 else "Category B"
)
for i in range(20)
]
catalog = spark.createDataFrame(
catalog_data,
["product_id", "product_name", "category"]
)
# AQE will optimize this join dynamically at runtime
result = (
sales.join(catalog, on="product_id")
.groupBy("category")
.agg({"revenue": "sum"})
)
result.show()
Output:
The implementation of AQE provides organizations with an advantage which requires minimal effort to achieve. The system should be activated for all Spark version 3.x operations except for cases which require specific exception handling.
Technique 7: Avoid Python UDFs Whenever Possible
The Python User Defined Functions UDFs create the most frequent performance problems in PySpark because they introduce unexpected delays. Python developers find it easy to use these functions but their usage results in significant performance degradation.
Why Python UDFs Slow Down Spark
Spark operates directly on the Java Virtual Machine which serves as its fundamental execution platform. Python operates outside the Java Virtual Machine environment. Spark needs to execute multiple steps when you use a Python UDF because it must convert data from the JVM to Python, execute the function, and then send back the results to the JVM. The system handles communication between components by processing one row at a time.
Serialization Overhead
The system needs to transform every data row from Spark’s internal binary format into Python objects for processing before it can create the Python objects. The process of serialization and deserialization incurs high costs because it needs to handle millions of rows.
JVM-to-Python Communication Cost
The system creates an independent Python process for each executor in Spark. The JVM and Python processes exchange data through a network socket. When operating at scale, this communication bottleneck causes Python UDFs to perform 10 times slower than equivalent native Spark functions.
Prefer Native Spark Functions
The functions from pyspark.sql.functions execute completely within the JVM environment which eliminates the need for Python data conversion. The system achieves faster execution speeds through compiled and optimized functions that outperform custom Python UDFs.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
when,
regexp_replace,
udf,
initcap
)
from pyspark.sql.types import StringType
spark = (
SparkSession.builder
.appName("UDFDemo")
.getOrCreate()
)
data = [
("alice smith", 85000, "engineering"),
("bob jones", 72000, "marketing"),
("charlie brown", 110000, "engineering"),
("diana prince", 65000, "hr"),
("eve white", 92000, "engineering"),
]
df = spark.createDataFrame(
data,
["name", "salary", "department"]
)
# BAD: Python UDF — slow due to serialization
def format_name_udf(name):
return name.title().replace(" ", "_")
format_udf = udf(format_name_udf, StringType())
df_udf = df.withColumn(
"formatted_name",
format_udf(col("name"))
)
# GOOD: Native Spark functions
# — fast, no serialization
df_native = (
df.withColumn(
"formatted_name",
regexp_replace(
initcap(col("name")),
" ",
"_"
)
)
.withColumn(
"salary_band",
when(col("salary") >= 100000, "Senior")
.when(col("salary") >= 80000, "Mid")
.otherwise("Junior")
)
)
df_native.show()
Output:
Technique 8: Cache Data Strategically
Spark kind of recomputes your DataFrame from scratch every time you hit an action on it. So if you do count() and then, later show() on the “same” DataFrame, Spark ends up running the whole pipeline twice. Caching helps, but only if you actually use it with a bit of sense, not just because it exists.
Understanding Spark Caching
Basically, caching means oncethe DataFrame gets computed the first time, Spark stores the result in memory (or disk). Then for the next action, Spark can read those stored rows and skip the recomputation from the original sources.
When to Use cache()
You should cache a DataFrame when stuff like this is true:
- You end up reusing the same DataFrame more than once in your workflow.
- The DataFrame is costly to build (think multiple joins , heavy aggregations , or lots of file reads).
- It can comfortably fit inside the memory available on the executors.
When Caching Can Hurt Performance
If you cache a DataFrame that you touch only once, you pay some overhead for nothing. And caching huge DataFrames that don’t really fit in memory can lead to spill to disk , which can end up slower than just recomputing. So it’s worth checking if caching helps in your scenario.
cache() vs persist()
cache() always stores the DataFrame in memory in a deserialized form. persist() gives you options , like memory only, memory + disk, disk only, or serialized in-memory. In cases where you need more control over storage behavior, persist() is usually the better choice.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
sum as spark_sum,
avg
)
spark = (
SparkSession.builder
.appName("CachingDemo")
.getOrCreate()
)
# Dummy retail data
data = [
("2024-01", "Electronics", "Laptop", 1200.00, 30),
("2024-01", "Furniture", "Chair", 200.00, 50),
("2024-02", "Electronics", "Phone", 800.00, 75),
("2024-02", "Electronics", "Monitor", 450.00, 40),
("2024-03", "Furniture", "Desk", 350.00, 20),
("2024-03", "Electronics", "Tablet", 600.00, 25),
("2024-04", "Furniture", "Lamp", 60.00, 60),
("2024-04", "Electronics", "Keyboard", 80.00, 100),
]
schema = [
"month",
"category",
"product",
"price",
"units"
]
df = spark.createDataFrame(data, schema)
# Compute revenue once
df_revenue = df.withColumn(
"revenue",
col("price") * col("units")
)
# Cache because we use df_revenue multiple times
df_revenue.cache()
# Action 1: Revenue by category
print("Revenue by Category:")
df_revenue.groupBy("category").agg(
spark_sum("revenue").alias("total_revenue")
).show()
# Action 2: Revenue by month
print("Revenue by Month:")
df_revenue.groupBy("month").agg(
spark_sum("revenue").alias("monthly_revenue")
).show()
# Action 3: Average unit price
print("Average Price per Category:")
df_revenue.groupBy("category").agg(
avg("price").alias("avg_price")
).show()
# Always unpersist when done
df_revenue.unpersist()
Output:
Removing Cached DataFrames
You need to use unpersist() after you finish working with a cached DataFrame. Cached DataFrames maintain their memory usage until either the Spark session terminates or you choose to free them. Excessive caching of DataFrames will lead to memory pressure which results in spilling.
Technique 9: Handle Data Skew Efficiently
Skewed data distribution creates one of the most difficult performance challenges for Spark systems. The system operates without detection because it creates extended task execution times for specific tasks which leads to delayed job completion until the slow tasks complete their execution.
What Is Data Skew?
Data skew occurs when some partitions contain far more data than others. A customer orders dataset shows that one major customer has 10 million orders while all other customers average 1,000 orders each. The customer ID grouping operation in Spark creates one partition which contains excessive data.
Symptoms of Skewed Spark Jobs
Your job has reached 95% completion but it experiences a delay during the final tasks. The situation displays classic skew behavior. Most tasks complete their operations quickly while a small number of tasks with heavy workloads create delays for the entire system.
Detecting Skew Using Spark UI
You should access the Spark UI to examine the Stages tab. The task metrics become available when you select a slow stage for analysis. Data skew exists when some tasks show higher values for “Input Size” and “Shuffle Read” and “Duration” than their median values.
Techniques to Fix Data Skew
- Salting: The process requires adding a random prefix that ranges from 0 to N to the skewed key. This generates N smaller partitions which will result from processing the heavy partition. The salt should be deleted after the aggregation process, and the results should be combined.
- AQE Skew Join: Spark will automatically manage the process when you enable the setting
spark.sql.adaptive.skewJoin.enabled. - Broadcast join: The system will broadcast the smaller join side when its size falls below the threshold because this method enables complete operation without needing a shuffle.
- Repartitioning: The system needs manual repartitioning because it requires better distribution through specific column repartitioning.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
rand,
floor,
concat,
lit,
sum as spark_sum
)
spark = (
SparkSession.builder
.appName("SkewDemo")
.config("spark.sql.adaptive.skewJoin.enabled", "true")
.getOrCreate()
)
# Skewed data:
# customer C001 has 80% of all orders
orders_data = (
[
(i, "C001", float(i * 12.5))
for i in range(1, 801)
] +
[
(
i + 800,
f"C{str(i % 10 + 2).zfill(3)}",
float(i * 9.9)
)
for i in range(1, 201)
]
)
orders = spark.createDataFrame(
orders_data,
["order_id", "customer_id", "amount"]
)
# Salting technique to fix skew manually
num_salts = 5
# Add salt to orders
orders_salted = orders.withColumn(
"salted_key",
concat(
col("customer_id"),
lit("_"),
(floor(rand() * num_salts)).cast("string")
)
)
# Aggregate with salted key
agg_salted = (
orders_salted
.groupBy("salted_key", "customer_id")
.agg(
spark_sum("amount").alias("partial_sum")
)
)
# Final aggregation
# remove salt and sum partial results
result = (
agg_salted
.groupBy("customer_id")
.agg(
spark_sum("partial_sum").alias("total_amount")
)
)
result.orderBy(
"total_amount",
ascending=False
).show(5)
Output:
Real-World Skew Optimization Example
Data skew develops during real pipelines when users join on active user IDs and top product IDs and optional foreign keys which contain default null values. Always check your join key distributions before writing your pipeline. The method to check for skew in data uses groupBy("join_key").count().orderBy("count", ascending=False).show(10) to show results.
Technique 10: Minimize Shuffle Operations
The most costly operation in Spark processing refers to shuffles because these operations require network data transfers between executors. The most effective optimization for your system occurs through the process of reducing shuffle operations.
Why Shuffles Are Expensive
All rows must undergo serialization before Spark can process them during the shuffle operation because the system needs to store them on disk and send them to the appropriate executor and then convert them back into their original format. The system operates all three components together which include disk I/O and network I/O and CPU processing. The duration of shuffles on extensive datasets can extend from several minutes to multiple hours.
Operations That Trigger Shuffles
The following common operations in Spark create shuffles:
- groupBy(): The operation groups data based on key values. The network transfer process becomes necessary because all rows sharing the same key must be processed on a single executor.
- join(): The operation performs a join between two DataFrames based on matching keys. The join key partitioning requires both DataFrames to undergo shuffling operations on one or both DataFrame sides.
- distinct(): The operation eliminates all duplicate rows through the entire dataset. The operation requires all duplicate row instances to gather at a single location.
- orderBy(): The operation sorts all data across every partition. The operation performs a global sort which automatically creates a shuffle process.
PySpark Code Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
sum as spark_sum,
countDistinct
)
spark = (
SparkSession.builder
.appName("ShuffleDemo")
.config("spark.sql.shuffle.partitions", "8")
.getOrCreate()
)
data = [
("2024-Q1", "North", "Electronics", "Laptop", 1200.00, 30),
("2024-Q1", "South", "Electronics", "Phone", 800.00, 75),
("2024-Q2", "North", "Furniture", "Chair", 200.00, 50),
("2024-Q2", "East", "Electronics", "Monitor", 450.00, 40),
("2024-Q3", "West", "Electronics", "Tablet", 600.00, 25),
("2024-Q3", "North", "Furniture", "Desk", 350.00, 20),
("2024-Q4", "South", "Electronics", "Keyboard", 80.00, 100),
("2024-Q4", "East", "Furniture", "Lamp", 60.00, 60),
]
schema = [
"quarter",
"region",
"category",
"product",
"price",
"units"
]
df = spark.createDataFrame(data, schema)
df = df.withColumn(
"revenue",
col("price") * col("units")
)
# BAD:
# Multiple separate groupBy operations
# (multiple shuffles)
df_q1 = df.groupBy("category").agg(
spark_sum("revenue").alias("cat_revenue")
)
df_q2 = df.groupBy("region").agg(
spark_sum("revenue").alias("reg_revenue")
)
# GOOD:
# Combine aggregations in one groupBy
# to reduce shuffles
df_combined = (
df.groupBy("category", "region")
.agg(
spark_sum("revenue").alias("total_revenue"),
spark_sum("units").alias("total_units")
)
)
df_combined.show()
Output:
Monitoring Shuffle Metrics in Spark UI
The Stages tab in Spark UI displays both Shuffle Read and Shuffle Write metrics. The operations require optimization from you when they produce large shuffle sizes which should lead you to pre-partition your data for capacity reduction. The SQL tab shows shuffle exchange nodes in your query plan.
Technique 11: Use Bucketing for Repeated Joins
The pipeline requires multiple joins of the same large tables which causes shuffle overhead to disappear through bucketing because it creates disk-based data organization.
What Is Bucketing?
Bucketing is a technique where Spark writes data to disk pre-sorted and pre-partitioned by a join key. Spark uses pre-existing data partitions to conduct its joins instead of performing data shuffling. The result is a join with no shuffle at all.
How Bucketing Improves Join Performance
When you bucket two tables on the same key with the same number of buckets matching rows go into matching bucket files. When Spark reads these tables for a join it can directly pair up corresponding bucket files without any network transfer. The shuffle cost drops to zero.
PySpark Code Example
from pyspark.sql import SparkSession
spark = (
SparkSession.builder
.appName("BucketingDemo")
.config(
"spark.sql.sources.bucketing.enabled",
"true"
)
.enableHiveSupport()
.getOrCreate()
)
# Large orders table
orders_data = [
(
i,
f"CUST_{i % 100:03d}",
float(i * 25.0),
"completed"
)
for i in range(1, 501)
]
orders = spark.createDataFrame(
orders_data,
["order_id", "customer_id", "amount", "status"]
)
# Customer info table
customers_data = [
(
f"CUST_{i:03d}",
f"Customer {i}",
f"Region_{i % 5}"
)
for i in range(100)
]
customers = spark.createDataFrame(
customers_data,
["customer_id", "customer_name", "region"]
)
# Write both tables bucketed on customer_id
# with the same number of buckets
orders.write \
.bucketBy(10, "customer_id") \
.sortBy("customer_id") \
.mode("overwrite") \
.saveAsTable("orders_bucketed")
customers.write \
.bucketBy(10, "customer_id") \
.sortBy("customer_id") \
.mode("overwrite") \
.saveAsTable("customers_bucketed")
# Now this join requires NO shuffle
# Spark matches bucket files directly
result = (
spark.table("orders_bucketed")
.join(
spark.table("customers_bucketed"),
on="customer_id"
)
.groupBy("region")
.agg({"amount": "sum"})
)
result.show()
Output:
Best Use Cases for Bucketing
- Your pipeline requires multiple joins with large dimension tables which you process continuously.
- Data warehouses use fact-to-dimension joins for their joining operations.
- Any two large DataFrames that share the same key will have multiple join operations throughout the day.
- You should use bucket-merge joins to replace sort-merge joins in these specific situations.
Technique 12: Tune Spark Configuration Settings
The proper Spark configuration settings deliver substantial performance improvements which remain applicable even after implementing all code-level improvements. Your jobs experience performance degradation because misconfigured executors either waste resources or generate memory errors.
Important Spark Configurations for Performance
Spark provides more than 100 configuration settings. The following settings deliver the strongest impact for general-purpose performance improvements.
- Executor Memory: Spark configuration through
spark.executor.memorysets the total memory allocation for executor-based calculations and data preservation. Spark moves data to disk when you set this value below the required level. The excessive setting waste memory resources which could support additional executor operations. - Executor Cores: The spark.executor.cores setting determines the number of tasks that each executor can process at the same time. The optimal range for this value lies between 2 and 5. The system experiences garbage collection pressure when multiple cores access the same Java virtual machine memory space.
- Driver Memory: The spark.driver.memory setting establishes the total memory capacity for the driver. You should increase this parameter when your system collects extensive results and needs multiple broadcast variables while executing intricate query planning procedures.
PySpark Configuration Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
sum as spark_sum,
avg
)
spark = (
SparkSession.builder
.appName("ConfigTuningDemo")
.config("spark.executor.memory", "4g")
.config("spark.executor.cores", "4")
.config("spark.driver.memory", "2g")
.config("spark.sql.shuffle.partitions", "50")
.config("spark.sql.adaptive.enabled", "true")
.config(
"spark.sql.adaptive.coalescePartitions.enabled",
"true"
)
.config("spark.memory.fraction", "0.8")
.config("spark.memory.storageFraction", "0.3")
.config(
"spark.serializer",
"org.apache.spark.serializer.KryoSerializer"
)
.getOrCreate()
)
# Dummy payroll dataset
payroll_data = [
(
f"EMP_{i:04d}",
f"Dept_{i % 10}",
float(50000 + (i % 50) * 1000),
"FT" if i % 4 != 0 else "PT"
)
for i in range(1, 201)
]
df = spark.createDataFrame(
payroll_data,
[
"emp_id",
"department",
"annual_salary",
"employment_type"
]
)
result = (
df.filter(col("employment_type") == "FT")
.groupBy("department")
.agg(
spark_sum("annual_salary").alias("total_payroll"),
avg("annual_salary").alias("avg_salary")
)
.orderBy("total_payroll", ascending=False)
)
result.show(5)
Output:
Cluster-Level vs Application-Level Tuning
- Cluster-level settings: The cluster uses default settings from spark-defaults.conf to establish cluster-wide configuration for all Spark applications. The baseline settings should be established through these settings.
- Application-level settings: Application-level settings (set in
SparkSession.builder.config()) override cluster defaults for a specific job. The system enables job-specific adjustments through these settings.
End-to-End PySpark Optimization Example
Ok so now lets stitch all these techniques together into something that feels more like a real pipeline. We start with a slow, kinda unoptimized job, then we figure out where it stalls, and only after that we stack multiple techniques to get the optimized version out.
Baseline Slow Spark Job
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
sum as spark_sum,
broadcast
)
spark = (
SparkSession.builder
.appName("OptimizedJob")
.config("spark.sql.adaptive.enabled", "true")
.getOrCreate()
)
# Large transactions table
# Read as Parquet instead of CSV for better performance
transactions = spark.read.parquet(
"/tmp/transactions_parquet"
)
# Product lookup table
products = spark.read.parquet(
"/tmp/products_parquet"
)
# Filter early and select only required columns
transactions_filtered = (
transactions
.filter(col("status") == "completed")
.select(
"product_id",
"amount"
)
)
products_selected = (
products
.select(
"product_id",
"category"
)
)
# Broadcast small lookup table
result = (
transactions_filtered
.join(
broadcast(products_selected),
on="product_id"
)
.groupBy("category")
.agg(
spark_sum("amount").alias("total_amount")
)
)
result.show()
Identifying Performance Bottlenecks
If we run result.explain(True) on the slow job it shows a bunch of problems: there is no predicate pushdown, which happens because CSV simply does not support it, you get a full sort merge join which causes a huge shuffle, it reads all columns from both files, and adaptive optimizations are not enabled at all.
Applying Multiple Optimization Techniques
Now let us rewrite the job, with all the optimizations turned on and applied, step by step so it behaves properly.
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
broadcast,
col,
sum as spark_sum
)
spark = (
SparkSession.builder
.appName("OptimizedJob")
.config("spark.sql.adaptive.enabled", "true")
.config(
"spark.sql.adaptive.coalescePartitions.enabled",
"true"
)
.config(
"spark.sql.adaptive.skewJoin.enabled",
"true"
)
.config("spark.sql.shuffle.partitions", "20")
.config(
"spark.serializer",
"org.apache.spark.serializer.KryoSerializer"
)
.getOrCreate()
)
# Create dummy transactions
# (in a real job, read from Parquet)
txn_data = [
(
f"TXN{i:05d}",
f"PROD_{i % 10:03d}",
float(i * 14.5),
"completed" if i % 5 != 0 else "failed",
f"CUST_{i % 50:03d}"
)
for i in range(1, 1001)
]
transactions = spark.createDataFrame(
txn_data,
[
"txn_id",
"product_id",
"amount",
"status",
"customer_id"
]
)
# Small products table
# ideal for broadcasting
prod_data = [
(
f"PROD_{i:03d}",
f"Product {i}",
"Electronics" if i % 2 == 0 else "Furniture"
)
for i in range(10)
]
products = spark.createDataFrame(
prod_data,
[
"product_id",
"product_name",
"category"
]
)
Optimizing Partitions
# Repartition transactions on product_id before join
transactions_repartitioned = transactions.repartition(20, "product_id")
Adding Broadcast Join
# Use broadcast for the small products table — eliminates shuffle
joined = transactions_repartitioned.join(broadcast(products), on="product_id")
Enabling AQE
Already enabled in the SparkSession config above. AQE handles dynamic partition coalescing and skew joins automatically, like it just… well, takes care of it on the fly.
Reducing Shuffle
# Filter early, select only required columns, aggregate in one pass
result = joined \
.filter(col("status") == "completed") \
.select("txn_id", "category", "amount") \
.groupBy("category") \
.agg(spark_sum("amount").alias("total_revenue"))
Final Optimized Version
result.show()
result.explain()
Output:
Conclusion
PySpark optimization is not just one single fix, its more like this stacked set of layered choices that snowball into big performance wins. Start with the high impact basics, use Parquet, flip on AQE , filter early and only pull the columns you actually need. After that, move into the join strategy stuff, think partitioning and deal with skew.
With these 12 techniques in your toolkit you can often drag hours-long Spark runs down to minutes, but you have to apply them in a systematic way. Also measure it using the Spark UI, and keep tuning as you learn. The gap between a slow Spark job and a fast one is usually very obvious once you look at the execution plan.
Login to continue reading and enjoy expert-curated content.
