tak0kadaの何でもノート

発声練習、生存確認用。

医学関連は 医学ノート

ディリクレ分布を立体的に表示する

ディリクレ分布をRとpythonで描いてみるの続き。今回はRで書くのがしんどかったのでpythonのみ。

from __future__ import division
from scipy.special import gammaln
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import  matplotlib.patches as patches
from matplotlib.path import Path

# ディリクレ分布確率密度
def dirichlet_pdf(xs, alphas):
    xs, alphas = np.array(xs), np.array(alphas)
    if np.any(xs < 0) or np.any(xs > 1):
        return 0
    else:
        ibeta = np.exp(gammaln(alphas.sum()) - gammaln(alphas).sum())
        return ibeta * (xs ** (alphas - 1)).prod()

fig = plt.figure()
ax = fig.gca(projection = '3d')

# ディリクレ分布
X = np.arange(0, 1.01, 0.01)
Y = np.arange(0, 1.01, 0.01)
X, Y = np.meshgrid(X, Y)
x1 = 1 - X - 3**(1/2) * Y / 3
x2 = X - 3**(1/2) * Y / 3
x3 = 2 * 3**(1/2) * Y / 3
Z = np.array(
    [[dirichlet_pdf((x1[i, j], x2[i, j], x3[i, j]), (11, 21, 31))
        for i in range(101)]
        for j in range(101)]
    )
# プロットの軸がなぜかずれるのでXとYの順番を入れ替えた
surf = ax.plot_surface(Y, X, Z, rstride=1, cstride=1, cmap=mpl.cm.coolwarm,
        linewidth=0, antialiased=False, alpha=0.6)

# 正三角形
verts = [(0, 0, 0), (1, 0, 0), (0.5, 3**(1/2)/2, 0), (0, 0, 0)]
for i in range(3):
    x = verts[i][0] + (verts[i+1][0] - verts[i][0]) * np.arange(0, 1.01, 0.01)
    y = verts[i][1] + (verts[i+1][1] - verts[i][1]) * np.arange(0, 1.01, 0.01)
    z = verts[i][2] + (verts[i+1][2] - verts[i][2]) * np.arange(0, 1.01, 0.01)
    ax.plot(x, y, z, color="black", linewidth=1)

# 軸を表示しない
ax.set_axis_off()

plt.show()

f:id:kuyata:20141114031209p:plain