Optimization on the Stiefel Manifold
ActiveThe topic of my math senior thesis and its Python implementation. Gradient-based optimization over the Stiefel manifold - the set of orthonormal matrices - using a descent algorithm on the tangent space to minimize arbitrary functions subject to orthogonality constraints.
The Stiefel Manifold
The Stiefel manifold is the set of orthonormal matrices - equivalently, ordered -tuples of mutually orthogonal unit vectors in :
It is a compact embedded submanifold of of dimension , and carries the structure of a homogeneous space .
Special cases include:
- : the unit sphere
- : the orthogonal group
The Descent Algorithm
To minimize a smooth function , the algorithm follows a descent curve that stays on the manifold at every step. The curve is defined via the Cayley transform:
where is a skew-symmetric matrix constructed from the naive gradient . The curve satisfies for all , so it never leaves .
For efficiency at large , the Sherman–Morrison–Woodbury identity reduces the inversion of an matrix to a one:
where and .
The step size is halved at each iteration to ensure convergence.
Tangent Space
The tangent space at a point consists of all matrices satisfying:
A basis is constructed by extending to a full orthonormal basis of . Let be the complementary columns (obtained via SVD). Every tangent vector can be written as:
where and is arbitrary. This gives basis vectors, matching the manifold dimension.
def T_A(A):
k = A.shape[1]
n = A.shape[0]
A_p, _, _ = np.linalg.svd(A)
A_p = A_p[:, k:] # complement: shape (n, n-k)
# basis for the B part: standard basis matrices of shape (n-k, k)
B = []
for i in range((n - k) * k):
b_i = np.zeros((n - k, k))
b_i[i // k, i % k] = 1
B.append(b_i)
# basis for the C part: skew-symmetric k×k matrices
C = skew_symmetric_base(k)
TB = []
for b in B:
TB.append(A @ np.zeros((k, k)) + A_p @ b)
for c in C:
TB.append(A @ c + A_p @ np.zeros((n - k, k)))
return C, B, TB, A_pThe tangent condition is verified for every basis vector:
for b in TB:
print(np.around(b.T @ A + A.T @ b, decimals=8))
# [[0.]] for each basis elementUnder the Euclidean inner product (where , ), the Gram matrix of the constructed basis is the identity - confirming orthonormality.
def euclidean_inner_product(Z1, Z2, A, A_p):
C1, C2 = A.T @ Z1, A.T @ Z2
B1, B2 = A_p.T @ Z1, A_p.T @ Z2
return np.trace(C1.T @ C2) + np.trace(B1.T @ B2)
# Gram matrix is identity:
# [[ 1. -0.]
# [-0. 1.]]The notebook also implements the canonical (non-isotropic) inner product, which weights the component by :
def cannonical_inner_product(Z1, Z2, A, A_p):
C1, C2 = A.T @ Z1, A.T @ Z2
B1, B2 = A_p.T @ Z1, A_p.T @ Z2
return np.trace(C1.T @ C2) / 2.0 + np.trace(B1.T @ B2)Cayley Transform
The Cayley map sends any skew-symmetric matrix to an orthogonal one. For each basis element of , the transform produces a rotation (determinant 1):
def cayley(A):
n = A.shape[0]
return (sym.eye(n) + A).inv() @ sym.Matrix(sym.eye(n) - A)
# For the 4x4 case, each basis element C gives cayley(C) with det = 1.0This is the key property exploited by the descent - the curve stays on for all .
Implementation
The gradient matrix is computed symbolically, then evaluated at the current point:
def generate_G(f, n, k):
G = f(n, k)
dG = np.empty((n, k), dtype=type(G))
for i in range(n):
for j in range(k):
dG[i, j] = sym.diff(G, sym.Symbol(f'a{i+1}{j+1}'))
return dG
def get_M(func, A, n, k):
G = get_G(func, A, n, k).evalf()
return G @ A.T - A @ G.TThe two descent variants - naive (inverting ) and optimized (inverting via Woodbury):
def descent(t, M, A):
return cayley(t / 2 * M) @ A
def descent_UV(t, U, V, A, k):
return A - t * U @ np.linalg.inv(np.eye(2*k) + t/2 * V.T @ U) @ V.T @ Awhere and are matrices:
def get_UV(func, A, n, k):
G = get_G(func, A, n, k).evalf()
U = np.concatenate((np.array(G), np.array(A)), axis=1).astype(np.float32)
V = np.concatenate((np.array(A), -np.array(G)), axis=1).astype(np.float32)
return U, VResults
Example 1 - , ,
The known minimum is . Starting from a random point on with , the algorithm converges in approximately 10 iterations.
Example 2 - , ,
A larger, non-trivial function. The algorithm converges in roughly 20 iterations - fast given the scale.
Example 3 - , (optimization on )
The descent path lies visibly on the sphere, confirming that the curve stays on the manifold throughout optimization.