← Projects

Optimization on the Stiefel Manifold

Active

The topic of my math senior thesis and its Python implementation. Gradient-based optimization over the Stiefel manifold - the set of k×nk \times n orthonormal matrices - using a descent algorithm on the tangent space to minimize arbitrary functions subject to orthogonality constraints.

PythonNumPyMatplotlibJupyter

The Stiefel Manifold

The Stiefel manifold Vk(Rn)V_k(\mathbb{R}^n) is the set of k×nk \times n orthonormal matrices - equivalently, ordered kk-tuples of mutually orthogonal unit vectors in Rn\mathbb{R}^n:

Vk(Rn)={AMatn×k(R)AA=Ik}V_k(\mathbb{R}^n) = \{ A \in \text{Mat}_{n \times k}(\mathbb{R}) \mid A^\top A = I_k \}

It is a compact embedded submanifold of Rnk\mathbb{R}^{nk} of dimension nkk(k+1)2nk - \frac{k(k+1)}{2}, and carries the structure of a homogeneous space On/Onk\mathcal{O}_n / \mathcal{O}_{n-k}.

Special cases include:

The Descent Algorithm

To minimize a smooth function f:Vk(Rn)Rf : V_k(\mathbb{R}^n) \to \mathbb{R}, the algorithm follows a descent curve that stays on the manifold at every step. The curve is defined via the Cayley transform:

γ(τ)=(I+τ2M)1(Iτ2M)A\gamma(\tau) = \left(I + \frac{\tau}{2} M\right)^{-1}\left(I - \frac{\tau}{2} M\right) A

where M=GAAGM = GA^\top - AG^\top is a skew-symmetric matrix constructed from the naive gradient G=(FAij)G = \left(\frac{\partial F}{\partial A_{ij}}\right). The curve satisfies γ(τ)γ(τ)=Ik\gamma(\tau)^\top \gamma(\tau) = I_k for all τ\tau, so it never leaves Vk(Rn)V_k(\mathbb{R}^n).

For efficiency at large n,kn, k, the Sherman–Morrison–Woodbury identity reduces the inversion of an n×nn \times n matrix to a 2k×2k2k \times 2k one:

γ(τ)=AτU(I+τ2VU)1VA\gamma(\tau) = A - \tau U \left(I + \frac{\tau}{2} V^\top U\right)^{-1} V^\top A

where U=[G    A]U = [G \;\; A] and V=[A    G]V = [A \;\; {-G}].

The step size τ\tau is halved at each iteration to ensure convergence.

Tangent Space

The tangent space TAVk(Rn)T_A V_k(\mathbb{R}^n) at a point AVk(Rn)A \in V_k(\mathbb{R}^n) consists of all matrices ZMatn×k(R)Z \in \text{Mat}_{n \times k}(\mathbb{R}) satisfying:

ZA+AZ=0Z^\top A + A^\top Z = 0

A basis is constructed by extending AA to a full orthonormal basis of Rn\mathbb{R}^n. Let AMatn×(nk)(R)A_\perp \in \text{Mat}_{n \times (n-k)}(\mathbb{R}) be the complementary columns (obtained via SVD). Every tangent vector can be written as:

Z=AC+ABZ = A \cdot C + A_\perp \cdot B

where CSkewSymkC \in \text{SkewSym}_k and BMat(nk)×kB \in \text{Mat}_{(n-k) \times k} is arbitrary. This gives k(k1)2+(nk)k=nkk(k+1)2\frac{k(k-1)}{2} + (n-k)k = nk - \frac{k(k+1)}{2} basis vectors, matching the manifold dimension.

