Jax's true calling: Ray-Marching renderers on WebGL

(benoit.paris)

88 points | by BenoitP 1 day ago

6 comments

  • heisenzombie 1 day ago
    Jax is super fun to use outside of ml!

    Recently I had fun reimplementing an old (but still usable!) code for accelerator optics. It involved transfer matrices for a 6D phase space to second order. Most of the FORTRAN77 source code was just pages and pages of hand-differentiated 6x6x6 matrices (with quite non-trivial elements) and the plumbing to painstakingly propagate those jacobians around for fitting... all replaced with a single, magic, call to jax.grad(). Felt like cheating!

    I'm also super interested in its application to modelling, e.g. projects like https://github.com/deepmodeling/jax-fem -- particularly for chaining different sorts of simulations and analysis together and getting gradients through the lot. Also quite magic!

    • BenoitP 23 hours ago
      Yeah :)

      I had a lot of fun writing the article! And it is only half a joke

      My intuition for so-called world models is that we'll have to plug modules, each responsible for a domain (text, video, sound, robot-haptics, physical modelling) It'll require to plug modules in a way that will allow the gradient to propagate. A differentiable architecture. And JAX seems well placed for this by making function manipulation a first citizen. Looking at your testimony comforts me in this view

  • corndoge 1 day ago
    Moving my thumb across the image causes the ball and cube graphic to disappear to black and then scrolls the page. Firefox on iOS
    • BenoitP 23 hours ago
      Damn, I should have spent more time QA-ing that post. I'll try to patch it.

      You did not miss much though: it just rotates the scene.

    • akoboldfrying 1 day ago
      Me too, Chrome on Android.

      I like the concept of applying Jax to SDF sphere tracing :)

  • VHRanger 1 day ago
    Pytorch is such a maddening mess of half implemented research features in a state of Heisen-deprecation, Jax becomes more appealing to me by the day.
  • vatsachak 1 day ago
    Yeah GPU compilers will be used for way more things than AI because parallel = good
    • jedbrooke 1 day ago
      We’ve come full circle and started using Graphics Processing Units to process graphics again
  • chillee 21 hours ago
    I'd also note that you can more or less write the same code in PyTorch with torch.vmap
  • dvt 1 day ago
    > the thing JAX was truly meant for: a graphics renderer

    I mean, just like ray-tracing, SDF (ray-marching) is neat, but basically everything useful is expensive or hard to do (collisions, meshes, texturing etc.). I mean mathy stuff is easier (rotations, unions/intersections, function composition, etc.) but 3D is usually used in either modeling software or video games, which care more about the former than they do the latter.

    • BenoitP 23 hours ago
      I believe we could get there eventually. For example for collision there is work to make it differentiable (or use a local surrogate at the collision point): https://arxiv.org/abs/2207.00669

      The robotics will need to connect vision with motors with haptics with 3D modelling. And to propagate gradient seamlessly. For calibrating torque with the the elastic deformation of the material for example. After all matter is not discreet at small scales (staying above the atomic scale)

      All this will require all modules to be compatible with differentiability. It'll be expensive at first, but I'm sure some optimizations can get us close to the discreet case.

      Also even for meshes there is a lot to gain with trying to go the continuous way:

      https://www.cs.cmu.edu/~kmcrane/Projects/DDG/

    • Archit3ch 1 day ago
      Games and simulations are typically stateful, I'm not sure the functional purity of JAX is a good fit.

      Also, what's the story for JAX + WebGL when it comes to targeting hardware-accelerated ray tracing?

      • BenoitP 23 hours ago
        JAX is designed from the start to fit well with systolic arrays (TPUs, Nvidia's tensor cores, etc), which are extremely energy-efficient. WebGL won't be the tool that connects it on the web, but the generation after WebGPU will.
      • nightski 1 day ago
        Maybe you mean mutation? State is inherently a part of functional purity. It's just handled explicitly instead of implicitly.