Overview#
TensorFlow and JAX#
This package provides two ‘substrates’ for modeling: modeling with TensorFlow and modeling with JAX.
The benefits of modeling with TensorFlow include an easier setup (TensorFlow is fairly standard anywhere
with access to GPUs), and a more mature implementation of algorithms such as variational
inference and HMC in TensorFlow Probability (TFP). However, for non-neural network applications, TensorFlow does not provide an easy
way to parallelize across multiple GPUs. JAX, on the other hand, natively supports multiple
devices, and has experimental support in TFP. For beginners, we recommend starting with the
TensorFlow substrate in gigalens.tf
– however, for experienced users who may be willing to experiment,
gigalens.jax
with multiple GPUs can be significantly faster than gigalens.tf
.
Although the two implementations are fundamentally identical, there are some small differences. While the flattening
of the parameter vector is automatically done by TFP for the TF substrate, for now the
flattening for the JAX substrate is done by hand (JAX may add this feature in the future, in which
case, we will update our implementation). Similarly, we implement VI for JAX, whereas in the TensorFlow substrate,
TFP already implements VI, so we use their built-in method (see tfp.vi.fit_surrogate_posterior()
). Also, we
note that using classes with TensorFlow can cause ~15% slowdown compared with less modular code. This is due
to the @tf.function
wrapper tracing the self
object in class methods, but we choose this approach
for the ease of use and maintenance.