tak0kadaの何でもノート

発声練習、生存確認用。

医学関連は 医学ノート

ディリクレ分布をRとpythonで描いてみる

ディリクレ分布

$$ p(\mathbf{x}, \mathbf{\alpha}) = \dfrac{1}{B(\mathbf{\alpha})}\prod x_{i}^{a_{i}-1} $$ $$ B(\mathbf{\alpha}) = \dfrac{\prod(\Gamma a_i)}{\Gamma(\sum a_{i})} $$ ベータ分布と同じく式の中にベータ関数が含まれていて、多次元ベータ分布と呼ばれている。ディリクレ分布の表示に使う分布は多次元のベータ関数になる。

  • R
library("rgl")
library("MCMCpack")

n_dev <- 100
alphas <- list(rep(1, 3), c(3, 7, 5), c(6, 2, 6), c(60, 20, 60))

dirichlet3d.density <- function(vec_1, vec_2, alpha){
  f <- function(x, y) ddirichlet(c(x, y, 1-x-y), alpha)
  mapply(f, vec_1, vec_2)
}

for (alpha in alphas){
  x <- y <- seq(0, 1, length=n_dev)
  # mesh
  z <- outer(x, y, function(x, y){dirichlet3d.density(x, y, alpha)})
  z[z==NaN] <- 0
  image(x, y, z, col=colorRampPalette(c("white", "yellow", "red"), space = "rgb")(10))
  contour(z, add = TRUE, drawlabels = FALSE, nlevels=5)
  polygon(c(0,1,0), c(0,0,1), border="black", lwd=3)
}

f:id:kuyata:20140617225519j:plain

f:id:kuyata:20140617225526j:plain

f:id:kuyata:20140617225532j:plain

f:id:kuyata:20140617225540j:plain

4つ作った画像が一列に並ぶのダサい。

from scipy.special import gammaln
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.path import Path

def dirichlet_pdf(xs, alphas):
    if np.any(xs < 0) or np.any(xs > 1):
        return 0
    else:
        ibeta = np.exp(gammaln(alphas.sum()) - gammaln(alphas).sum())
        return ibeta * (xs ** (alphas - 1)).prod()

n_dev = 100
alphas = np.array([[1, 1, 1], [3, 7, 5], [6, 2, 6], [60, 20, 60]])

def plotframe():
    verts = [(0, 0), (1, 0), (0, 1), (0, 0)]
    codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
    path = Path(verts, codes)
    patch = patches.PathPatch(path, edgecolor='black', facecolor='none', linewidth=3)
    ax.add_patch(patch)

def plotcontour(x, y, z, cmap):
    CS = ax.contourf(X, Y, Z, 20, cmap=cmap)
    #contour line
    CS1 = ax.contour(CS, levels=CS.levels[::4], colors='black')
    #label the contour
    ax.clabel(CS1, inline=1, fontsize=8)

#colormap
colors = [(1, 1, 1)]
colors.extend(mpl.cm.autumn_r(np.linspace(0, 1, 10)))
cmap = mpl.colors.ListedColormap(colors)

x = y = np.linspace(0, 1, num=n_dev)
X, Y = np.meshgrid(x, y)
for i, alpha in enumerate(alphas):
    ax = plt.subplot(int("22"+str(i+1)))
    Z = np.array([
        [dirichlet_pdf(np.array([i, j, 1-i-j]), alpha) for i in x]
        for j in y
        ])
    plotcontour(X, Y, Z, cmap)
    plotframe()
    ax.set_title(str(alpha))

plt.show()

コードがRより圧倒的に長い(;´Д`)。cmapはjet_rなどとするとカラーマップの配置を逆にすることができる。あとcmap.set_over()とかも便利。

f:id:kuyata:20140619005213p:plain

  • 参考

ディリクレ分布まとめ
dirichlet.py
stackoverflow: カラーマップの設定
ufunc備忘録: pythonのouter関数が作れるのは積のみなのでリスト内包で対応する。ufuncの関数たちにはnp.add.outer(a, b)などとする使い方がある。