Greedy Number Partitions in JAVA
In this post we will describe and show a JAVA implementation for the Greedy Number Partition problem. We will start with a simple version that just tackles a list of numbers and work our way up to an improvement that better solves the problem. The code for the algorithm implementations and the Monte Carlo simulation are available on GitHub.
Problem Description
You are running the local kids soccer league in your town. You have 350 kids and have given each one a number 0-99 that roughly captures how good each kid is at soccer. It is your job to create 20 teams such that all teams are roughly equal so that everyone has a fair and fun time during the soccer season. How do you do it?
The problem above can be solved with the greedy number partitioning algorithm. Given a set of numbers, S, and a partition number, p, generate a set of partitions P such that the sums of the numbers in each partition is as equal as possible and that the partitions are mutually exclusive and collectively exhaustive (of S).
Simple Implementation
The simplest algorithm iterates once through each element in S and adds the element to the partition with smallest sum so far. To help keep track of the sum of each partition so far, let’s create a Partition class:
class Partition {
private int sum;
Partition() {
this.sum = 0;
}
public void increaseSum(int amount) {
this.sum += amount;
}
public int getSum() {
return this.sum;
}
}
This class keeps track of the sum of a partition so far. It has a method to increase the partition’s sum by some amount and a getter method for the sum. Using the helper class, we can look at the signature of the method we want to create:
Collection<Partition> simplePartition(
Collection<Integer> numbers,
int partitionCount)
For our first example, the input will be a collection of integers (S in the problem description) and the number of partitions we want to make (p in the problem description). numbers could be the scores of all our little soccer players and partitionCount could be the number of teams in which case the output of the method would be all the teams we were trying to create for the season (I guess the analogy falls apart a bit because we are not guaranteed to have equal sized teams).
We will want a way to compare Partition instances:
class PartitionComparator implements Comparator<Partition> {
@Override
public int compare(Partition partitionA, Partition partitionB) {
return Integer.compare(partitionA.getSum(), partitionB.getSum());
}
}
With the utility class and comparators defined, we can finally take a look at the whole implementation:
Collection simplePartition(Collection<Integer> numbers, int partitionCount) {
PartitionComparator partitionComparator = new PartitionComparator();
Queue<Integer> numberQueue = new ArrayDeque<>(numbers);
List<Partition> partitions = new ArrayList<>();
for (int i = 0; i < partitionCount; i++) {
partitions.add(new Partition());
}
while (!numberQueue.isEmpty()) {
Integer number = numberQueue.poll();
Partition lowestSumPartition = getLowestSumPartition(partitions, partitionComparator);
lowestSumPartition.increaseSum(number);
}
return partitions;
}
Let’s review what the code is doing.
I copy the numbers into a Queue because I like using Queue.poll() instead of List.remove(0), fight me.
Queue<Integer> numberQueue = new ArrayDeque<>(numbers);
Next, we initialize our list of Partition instances which will set all their sums to zero.
List<Partition> partitions = new ArrayList<>();
for (int i = 0; i < partitionCount; i++) {
partitions.add(new Partition());
}
The rest of the code is what actually performs the algorithm. It iterates over each number, removes it from the queue, and adds it to the partition with the lowest sum.
while (!numberQueue.isEmpty()) {
Integer number = numberQueue.poll();
Partition lowestSumPartition = getLowestSumPartition(partitions, partitionComparator);
lowestSumPartition.increaseSum(number);
}
Getting the “partition with the lowest sum” is not free and the getLowestSumPartition() method takes care of that for us. It searches through all the partitions to get the lowest one each iteration which means the complexity this algorithm is at least O(nk), where n is the length of the number list and k is the partition count.
I ran this code 1000 times. Each time I generated an array of 350 random numbers [0, 100) and partitioned them into 20 partitions. As a result:
the average sum across all the partitions was 17345.952
the average average sum for a partition was 835.66
the average maximum sum for a partition was 915.181
the average minimum sum for a partition was 866.832
the average difference between the maximum and minimum was 79.521
the average standard deviation of the sum across all partitions was 553.4461
Improved Implementation
We are going to make two major changes to the implementation above:
use a smarter algorithm to make the sums of the partitions closer to each other
use a more sophisticated data structure to clean up the code a bit
The first part is really simple; we will sort the number list before we process it. The second part is also simple, we will store the Partition instances in a PriorityQueue instead of a List. A PriorityQueue is a neat data structure that guarantees that the front element is the one with lowest value. It has O(log(n)) complexity to add() and O(1) complexity to poll(). Technically, all the numbers need to be added to the PriorityQueue which is O(n log(n)) which means we are not necessarily improving the runtime complexity of the code but for very large k and n values I would expect it to help.
public Collection<Partition> improvedPartition(Collection<Integer> numbers, int partitionCount) {
List<Integer> numbersCopy = new ArrayList<>(numbers);
Collections.sort(numbersCopy, reverseOrder());
Queue<Integer> numberQueue = new ArrayDeque<>(numbersCopy);
PriorityQueue<Partition> partitionPriorityQueue =
new PriorityQueue<>(partitionCount, new PartitionComparator());
for (int i = 0; i < partitionCount; i++) {
partitionPriorityQueue.add(new Partition());
}
while (!numberQueue.isEmpty()) {
Integer number = numberQueue.poll();
Partition lowestSumPartition = partitionPriorityQueue.poll();
lowestSumPartition.increaseSum(number);
partitionPriorityQueue.add(lowestSumPartition);
}
return partitionPriorityQueue;
}
The code begins by copying the Collection into a List so that it may be sorted in descending order and then copied into a Queue. Again, the Queue copying is a code clarity preference, not a necessity. The sort is a big up front cost that will increase the complexity of the algorithm from O(n) to O(log(n) n). We hope that the improvement in the results is worth the increase in complexity.
List<Integer> numbersCopy = new ArrayList<>(numbers);
Collections.sort(numbersCopy, reverseOrder());
Queue<Integer> numberQueue = new ArrayDeque<>(numbersCopy);
Once again, the Partition instances are initialized but this time they are placed in a PriorityQueue. Each time add() is invoked the ordering of the PriorityQueue is shuffled guaranteeing that the head element is the smallest. Note that the PriorityQueue constructor takes a comparator so ‘smallest’ here is a function of that comparator.
PriorityQueue<Partition> partitionPriorityQueue =
new PriorityQueue<>(partitionCount, new PartitionComparator());
for (int i = 0; i < partitionCount; i++) {
partitionPriorityQueue.add(new Partition());
}
The last portion of the code handles the main algorithm. It again iterates over the Queue of numbers once. This time it relies on the PriorityQueue to get the lowest-sum Partition instance and it gets the highest spending number because it polls a sorted Queue. After the lowest-sum Partition instance has its sum increased it needs to be added back to the PriorityQueue so that it can reshuffle the lowest element to its head.
while (!numberQueue.isEmpty()) {
Integer number = numberQueue.poll();
Partition lowestSumPartition = partitionPriorityQueue.poll();
lowestSumPartition.increaseSum(number);
partitionPriorityQueue.add(lowestSumPartition);
}
I ran this code 1000 times. Each time I generated an array of 350 random numbers [0, 100) and partitioned them into 20 partitions. As a result:
the average sum across all the partitions was 17311.068
the average of the average sum for a partition was 864.486
the average maximum sum for a partition was 866.957
the average minimum sum for a partition was 865.097
the average difference between the maximum and minimum was 2.471
the average standard deviation of the sum across all partitions was 0.9889
The last two numbers in the list really show the impact the small improvement to the algorithm. The differences in the sums of each partition is very tiny and the complexity of the algorithm only increased to O(log(n) n).
Note that the randomization used for generating the number lists was uniform. I am curious about the effects that other distributions have on the average standard deviations of the improved algorithm.
The Wikipedia article shows a third algorithm that is ‘exact’ meaning that it finds the exact correct solution. This is a NP-hard problem so this solution will be very expensive and likely won’t be feasible for large number sequences and number of partitions so I didn’t include it in this article.