import ternary
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import dirichlet
from scipy.special import gamma, digamma
KL-Divergence of the Dirichlet Distribution
Work
Compositional
Useful
Breakin’ it down
def get_comparison_axes():
= plt.subplots(ncols=2, figsize=(12, 6))
_, (ax1, ax2) return (
=ax1),
get_ternary_axes_for_dirichlet(ax=ax2)
get_ternary_axes_for_dirichlet(ax
)
def get_ternary_axes_for_dirichlet(ax):
## Boundary and Gridlines
= 30
scale = ternary.figure(ax=ax, scale=scale)
figure, tax
# Draw Boundary and Gridlines
=1.5)
tax.boundary(linewidth="black", multiple=6)
tax.gridlines(color="blue", multiple=2, linewidth=0.5)
tax.gridlines(color
# Set Axis labels and Title
= 12
fontsize = 0.14
offset "Dirichlet Distribution\n", fontsize=fontsize)
tax.set_title("$\\alpha_1$", fontsize=fontsize, offset=offset)
tax.left_axis_label("$\\alpha_2$", fontsize=fontsize, offset=offset)
tax.right_axis_label("$\\alpha_3$", fontsize=fontsize, offset=offset)
tax.bottom_axis_label(
# Background color
="whitesmoke", alpha=0.7) # the detault, essentially
tax.set_background_color(color
# Remove default Matplotlib Axes
tax.clear_matplotlib_ticks()'off')
tax.get_axes().axis(
return tax
def dirch(p, alphas):
"""Computes the Shannon Entropy at a distribution in the simplex."""
if min(p) == 0:
= (np.array(p) + 0.01)
new_ return dirch(new_ / np.sum(new_), alphas)
return dirichlet.pdf(p, alphas)
def KL(alphas_1, alphas_2):
# Magnitude comparison term
= gamma(np.sum(alphas_1))
sum_1 = gamma(np.sum(alphas_2))
sum_2 = np.log(sum_1 / sum_2)
term_1
# Orientation comparison term
= np.sum(
term_2
np.log(
gamma(alphas_1)/ gamma(alphas_2)
)
)
# Weighted orientation comparison term
= np.sum(
term_3 - alphas_2)
(alphas_1 * (digamma(alphas_1) - np.sum(digamma(alphas_2)))
)
return term_1, term_2, term_3
= get_comparison_axes()
tax1, tax2 = np.array([0.99, 0.8, 0.99])
alphas1 = np.array([0.99, 1.1, 0.99])
alphas2 lambda p: dirch(p, alphas1), boundary=True, style='h')
tax1.heatmapf(lambda p: dirch(p, alphas2), boundary=True, style='h')
tax2.heatmapf( KL(alphas1, alphas2)
(-0.2682430037849523, 0.20193211965967683, -0.1938956945448387)
= np.array([0.99, 0.99, 0.99])
alphas1 = np.array([0.8, 0.8, 0.8])
alphas2 := KL(alphas1, alphas2))
display(out print(sum(out))
print(sum(out[1:]))
print(sum(out[1:])/sum(out))
(0.4487827459423686, -0.43861461490538106, 1.311706455755807)
1.3218745867927946
0.8730918408504258
0.6604952160921481
for i in range(10):
= np.array([0.99, 0.99, 0.99])
alphas1 = np.array([i/10+0.01]*3)
alphas2 = KL(alphas1, alphas2)
out print(sum(out[1:])/sum(out))
1.0032516022811897
1.0049366015508523
0.9895612860037356
0.9633624172613837
0.9272040376306717
0.8803714691968922
0.820979045835589
0.7458099235921262
0.6497473905563662
0.5245956186761072
for i in range(10):
= np.array([0.99, 0.99, 0.99])
alphas1 = np.array([0.99, i/10+0.01, 0.99])
alphas2 = KL(alphas1, alphas2)
out print(sum(out[1:])/sum(out))
0.9929644665475431
0.9150841636416476
0.8359960393617754
0.7611816416699346
0.6920482952977239
0.6287613014304215
0.5710404581911143
0.5184542687327516
0.47053741788439
0.42683850783756333
for i in range(10):
= np.array([0.99, 0.99, 0.99])
alphas1 = np.array([i/10+0.01]*3)
alphas2 print(
@ alphas2)
(alphas1 / (np.linalg.norm(alphas1) * np.linalg.norm(alphas2))
)
1.0000000000000002
1.0
1.0000000000000002
1.0000000000000002
1.0000000000000002
1.0
1.0000000000000002
1.0
1.0000000000000002
1.0000000000000002
for i in range(10):
= np.array([0.99, 0.99, 0.99])
alphas1 = np.array([0.99, i/10+0.01, 0.99])
alphas2 print(
@ alphas2)
(alphas1 / (np.linalg.norm(alphas1) * np.linalg.norm(alphas2))
)
0.8205993697788486
0.8592097001276557
0.8931041846101601
0.922001477953326
0.9458469136837342
0.9647902152478092
0.9791448197419074
0.9893383956554502
0.9958634776150151
0.999234605291067
= 1
change = 100
total for i in range(10):
= np.array([0.99] * total)
alphas1 = np.array([i/10+0.01] * change + [0.99] * (total-change))
alphas2 print(
@ alphas2)
(alphas1 / (np.linalg.norm(alphas1) * np.linalg.norm(alphas2))
)
0.9950884433038646
0.996042042598324
0.9968928178048972
0.9976406076930077
0.9982853465829943
0.9988270641922016
0.9992658852368395
0.9996020287923161
0.9998358074159202
0.9999676260368984