0
JAX 的表現出乎所有人的意料,在極端情況下,最大性能可提高 20 倍。由于 JAX 的 JIT 編譯開銷,Numpy 在少樣本、少量鏈的情況下會勝出。我報告了 tensorflow probability (TFP) 的結果,但請記住,這種比較是不公平的,因為它實現的隨機游走 metroplis 比我們的包含更多的功能。
重現結果所需的代碼可以在這里找到。使代碼運行得更快的技巧值得學習。
矢量化 MCMC
Colin Carroll 最近發布了一篇有趣的博文,使用 Numpy 和隨機游走 metropolis 算法 (RWMH) 的矢量化版本來生成大量的樣本,同時運行多個鏈以便對算法的收斂性進行后驗檢驗。這通常是通過在多線程機器上每個線程運行一個鏈來實現的,在 Python 中使用 joblib 或自定義后端。這么做很麻煩,但它能完成任務。
Colin 的 文章讓我感到非常興奮,因為我可以在幾乎不增加成本的情況下,同時對成千上萬的鏈進行取樣。他在文章中詳細介紹了幾個這一方法的應用,但我有一種直覺,它可以完成更多的事情。
大約在同一時間,我偶然發現了 JAX。JAX 在概率編程語言環境中似乎很有趣,原因如下:
在大多數情況下,它完全可以替代 Numpy;
Autodiff 很簡單;
它的正向微分模式使得計算高階導數變得容易;
JAX 使用 XLA 執行 JIT 編譯,即使在 CPU 上也可以加速代碼的運行;
使用 GPU 和 TPU 非常簡單;
這是一個偏好問題,但它更傾向于函數式編程。
在開始使用 JAX 實現一個框架之前,我想做一些基準測試,以了解我要注冊的是什么。這里我將進行比較:
Numpy
Jax
Tensorflow Probability (TFP)
XLA 編譯的 Tensorflow Probability
關于基準測試
在給出結果之前,首先需要聲明的是:
報告的時間是在我的筆記本電腦上運行 10 次的平均值,除了終端打開外,沒有任何其它操作。除了編譯后的 JAX 運行外,所有運行的時間都是使用 hyperfine 命令行工具測量的。
我的代碼可能不是最優的,對于 TFP 來說尤其如此。
實驗是在 CPU 上進行的。JAX 和 TFP 可以運行在 GPU/TPU 上,所以可以期待額外的加速。
對于 Numpy 和 JAX 來說,采樣器是一個生成器,樣本不保存在內存中但對 TFP 來說并非如此,因此在大型實驗期間,計算機會耗盡內存。如果 TFP 沒有在堆棧上預先分配內存,不斷地分配內存也會影響性能。
在概率編程中重要的度量是每秒有效采樣的數量,而不是每秒采樣數量,前者后者更像是你使用的算法。這個基準測試仍然可以很好地反映不同框架的原始性能。
設置和結果
我在對一個含有 4 個分量的任意高斯混合樣本進行采樣。使用 Numpy:
import numpy as np
from scipy.stats import norm
from scipy.special import logsumexp
def mixture_logpdf(x):
loc = np.array([[-2, 0, 3.2, 2.5]]).T
scale = np.array([[1.2, 1, 5, 2.8]]).T
weights = np.array([[0.2, 0.3, 0.1, 0.4]]).T
log_probs = norm(loc, scale).logpdf(x)
return -logsumexp(np.log(weights) - log_probs, axis=0)
Numpy
Colin Carroll 的 MiniMC 是我見過的最簡單、最易讀的大都市隨機游走 Metropolis 和 Hamiltonian Monte Carlo 的實現。我的 Numpy 實現是他的一個迭代:
import numpy as np
def rw_metropolis_sampler(logpdf, initial_position):
position = initial_position
log_prob = logpdf(initial_position)
yield position
while True:
move_proposals = np.random.normal(0, 0.1, size=initial_position.shape)
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = np.log(np.random.rand(initial_position.shape[0], initial_position.shape[1]))
do_accept = log_uniform < proposal_log_prob - log_prob
position = np.where(do_accept, proposal, position)
log_prob = np.where(do_accept, proposal_log_prob, log_prob)
yield position
JAX
JAX 的實現與 Numpy 非常相似:
from functools import partial
import jax
import jax.numpy as np
@partial(jax.jit, static_argnums=(0, 1))
def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
move_proposals = jax.random.normal(rng_key, shape=position.shape) * 0.1
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = np.log(jax.random.uniform(rng_key, shape=position.shape))
do_accept = log_uniform < proposal_log_prob - log_prob
position = np.where(do_accept, proposal, position)
log_prob = np.where(do_accept, proposal_log_prob, log_prob)
return position, log_prob
def rw_metropolis_sampler(rng_key, logpdf, initial_position):
position = initial_position
log_prob = logpdf(initial_position)
yield position
while True:
position, log_prob = rw_metropolis_kernel(rng_key, logpdf, position, log_prob)
yield position
如果你熟悉 Numpy,那么你應該非常熟悉它的語法。JAX 和它有一些不同之處:
jax.numpy 充當 numpy 的替代。對于只涉及數組操作的函數,用 import jax.numpy as np 替換 import numpy as np,這會給你帶來性能上的提升。
JAX 處理隨機數生成的方式與其他 Python 包不同,這是有原因的 (請閱讀這篇文章:https://github.com/google/jax/blob/master/design_notes/prng.md ) 。每個發行版都以一個 PRNG 鍵作為輸入。
因為 JAX 不能編譯生成器,我從采樣器中提取內核。因此,我們提取并 JIT 完成所有繁重工作的函數:rw_metropolis_kernel。
我們需要對 JAX 的編譯器提供一點幫助,即指出當函數多次運行時哪些參數不會改變:@partial(jax.jit, argnums=(0, 1))。如果將函數作為參數傳遞,這是必需的,并且可以啟用進一步的編譯時優化。
Tensorflow Probability
對于 TFP,我們使用庫中實現的隨機游走 Metropolis 算法:
from functools import partial
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def run_raw_metropolis(n_dims, n_samples, n_chains, target):
samples, _ = tfp.mcmc.sample_chain(
num_results=n_samples,
current_state=np.zeros((n_dims, n_chains), dtype=np.float32),
kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
num_burnin_steps=0,
parallel_iterations=8,
)
return samples
run_mcm = partial(run_tfp_mcmc, n_dims, n_samples, n_chains, target)
## Without XLA
run_mcm()
## With XLA compilation
tf.xla.experimental.compile(run_mcm)
結果
我們有兩個自由維度:樣本的數量和鏈的數量,第一個依賴于原始的數字處理能力,第二個也依賴于向量化的實現方式。因此,我決定在兩個維度上對算法進行基準測試。
我考慮以下情況:
Numpy 實現;
JAX 實現;
減去編譯時間的 JAX 實現。這只是一個假設的情況,目的是顯示編譯帶來的改進。
Tensorflow Probability;
實驗 XLA 編譯的 Tensorflow Probability。
用 1000 條鏈繪制越來越多的樣本
我們固定鏈的數量,并改變樣本的數量。

你將注意到 TFP 實現的缺失點。由于 TFP 算法存儲所有的樣本,所以它會耗盡內存。這在 XLA 編譯的版本中沒有發生,可能是因為它使用了內存效率更高的數據結構。
對于少于 1000 個樣本,普通的 TFP 和 Numpy 實現比它們的編譯副本要快。這是由于編譯開銷造成的:當你減去 JAX 的編譯時間 (從而獲得綠色曲線) 時,它會大大加快速度。只有當樣本的數量變得很大,并且總抽樣時間取決于抽取樣本的時間時,你才開始從編譯中獲益。
沒有什么神奇的:JIT 編譯意味著一個明顯的、但不變的計算開銷。
我建議在大多數情況下使用 JAX。只有當相同的代碼執行超過 10 次時,在 0.3 秒而不是 3 秒內進行采樣的差異才會產生影響。然而,編譯是只會發生一次。在這種情況下,計算開銷將在你達到 10 次迭代之前得到回報。實際上,JAX 贏了。
用越來越多的鏈繪制 1000 個樣本
在這里,我們固定樣本的數量,改變鏈的數量。

JAX 仍然明顯地贏了:只要鏈的數量達到 10,000,它就比 Numpy 更快。你將注意到 JAX 曲線上有一個凸起,這完全是由于編譯造成的 (綠色曲線沒有這個凸起)。我不知道為什么,如果有答案請告訴我!
這就是令人興奮的亮點:
JAX 可以在 25 秒內在 CPU 上生成 10 億個樣本,比 Numpy 快 20 倍!
結論
對于允許我們用純 python 編寫代碼的項目,JAX 的性能是令人難以置信的。Numpy 仍然是一個不錯的選擇,特別是對于那些 JAX 的大部分執行時間都花在編譯上的項目來說尤其如此。
但是,Numpy 不適合概率編程語言。如 Hamiltonian Monte Carlo 這樣的高效抽樣算 Uber 優步的團隊開始和 JAX 在 Numpyro 上合作。
不要過多地解讀 Tensorflow Probability 的拙劣表現。當從分布中采樣時,重要的不是原始速度,而是每秒有效采樣的數量。TFP 的實現包括更多的附加功能,我希望它在每秒有效采樣樣本數方面更具競爭力。
最后,請注意,用鏈的數量乘以樣本的數量要比用樣本的數量乘以樣本的數量容易得多。我們還不知道如何處理這些鏈,但我有一種直覺,一旦我們這樣做了,概率編程將會有另一個突破。
via:https://rlouf.github.io/post/jax-random-walk-metropolis/
雷鋒網雷鋒網雷鋒網
雷峰網版權文章,未經授權禁止轉載。詳情見轉載須知。