File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \
7070 echo "Installing Jax GPU" \
7171 && cd /algorithmic-efficiency \
7272 && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \
73- && pip install -e '.[jax_gpu]' ; \
73+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
7474 elif [ "$framework" = "pytorch" ] ; then \
7575 echo "Installing Pytorch GPU" \
7676 && cd /algorithmic-efficiency \
@@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \
8080 echo "Installing Jax GPU and Pytorch GPU" \
8181 && cd /algorithmic-efficiency \
8282 && pip install -e '.[pytorch_gpu]' \
83- && pip install -e '.[jax_gpu]' ; \
83+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
8484 else \
8585 echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
8686 && exit 1 ; \
You can’t perform that action at this time.
0 commit comments