def T_A(A):
    k = A.shape[1]
    n = A.shape[0]
    A_p, _, _ = np.linalg.svd(A)
    A_p = A_p[:, k:]          # complement: shape (n, n-k)
 
    # basis for the B part: standard basis matrices of shape (n-k, k)
    B = []
    for i in range((n - k) * k):
        b_i = np.zeros((n - k, k))
        b_i[i // k, i % k] = 1
        B.append(b_i)
 
    # basis for the C part: skew-symmetric k×k matrices
    C = skew_symmetric_base(k)
 
    TB = []
    for b in B:
        TB.append(A @ np.zeros((k, k)) + A_p @ b)
    for c in C:
        TB.append(A @ c + A_p @ np.zeros((n - k, k)))
    return C, B, TB, A_p

The tangent condition is verified for every basis vector:

for b in TB:
    print(np.around(b.T @ A + A.T @ b, decimals=8))
# [[0.]]  for each basis element

Under the Euclidean inner product Z1,Z2=tr(C1C2)+tr(B1B2)\langle Z_1, Z_2 \rangle = \text{tr}(C_1^\top C_2) + \text{tr}(B_1^\top B_2) (where Ci=AZiC_i = A^\top Z_i, Bi=AZiB_i = A_\perp^\top Z_i), the Gram matrix of the constructed basis is the identity - confirming orthonormality.

def euclidean_inner_product(Z1, Z2, A, A_p):
    C1, C2 = A.T @ Z1, A.T @ Z2
    B1, B2 = A_p.T @ Z1, A_p.T @ Z2
    return np.trace(C1.T @ C2) + np.trace(B1.T @ B2)
 
# Gram matrix is identity:
# [[ 1. -0.]
#  [-0.  1.]]

The notebook also implements the canonical (non-isotropic) inner product, which weights the CC component by 12\frac{1}{2}:

def cannonical_inner_product(Z1, Z2, A, A_p):
    C1, C2 = A.T @ Z1, A.T @ Z2
    B1, B2 = A_p.T @ Z1, A_p.T @ Z2
    return np.trace(C1.T @ C2) / 2.0 + np.trace(B1.T @ B2)

Cayley Transform

The Cayley map C(I+C)1(IC)C \mapsto (I + C)^{-1}(I - C) sends any skew-symmetric matrix to an orthogonal one. For each basis element of SkewSymm\text{SkewSym}_m, the transform produces a rotation (determinant 1):

def cayley(A):
    n = A.shape[0]
    return (sym.eye(n) + A).inv() @ sym.Matrix(sym.eye(n) - A)
 
# For the 4x4 case, each basis element C gives cayley(C) with det = 1.0

This is the key property exploited by the descent - the curve γ(τ)=cayley(τ2M)A\gamma(\tau) = \text{cayley}(\frac{\tau}{2} M) \cdot A stays on Vk(Rn)V_k(\mathbb{R}^n) for all τ\tau.

Implementation

The gradient matrix GG is computed symbolically, then evaluated at the current point:

def generate_G(f, n, k):
    G = f(n, k)
    dG = np.empty((n, k), dtype=type(G))
    for i in range(n):
        for j in range(k):
            dG[i, j] = sym.diff(G, sym.Symbol(f'a{i+1}{j+1}'))
    return dG
 
def get_M(func, A, n, k):
    G = get_G(func, A, n, k).evalf()
    return G @ A.T - A @ G.T

The two descent variants - naive (inverting n×nn \times n) and optimized (inverting 2k×2k2k \times 2k via Woodbury):

def descent(t, M, A):
    return cayley(t / 2 * M) @ A
 
def descent_UV(t, U, V, A, k):
    return A - t * U @ np.linalg.inv(np.eye(2*k) + t/2 * V.T @ U) @ V.T @ A

where U=[G    A]U = [G \;\; A] and V=[A    G]V = [A \;\; {-G}] are n×2kn \times 2k matrices:

def get_UV(func, A, n, k):
    G = get_G(func, A, n, k).evalf()
    U = np.concatenate((np.array(G), np.array(A)), axis=1).astype(np.float32)
    V = np.concatenate((np.array(A), -np.array(G)), axis=1).astype(np.float32)
    return U, V

Results

Example 1 - n=4n = 4, k=2k = 2, F(A)=i,jijaij2F(A) = \sum_{i,j} i \cdot j \cdot a_{ij}^2

The known minimum is F=4F = 4. Starting from a random point on V2(R4)V_2(\mathbb{R}^4) with τ=4\tau = 4, the algorithm converges in approximately 10 iterations.

Example 2 - n=17n = 17, k=11k = 11, F(A)=i,jijsin(πaij)i+jF(A) = \sum_{i,j} i \cdot j \cdot \sin(\pi a_{ij})^{i+j}

A larger, non-trivial function. The algorithm converges in roughly 20 iterations - fast given the scale.

Example 3 - n=3n = 3, k=1k = 1 (optimization on S2S^2)

The descent path lies visibly on the sphere, confirming that the curve stays on the manifold throughout optimization.