[exchangeable] Update lecture to use JAX#772
Conversation
|
📖 Netlify Preview Ready! Preview URL: https://pr-772--sunny-cactus-210e3e.netlify.app (5098c0e) 📚 Changed Lecture Pages: exchangeable |
|
@longye-tian , would you be willing to review this PR? @mmcky The deploy link seems broken. Do they go stale? after how long? Should @longye-tian trigger a new build with an empty commit? |
|
@longye-tian or any other reviewer: Does adding JAX improve this lecture? Please also make this value judgement. |
There was a problem hiding this comment.
I think the JAX conversion is natural in the appendix simulation section, especially around the replacement of the old nested-loop simulation. This seems like a good fit because the simulation has two parts: recursion over time and repetition across independent paths.
It might be helpful to say this explicitly in the text before the code. Currently the lecture says:
To proceed, we create some Python code.
Perhaps add something like:
The simulation has two dimensions: recursion over time and repetition across independent paths. In JAX, we use
lax.scanfor the time recursion andvmapto apply the path simulator across many independent key sequences.
For the learning_example block, I’m less sure that the JAX conversion adds much. This part changes the beta-density helper from the old Numba/vectorized version and wraps several scalar helpers with jax.jit, but the surrounding computation is still mostly SciPy root finding, SciPy quadrature, NumPy grids, and Matplotlib plotting. So JAX transformations do not seem central to this part of the lecture. I can see the value of removing Numba and using JAX for consistency, but plain NumPy/SciPy may be clearer here.
A more useful exposition improvement might be to split this page-long helper function into the three conceptual pieces shown in the output figure:
- likelihood ratio plot
- density / probability-region plot
- posterior-dynamics arrows
That would let the text introduce and discuss each panel one by one, rather than asking readers to parse a long plotting helper before seeing the three ideas it creates.
A few small code-level updates could also make the JAX appendix clearer:
- Instead of
random_seed = int(a * b + T + N), consider adding an explicitseedorkeyargument. That would make the source of randomness clearer. - The comment
# Generate all random keys upfrontappears insidesimulate_path, but the keys are actually generated insimulate; perhaps change it to something like# Use one random key for each date.
A few small remaining typos I noticed while reading:
- “less that one” → “less than one”
- “Bayes’ Law make
$\pi$ decrease” → “Bayes’ Law makes$\pi$ decrease” - “The above graphs shows” → “The above graph shows”
-
expected_rario→expected_ratio
Best,
Longye
Update the lecture to use JAX and removes numba related code. Also fixes minor typos and code styling issues.