fridom.framework.utils.jax_utils.jaxjit#
- fridom.framework.utils.jax_utils.jaxjit(fun: callable, *args, **kwargs) callable[source]#
Decorator for JAX JIT compilation.
Description#
This decorator is a wrapper around jax.jit. When jax is not installed, the function is returned as it is.
Parameters#
- funcallable
The function to JIT compile.
Returns#
- callable
The JIT compiled function.
Examples#
>>> import fridom.framework as fr >>> @fr.utils.jaxjit ... def my_function(x): ... return x**2