Mingze Gao

Centrifuge Problem

| 9 min read

Given a centrifuge with nn holes, can we balance it with kk (1kn1\le k \le n) identical test tubes?

This is a simple yet interesting problem, very well illustrated by Numberphile and discussed by Matt Baker's blog.

The now proved solution is that:

An example: balanced 18-hole centrifuge

18-hole-centrifuge

Another example: balanced 20-hole centrifuge

20-hole-centrifuge

Below is my attempt to programmatically answer the centrifuge problem.

Method 1: Naïve DFS

The very first method literally follows the solution. For a given (n,k)(n,k) pair, check if kk and nkn-k can be written as a linear combination of the prime divisors of nn (with non-negative coefficients).

def is_linear_combination(x: int, prime_numbers: list) -> bool:
    """Check if `x` can be written as a linear combination of prime numbers, i.e.,

    x = b1*p1 + b2*p2 + b3*p3 + ... + bn*pn

    where pi represents a prime number in `prime_numbers`, bi is a non-negative integer.
    """
    # very naive and not optimized
    for n in prime_numbers:
        # n divides x
        if x % n:
            return True
        # n does not divides x, check if the difference between x and multiples of n can be
        # a linear combination of other remaining prime numbers
        for i in range(x//n):
            if is_linear_combination(x - i*n, [p for p in prime_numbers if p!=n]):
                return True
    return False

def centrifuge_naive(n: int, k: int) -> bool:
    """Check if a `n`-hole centrifuge can be balanced with `k` identical test tubes.

    True if both `k` and `n-k` can be written as a linear combination of the prime divisors of `n`.
    """
    prime_divisors = get_prime_divisors(n) # simple cached function, skipped
    return is_linear_combination(k, prime_divisors) and is_linear_combination(n-k, prime_divisors)

Some Optimizations

The above method works just fine, but very slow if we want to compute the total number of solutions, instead of just checking whether a particular kk works.

There can be a few optimizations, for example, we can compute only the lower half of kk:

from functools import lru_cache

@lru_cache(maxsize=None)
def centrifuge_naive(n: int, k: int) -> bool:
    prime_divisors = get_prime_divisors(n) # cached
    if k > n//2:
        return centrifuge(n, n-k)
    return is_linear_combination(k, prime_divisors) and is_linear_combination(n-k, prime_divisors)

Further, if nn is a (large) prime number itself, we understand that all 1k<n1\le k\lt n will not work. Similarly, if nn is a power of prime number, we can bypass many values of kk too.

@lru_cache(maxsize=None)
def centrifuge_naive(n, k):
    prime_divisors = get_prime_divisors(n)
    # ...
    # special case when n is power of prime
    if len(prime_divisors) == 1:
        p = prime_divisors[0]
        return (k % p == 0) and ((n - k) % p == 0)
    # ...

At certain point, we will realize that it would be faster to simply compute all possible kks instead of checking one by one whether a certain kk can balance the centrifuge. This leads us to the second approach, which I call "bootstrap".

Method 2: Bootstrap

The bootstrap method is a variant of DFS, which essentially generates all possible kk for a given nn by exhausting the values from linear combinations of nn's prime divisors. The generated values should be between 2 and nn. Then we can tell if kk' can balance the nn-hole centrifuge by checking whether kk' and nkn-k' are in the generated values.

def bootstrap(x, n, numbers, result):
    """Compute all linear combinations of the given numbers smaller than n"""
    for p in numbers:
        if p+x > n:
            break
        for i in range((n-x) // p):
            p_ = x + p * i # p_ <= n
            if not result[p_]:
                # x + p*i has not been tested, and is a linear combination of given numbers
                result[p_] = True
                # check whether we can add multiples of remaining numbers
                bootstrap(p_, n, [n2 for n2 in numbers if n2 != p], result)

def centrifuge_bootstrap(n: int, k: int) -> bool:
    prime_divisors = get_prime_divisors(n) # cached, `prime_divisors` is sorted
    # result[k] represents whether k is valid, k=0...n
    result = [True] + [False] * (n-1) + [True]
    bootstrap(0, n, prime_divisors, result) # TODO: bootstrap only once for a given `n`
    return result[k] and result[n-k]

This method invests some time in pre-computing all possible linear combinations of the prime divisors of nn. If we are only interested to see a particular (n,k)(n,k) pair, we can break out when we have done result[k] and result[n-k] in bootstrap().

Method 3: Dynamic Programming

The last method uses dynamic programming.

We can use f[k]f[k]=True to represent that kk is a linear combination of nn's prime divisors. A value ii is either itself a prime divisor of nn (and thus a linear combination of the prime divisors), or the sum of a nn's prime divisor pp and (ip)(i-p). In the latter case, if (ip)(i-p) is a linear combination of nn's prime divisors, so is p+(ip)=ip+(i-p)=i.

Hence,

The boundary condition is f[0]f[0] = True, i.e., an empty centrifuge is balanced.

The whole function is extremely short:

def centrifuge_dp(n: int, k: int) -> bool:
    prime_divisors = get_prime_divisors(n) # cached, `prime_divisors` is sorted
    f = [True] + [False] * n
    for p in prime_divisors: # TODO: DP only once for a given `n`
        for i in range(p, n+1):
            f[i] = f[i] or f[i-p]
    return f[k] and f[n-k]

Performance Comparison

Obviously, the Method 2 and 3 are much faster than the naïve Method 1. Method 3 does not even use recursion and is the fastest.

Visualization

Below are some plots of balanced centrifuges. Note that for a particular value of kk, there can be more than one way to balance the centrifuge. Here, I illustrate only one.

"6-hole"

plot_centrifuge(6, "6-hole-centrifuge.svg")

6-hole-centrifuge

"10-hole"

plot_centrifuge(10, "10-hole-centrifuge.svg")

10-hole-centrifuge

"12-hole"

plot_centrifuge(12, "12-hole-centrifuge.svg")

12-hole-centrifuge

"12-hole"

plot_centrifuge(18, "18-hole-centrifuge.svg")

18-hole-centrifuge

"20-hole"

plot_centrifuge(20, "20-hole-centrifuge.svg")

20-hole-centrifuge

"24-hole"

plot_centrifuge(24, "24-hole-centrifuge.svg")

24-hole-centrifuge

"33-hole"

plot_centrifuge(33, "33-hole-centrifuge.svg")

33-hole-centrifuge

Python code

The code to generate the plots above:

from functools import lru_cache
import numpy as np
import matplotlib.pyplot as plt


@lru_cache(maxsize=None)
def prime_divisors(n):
    """Return list of n's prime divisors"""
    primes = []
    p = 2
    while p**2 <= n:
        if n % p == 0:
            primes.append(p)
            n //= p
        else:
            p += 1 if p % 2 == 0 else 2
    if n > 1:
        primes.append(n)
    return primes


def centrifuge(n):
    """Return a list of which the k-th element represents if k tubes can balance the n-hole centrifuge"""
    F = [True] + [False] * n
    for p in prime_divisors(n):
        for i in range(p, n + 1):
            F[i] = F[i] or F[i - p]
    return [F[k] and F[n - k] for k in range(n + 1)]


def factorize(k: int, nums: list) -> list:
    """Given k, return the list of numbers from the given numbers which add up to k.
    The given numbers are guaranteed to be able to generate k via a linear combination.

    Examples:
        >>> factorize(5, [2, 3])
        [2, 3]
        >>> factorize(6, [2, 3])
        [2, 2, 2]
        >>> factorize(7, [2, 3])
        [2, 2, 3]
    """

    def _factorize(k, nums, res: list):
        for p in nums:
            if k % p == 0:
                res.extend([p] * (k // p))
                return True
            else:
                for i in range(1, k // p):
                    if _factorize(k - p * i, [n for n in nums if n != p], res):
                        res.extend([p] * i)
                        return True
        return False

    res = []
    _factorize(k, nums, res)
    return res


@lru_cache(maxsize=None)
def centrifuge_k(n, k):
    """Given (n, k) and that k balances a n-hole centrifuge, find the positions of k tubes"""
    if n == k:
        return [True] * n
    factors = factorize(k, prime_divisors(n))
    pos = [False] * n

    def c(factors: list, pos: list) -> bool:
        if sum(pos) == k:
            return True
        if not factors:
            return False
        p = factors.pop(0)
        pos_wanted = [n // p * i for i in range(p)]
        for offset in range(n):
            pos_rotated = [(i + offset) % n for i in pos_wanted]
            # the intended positions of the p tubes are all available
            if not any(pos[i] for i in pos_rotated):
                # claim the positions
                for i in pos_rotated:
                    pos[i] = True
                if not c(factors, pos):
                    # unclaim the positions
                    for i in pos_rotated:
                        pos[i] = False
                else:
                    return True
        # all rotated positions failed, add p back to factors to place later
        factors.append(p)

    c(factors, pos)
    return pos


def plot_centrifuge(n, figname="centrifuge.svg"):
    ncols = max(int(n**0.5), 1)  # minimum 1 column
    nrows = n // ncols if n % ncols == 0 else n // ncols + 1
    height = 3 if nrows == ncols else 2
    width = 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(height * nrows, width * ncols))
    z = np.exp(2 * np.pi * 1j / n)

    theta = np.linspace(0, 2 * np.pi, 20)
    radius = 1 / (ncols + nrows)
    a = radius * np.cos(theta)
    b = radius * np.sin(theta)

    cent = centrifuge(n)
    for nr in range(nrows):
        for nc in range(ncols):
            k = nr * ncols + nc + 1
            axis = axes[nr, nc] if ncols > 1 else axes[nr]
            if k > n:
                axis.axis("off")
                continue
            # draw the n-holes
            for i in [z**i for i in range(n)]:
                axis.plot(a + i.real, b + i.imag, color="b" if cent[k] else "gray")
            # draw the k tubes
            if cent[k]:
                if k > n // 2:
                    pos = [not b for b in centrifuge_k(n, n - k)]
                else:
                    pos = centrifuge_k(n, k)
                for i, ok in enumerate(pos):
                    i = z**i
                    if ok:
                        axis.fill(a + i.real, b + i.imag, color="r")

            axis.set_aspect(1)
            axis.set(xticklabels=[], yticklabels=[])
            axis.set(xlabel=None)
            axis.set_ylabel(f"k={k}", rotation=0, labelpad=10)
            axis.tick_params(bottom=False, left=False)

    fig.suptitle(f"$k$ Test Tubes to Balance a {n}-Hole Centrifuge")
    fig.text(0.1, 0.05, "Red dot represents the position of test tubes.")
    plt.savefig(figname)
    plt.close(fig)


if __name__ == "__main__":
    for n in range(6, 51):
        print(f"Balancing {n}-hole centrifuge...")
        plot_centrifuge(n, f"{n}-hole-centrifuge.png")