Visualising the Wishart Distribution

This semester, I have been teaching Multivariate Statistics, and one isse that comes up is how to think about the Wishart distribution, which generalises the χ2 distribution. This has probability density function

fχ2k(x)=12Γ(k2)(x2)k21ex2.

To visualise this at different degrees of freedom k, we use the following Python code.

%matplotlib inline
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
np.random.seed(42)
Sig1 = np.array([[1,0.75],[0.75,2]])
n=1000
x = np.linspace(0, 10, 1000)
plt.figure(figsize=(8,4))
colors = plt.cm.Accent(np.linspace(0,1,4))
i=0
for k in (1,2,5,10):
    c = st.chi2.pdf(x, k);
    plt.plot(x,c, label='k = ' + str(k), color=colors[i],linewidth=3)
    i += 1
plt.xlim((0,10))
plt.ylim((0,0.5))
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
../output_2_1.png

The Wishart emerges when we make k observations each with p variables; the sampling distribution for the covariance matrix of these data if the population is multivariate normal is Wishart and has probability density function

fWp(k,Σ)(V)=det(V)kp122kp2det(Σ)k2Γp(k2)exp(12Tr(Σ1V)).

In general, this cannot be visualised, but for the p=2 case, one option is to pick random numbers from this distribution and then produce a three-dimensional scatter plot of these. The code and results below show that this can be seen as behaving somewhat like the chi-squared distribution that it generalises as the degrees of freedom k are increased.

x=np.zeros(n)
y=np.zeros(n)
z=np.zeros(n)
for i in range(0,n):
    M=st.wishart.rvs(2,scale=Sig1,size=1)
    x[i]=M[0][0]
    y[i]=M[1][1]
    z[i]=M[1][0]
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x,y,z, marker='o', c=z, cmap='seismic')
ax.set_xlabel('V11')
ax.set_ylabel('V22')
ax.set_zlabel('V12=V21')
ax.set_xlim((0,30))
ax.set_ylim((0,70))
ax.set_zlim((0,30))
plt.tight_layout()
../output_3_0.png
x=np.zeros(n)
y=np.zeros(n)
z=np.zeros(n)
for i in range(0,n):
    M=st.wishart.rvs(5,scale=Sig1,size=1)
    x[i]=M[0][0]
    y[i]=M[1][1]
    z[i]=M[1][0]
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x,y,z, marker='o', c=z, cmap='seismic')
ax.set_xlabel('V11')
ax.set_ylabel('V22')
ax.set_zlabel('V12=V21')
ax.set_xlim((0,30))
ax.set_ylim((0,70))
ax.set_zlim((0,30))
plt.tight_layout()
../output_4_0.png
x=np.zeros(n)
y=np.zeros(n)
z=np.zeros(n)
for i in range(0,n):
    M=st.wishart.rvs(10,scale=Sig1,size=1)
    x[i]=M[0][0]
    y[i]=M[1][1]
    z[i]=M[1][0]
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x,y,z, marker='o', c=z, cmap='seismic')
ax.set_xlabel('V11')
ax.set_ylabel('V22')
ax.set_zlabel('V12=V21')
ax.set_xlim((0,30))
ax.set_ylim((0,70))
ax.set_zlim((0,30))
plt.tight_layout()
../output_5_0.png