Skip to content

[exchangeable] Update lecture to use JAX#772

Open
kp992 wants to merge 2 commits into
mainfrom
exchangeable_fixes
Open

[exchangeable] Update lecture to use JAX#772
kp992 wants to merge 2 commits into
mainfrom
exchangeable_fixes

Conversation

@kp992
Copy link
Copy Markdown
Contributor

@kp992 kp992 commented Dec 26, 2025

Update the lecture to use JAX and removes numba related code. Also fixes minor typos and code styling issues.

@github-actions
Copy link
Copy Markdown

📖 Netlify Preview Ready!

Preview URL: https://pr-772--sunny-cactus-210e3e.netlify.app (5098c0e)

📚 Changed Lecture Pages: exchangeable

@kp992 kp992 requested review from HumphreyYang and mmcky December 26, 2025 02:10
@jstac
Copy link
Copy Markdown
Contributor

jstac commented May 30, 2026

@longye-tian , would you be willing to review this PR?

@mmcky The deploy link seems broken. Do they go stale? after how long? Should @longye-tian trigger a new build with an empty commit?

@jstac
Copy link
Copy Markdown
Contributor

jstac commented May 30, 2026

@longye-tian or any other reviewer: Does adding JAX improve this lecture? Please also make this value judgement.

Copy link
Copy Markdown
Contributor

@longye-tian longye-tian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kp992 @jstac,

I think the JAX conversion is natural in the appendix simulation section, especially around the replacement of the old nested-loop simulation. This seems like a good fit because the simulation has two parts: recursion over time and repetition across independent paths.

It might be helpful to say this explicitly in the text before the code. Currently the lecture says:

To proceed, we create some Python code.

Perhaps add something like:

The simulation has two dimensions: recursion over time and repetition across independent paths. In JAX, we use lax.scan for the time recursion and vmap to apply the path simulator across many independent key sequences.

For the learning_example block, I’m less sure that the JAX conversion adds much. This part changes the beta-density helper from the old Numba/vectorized version and wraps several scalar helpers with jax.jit, but the surrounding computation is still mostly SciPy root finding, SciPy quadrature, NumPy grids, and Matplotlib plotting. So JAX transformations do not seem central to this part of the lecture. I can see the value of removing Numba and using JAX for consistency, but plain NumPy/SciPy may be clearer here.

A more useful exposition improvement might be to split this page-long helper function into the three conceptual pieces shown in the output figure:

  1. likelihood ratio plot
  2. density / probability-region plot
  3. posterior-dynamics arrows

That would let the text introduce and discuss each panel one by one, rather than asking readers to parse a long plotting helper before seeing the three ideas it creates.

A few small code-level updates could also make the JAX appendix clearer:

  • Instead of random_seed = int(a * b + T + N), consider adding an explicit seed or key argument. That would make the source of randomness clearer.
  • The comment # Generate all random keys upfront appears inside simulate_path, but the keys are actually generated in simulate; perhaps change it to something like # Use one random key for each date.

A few small remaining typos I noticed while reading:

  • “less that one” → “less than one”
  • “Bayes’ Law make $\pi$ decrease” → “Bayes’ Law makes $\pi$ decrease”
  • “The above graphs shows” → “The above graph shows”
  • expected_rarioexpected_ratio

Best,
Longye

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants