Recently I had fun reimplementing an old (but still usable!) code for accelerator optics. It involved transfer matrices for a 6D phase space to second order. Most of the FORTRAN77 source code was just pages and pages of hand-differentiated 6x6x6 matrices (with quite non-trivial elements) and the plumbing to painstakingly propagate those jacobians around for fitting... all replaced with a single, magic, call to jax.grad(). Felt like cheating!
I'm also super interested in its application to modelling, e.g. projects like https://github.com/deepmodeling/jax-fem -- particularly for chaining different sorts of simulations and analysis together and getting gradients through the lot. Also quite magic!
I had a lot of fun writing the article! And it is only half a joke
My intuition for so-called world models is that we'll have to plug modules, each responsible for a domain (text, video, sound, robot-haptics, physical modelling) It'll require to plug modules in a way that will allow the gradient to propagate. A differentiable architecture. And JAX seems well placed for this by making function manipulation a first citizen. Looking at your testimony comforts me in this view
> the thing JAX was truly meant for: a graphics renderer
I mean, just like ray-tracing, SDF (ray-marching) is neat, but basically everything useful is expensive or hard to do (collisions, meshes, texturing etc.). I mean mathy stuff is easier (rotations, unions/intersections, function composition, etc.) but 3D is usually used in either modeling software or video games, which care more about the former than they do the latter.
I believe we could get there eventually. For example for collision there is work to make it differentiable (or use a local surrogate at the collision point):
https://arxiv.org/abs/2207.00669
The robotics will need to connect vision with motors with haptics with 3D modelling. And to propagate gradient seamlessly. For calibrating torque with the the elastic deformation of the material for example. After all matter is not discreet at small scales (staying above the atomic scale)
All this will require all modules to be compatible with differentiability. It'll be expensive at first, but I'm sure some optimizations can get us close to the discreet case.
Also even for meshes there is a lot to gain with trying to go the continuous way:
JAX is designed from the start to fit well with systolic arrays (TPUs, Nvidia's tensor cores, etc), which are extremely energy-efficient. WebGL won't be the tool that connects it on the web, but the generation after WebGPU will.
Recently I had fun reimplementing an old (but still usable!) code for accelerator optics. It involved transfer matrices for a 6D phase space to second order. Most of the FORTRAN77 source code was just pages and pages of hand-differentiated 6x6x6 matrices (with quite non-trivial elements) and the plumbing to painstakingly propagate those jacobians around for fitting... all replaced with a single, magic, call to jax.grad(). Felt like cheating!
I'm also super interested in its application to modelling, e.g. projects like https://github.com/deepmodeling/jax-fem -- particularly for chaining different sorts of simulations and analysis together and getting gradients through the lot. Also quite magic!
I had a lot of fun writing the article! And it is only half a joke
My intuition for so-called world models is that we'll have to plug modules, each responsible for a domain (text, video, sound, robot-haptics, physical modelling) It'll require to plug modules in a way that will allow the gradient to propagate. A differentiable architecture. And JAX seems well placed for this by making function manipulation a first citizen. Looking at your testimony comforts me in this view
You did not miss much though: it just rotates the scene.
I like the concept of applying Jax to SDF sphere tracing :)
I mean, just like ray-tracing, SDF (ray-marching) is neat, but basically everything useful is expensive or hard to do (collisions, meshes, texturing etc.). I mean mathy stuff is easier (rotations, unions/intersections, function composition, etc.) but 3D is usually used in either modeling software or video games, which care more about the former than they do the latter.
The robotics will need to connect vision with motors with haptics with 3D modelling. And to propagate gradient seamlessly. For calibrating torque with the the elastic deformation of the material for example. After all matter is not discreet at small scales (staying above the atomic scale)
All this will require all modules to be compatible with differentiability. It'll be expensive at first, but I'm sure some optimizations can get us close to the discreet case.
Also even for meshes there is a lot to gain with trying to go the continuous way:
https://www.cs.cmu.edu/~kmcrane/Projects/DDG/
Also, what's the story for JAX + WebGL when it comes to targeting hardware-accelerated ray tracing?