Philosophy#

TensorFlow and JAX#

This package provides two ‘substrates’ for modeling: modeling with TensorFlow and modeling with JAX. The benefits of modeling with TensorFlow include an easier setup (TensorFlow is fairly standard anywhere with access to GPUs), and a more mature implementation of algorithms such as variational inference and HMC in TensorFlow Probability (TFP). However, TensorFlow does not provide an easy way to parallelize across multiple GPUs. JAX, on the other hand, natively supports multiple devices, and has experimental support in TFP. For beginners, we recommend starting with the TensorFlow implementation in gigalens.tf – however, for experienced users who may be willing experiment, gigalens.jax can be significantly faster than gigalens.tf.

Although the two implementations are fundamentally identical, there are some small differences. The most significant is that the unconstraining bijector used in the JAX implementation is not able to fully flatten unconstrained parameters into a JAX array. This seems to be due to limited support for tfp.substrates.jax.bijectors.Split, and this may be resolved in upcoming releases as the JAX substrate in TFP matures. Similarly, the VI for JAX is implemented within the gigalens package, whereas in the TensorFlow implementation, TFP already implements VI, so we use their built-in method (see tfp.vi.fit_surrogate_posterior())