<label id="jgr5k"></label>
    <legend id="jgr5k"><track id="jgr5k"></track></legend>

    <sub id="jgr5k"></sub>
  1. <u id="jgr5k"></u>
      久草国产视频,91资源总站,在线免费看AV,丁香婷婷社区,久久精品99久久久久久久久,色天使av,无码探花,香蕉av在线
      您正在使用IE低版瀏覽器,為了您的雷峰網賬號安全和更好的產品體驗,強烈建議使用更快更安全的瀏覽器
      此為臨時鏈接,僅用于文章預覽,將在時失效
      人工智能開發者 正文
      發私信給skura
      發送

      0

      基于JAX的大規模并行MCMC:CPU25秒就可以處理10億樣本

      本文作者: skura 2020-01-14 11:43
      導語:在概率編程中,JAX 有很多優勢

      JAX 的表現出乎所有人的意料,在極端情況下,最大性能可提高 20 倍。由于 JAX 的 JIT 編譯開銷,Numpy 在少樣本、少量鏈的情況下會勝出。我報告了 tensorflow probability (TFP) 的結果,但請記住,這種比較是不公平的,因為它實現的隨機游走 metroplis 比我們的包含更多的功能。

      重現結果所需的代碼可以在這里找到。使代碼運行得更快的技巧值得學習。

      矢量化 MCMC

      Colin Carroll 最近發布了一篇有趣的博文,使用 Numpy 和隨機游走 metropolis 算法 (RWMH) 的矢量化版本來生成大量的樣本,同時運行多個鏈以便對算法的收斂性進行后驗檢驗。這通常是通過在多線程機器上每個線程運行一個鏈來實現的,在 Python 中使用 joblib 或自定義后端。這么做很麻煩,但它能完成任務。

      Colin 的 文章讓我感到非常興奮,因為我可以在幾乎不增加成本的情況下,同時對成千上萬的鏈進行取樣。他在文章中詳細介紹了幾個這一方法的應用,但我有一種直覺,它可以完成更多的事情。

      大約在同一時間,我偶然發現了 JAX。JAX 在概率編程語言環境中似乎很有趣,原因如下:

      • 在大多數情況下,它完全可以替代 Numpy;

      • Autodiff 很簡單;

      • 它的正向微分模式使得計算高階導數變得容易;

      • JAX 使用 XLA 執行 JIT 編譯,即使在 CPU 上也可以加速代碼的運行;

      • 使用 GPU 和 TPU 非常簡單;

      • 這是一個偏好問題,但它更傾向于函數式編程。

      在開始使用 JAX 實現一個框架之前,我想做一些基準測試,以了解我要注冊的是什么。這里我將進行比較:

      • Numpy

      • Jax

      • Tensorflow Probability (TFP)

      • XLA 編譯的 Tensorflow Probability

      關于基準測試

      在給出結果之前,首先需要聲明的是:

      1. 報告的時間是在我的筆記本電腦上運行 10 次的平均值,除了終端打開外,沒有任何其它操作。除了編譯后的 JAX 運行外,所有運行的時間都是使用 hyperfine 命令行工具測量的。

      2. 我的代碼可能不是最優的,對于 TFP 來說尤其如此。

      3. 實驗是在 CPU 上進行的。JAX 和 TFP 可以運行在 GPU/TPU 上,所以可以期待額外的加速。

      4. 對于 Numpy 和 JAX 來說,采樣器是一個生成器,樣本不保存在內存中但對 TFP 來說并非如此,因此在大型實驗期間,計算機會耗盡內存。如果 TFP 沒有在堆棧上預先分配內存,不斷地分配內存也會影響性能。

      5. 在概率編程中重要的度量是每秒有效采樣的數量,而不是每秒采樣數量,前者后者更像是你使用的算法。這個基準測試仍然可以很好地反映不同框架的原始性能。

      設置和結果

      我在對一個含有 4 個分量的任意高斯混合樣本進行采樣。使用 Numpy:

      import numpy as np
      from scipy.stats import norm
      from scipy.special import logsumexp

      def mixture_logpdf(x):
          loc = np.array([[-2, 0, 3.2, 2.5]]).T
          scale = np.array([[1.2, 1, 5, 2.8]]).T
          weights = np.array([[0.2, 0.3, 0.1, 0.4]]).T

          log_probs = norm(loc, scale).logpdf(x)

          return -logsumexp(np.log(weights) - log_probs, axis=0)

      Numpy

      Colin Carroll 的 MiniMC 是我見過的最簡單、最易讀的大都市隨機游走  Metropolis 和 Hamiltonian Monte Carlo 的實現。我的 Numpy 實現是他的一個迭代:

      import numpy as np

      def rw_metropolis_sampler(logpdf, initial_position):
          position = initial_position
          log_prob = logpdf(initial_position)
          yield position

          while True:
              move_proposals = np.random.normal(0, 0.1, size=initial_position.shape)
              proposal = position + move_proposals
              proposal_log_prob = logpdf(proposal)

              log_uniform = np.log(np.random.rand(initial_position.shape[0], initial_position.shape[1]))
              do_accept = log_uniform < proposal_log_prob - log_prob

              position = np.where(do_accept, proposal, position)
              log_prob = np.where(do_accept, proposal_log_prob, log_prob)
              yield position

      JAX

      JAX 的實現與 Numpy 非常相似:

      from functools import partial

      import jax
      import jax.numpy as np

      @partial(jax.jit, static_argnums=(0, 1))
      def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
          move_proposals = jax.random.normal(rng_key, shape=position.shape) * 0.1
          proposal = position + move_proposals
          proposal_log_prob = logpdf(proposal)

          log_uniform = np.log(jax.random.uniform(rng_key, shape=position.shape))
          do_accept = log_uniform < proposal_log_prob - log_prob

          position = np.where(do_accept, proposal, position)
          log_prob = np.where(do_accept, proposal_log_prob, log_prob)
          return position, log_prob


      def rw_metropolis_sampler(rng_key, logpdf, initial_position):
          position = initial_position
          log_prob = logpdf(initial_position)
          yield position

          while True:
              position, log_prob = rw_metropolis_kernel(rng_key, logpdf, position, log_prob)
              yield position

      如果你熟悉 Numpy,那么你應該非常熟悉它的語法。JAX 和它有一些不同之處:

      •  jax.numpy 充當 numpy 的替代。對于只涉及數組操作的函數,用 import jax.numpy as np 替換 import numpy as np,這會給你帶來性能上的提升。

      • JAX 處理隨機數生成的方式與其他 Python 包不同,這是有原因的 (請閱讀這篇文章:https://github.com/google/jax/blob/master/design_notes/prng.md ) 。每個發行版都以一個 PRNG 鍵作為輸入。

      • 因為 JAX 不能編譯生成器,我從采樣器中提取內核。因此,我們提取并 JIT 完成所有繁重工作的函數:rw_metropolis_kernel。

      • 我們需要對 JAX 的編譯器提供一點幫助,即指出當函數多次運行時哪些參數不會改變:@partial(jax.jit, argnums=(0, 1))。如果將函數作為參數傳遞,這是必需的,并且可以啟用進一步的編譯時優化。

      Tensorflow Probability

      對于 TFP,我們使用庫中實現的隨機游走 Metropolis 算法:

      from functools import partial

      import numpy as np
      import tensorflow as tf
      import tensorflow_probability as tfp
      tfd = tfp.distributions

      def run_raw_metropolis(n_dims, n_samples, n_chains, target):
          samples, _ = tfp.mcmc.sample_chain(
              num_results=n_samples,
              current_state=np.zeros((n_dims, n_chains), dtype=np.float32),
              kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
              num_burnin_steps=0,
              parallel_iterations=8,
          )
          return samples

      run_mcm = partial(run_tfp_mcmc, n_dims, n_samples, n_chains, target)

      ## Without XLA
      run_mcm()

      ## With XLA compilation
      tf.xla.experimental.compile(run_mcm)

      結果

      我們有兩個自由維度:樣本的數量和鏈的數量,第一個依賴于原始的數字處理能力,第二個也依賴于向量化的實現方式。因此,我決定在兩個維度上對算法進行基準測試。

      我考慮以下情況:

      1. Numpy 實現;

      2. JAX 實現;

      3. 減去編譯時間的 JAX 實現。這只是一個假設的情況,目的是顯示編譯帶來的改進。

      4. Tensorflow Probability;

      5. 實驗 XLA 編譯的 Tensorflow Probability。

      用 1000 條鏈繪制越來越多的樣本

      我們固定鏈的數量,并改變樣本的數量。

      基于JAX的大規模并行MCMC:CPU25秒就可以處理10億樣本

      你將注意到 TFP 實現的缺失點。由于 TFP 算法存儲所有的樣本,所以它會耗盡內存。這在 XLA 編譯的版本中沒有發生,可能是因為它使用了內存效率更高的數據結構。

      對于少于 1000 個樣本,普通的 TFP 和 Numpy 實現比它們的編譯副本要快。這是由于編譯開銷造成的:當你減去 JAX 的編譯時間 (從而獲得綠色曲線) 時,它會大大加快速度。只有當樣本的數量變得很大,并且總抽樣時間取決于抽取樣本的時間時,你才開始從編譯中獲益。

      沒有什么神奇的:JIT 編譯意味著一個明顯的、但不變的計算開銷。

      我建議在大多數情況下使用 JAX。只有當相同的代碼執行超過 10 次時,在 0.3 秒而不是 3 秒內進行采樣的差異才會產生影響。然而,編譯是只會發生一次。在這種情況下,計算開銷將在你達到 10 次迭代之前得到回報。實際上,JAX 贏了。

      用越來越多的鏈繪制 1000 個樣本

      在這里,我們固定樣本的數量,改變鏈的數量。

      基于JAX的大規模并行MCMC:CPU25秒就可以處理10億樣本

      JAX 仍然明顯地贏了:只要鏈的數量達到 10,000,它就比 Numpy 更快。你將注意到 JAX 曲線上有一個凸起,這完全是由于編譯造成的 (綠色曲線沒有這個凸起)。我不知道為什么,如果有答案請告訴我!

      這就是令人興奮的亮點:

      JAX 可以在 25 秒內在 CPU 上生成 10 億個樣本,比 Numpy 快 20 倍!

      結論

      對于允許我們用純 python 編寫代碼的項目,JAX 的性能是令人難以置信的。Numpy 仍然是一個不錯的選擇,特別是對于那些 JAX 的大部分執行時間都花在編譯上的項目來說尤其如此。

      但是,Numpy 不適合概率編程語言。如 Hamiltonian Monte Carlo 這樣的高效抽樣算 Uber 優步的團隊開始和 JAX 在 Numpyro 上合作。

      不要過多地解讀 Tensorflow Probability 的拙劣表現。當從分布中采樣時,重要的不是原始速度,而是每秒有效采樣的數量。TFP 的實現包括更多的附加功能,我希望它在每秒有效采樣樣本數方面更具競爭力。

      最后,請注意,用鏈的數量乘以樣本的數量要比用樣本的數量乘以樣本的數量容易得多。我們還不知道如何處理這些鏈,但我有一種直覺,一旦我們這樣做了,概率編程將會有另一個突破。

      via:https://rlouf.github.io/post/jax-random-walk-metropolis/

      雷鋒網雷鋒網雷鋒網

      雷峰網版權文章,未經授權禁止轉載。詳情見轉載須知

      基于JAX的大規模并行MCMC:CPU25秒就可以處理10億樣本

      分享:
      相關文章
      當月熱門文章
      最新文章
      請填寫申請人資料
      姓名
      電話
      郵箱
      微信號
      作品鏈接
      個人簡介
      為了您的賬戶安全,請驗證郵箱
      您的郵箱還未驗證,完成可獲20積分喲!
      請驗證您的郵箱
      立即驗證
      完善賬號信息
      您的賬號已經綁定,現在您可以設置密碼以方便用郵箱登錄
      立即設置 以后再說
      主站蜘蛛池模板: 人妻无码熟妇乱又伦精品视频| 若羌县| 午夜精品久久久久久久无码软件| 免费区欧美一级猛片| 国产福利萌白酱在线观看视频| 日韩成人无码| 亚洲欧洲自拍| 无码中出人妻中文字幕AV| 亚州精品无码| 久久无码人妻热线精品| 婷婷综合亚洲| 国产精品 视频一区 二区三区 | 一本久道久久综合无码中文| 无码少妇一区二区三区免费| 国产一区韩国主播| 婷婷久久久久| 黑人巨大精品欧美视频一区| 久久精品欧美日韩精品| 婷婷五月激情综合| 性交大片| 亚洲人成欧美中文字幕| 日韩欧美国产精品| 亚洲精品喷潮一区二区三区| 国产成人精品a视频一区| 广东少妇大战黑人34厘米视频 | 午夜综合网| 在线综合亚洲欧洲综合网站| 在线国产毛片| 国产精品亚洲综合色区韩国| 久久久999| 草草影院ccyy| 爱情岛论坛首页永久入口| 扎鲁特旗| 国产69久久精品成人看| 亚洲熟妇色自偷自拍另类| 国产人妻丰满熟妇嗷嗷叫| 3Pav图| www.色吊丝av.com| 男人天堂网址| 中文字幕久久精品无码综合网| 黑巨人与欧美精品一区|