En/review eg/dynamics fix#38
Conversation
…f initializing particles at a given distance from an axis, correcting energy/pitch/mu/perpnedicular_velocity methods for different types of tracing and adding first iteration of differentiable loss functions for loss fraction, heat flux and loss positions
rogeriojorge
left a comment
There was a problem hiding this comment.
Mostly looks good! A few things I spot:
-
dynamics.py:1371-1375,
loss_fraction_collisions_differentiable(): this checkshasattr(self, "energy")and then usesself.energyas if it were an array, butenergyis a method. This should likely beenergies = self.energy()once, then useenergies * per_particle_loss[:, None]. -
dynamics.py:1414-1427,
escape_location_rmax(): this usestrajectories = self.trajectories, so the weighted escape locations can include the full state dimension, not just position. Since the docstring and penalty functions expect(n_timesteps, 3), this should usetrajectories_xyz = self.trajectories[:, :, :3]for the positions. -
dynamics.py:1512,
1585,1659, and1741, classifier penalty methods: the JIT static args look wrong/incomplete. For example,escape_location_penalty_classifier()markstarget_positionstatic but notboundary. The target/line/plane/band arrays should stay dynamic, whileboundaryshould be static, e.g.@partial(jit, static_argnums=(0,), static_argnames=["boundary"]). Otherwise this can fail with non-array boundary objects or recompile unnecessarily for target arrays. -
dynamics.py:1335-1336
and1391-1392: the detailed differentiable loss functions computecumsum(...)and then normalize by the final value, which makes the final loss approximately1` whenever there is any loss. Is this intended? I think these should return the actual cumulative loss fraction/probability (divide by total number of particles?)
Minor things
-
remove the commented-out sharding block at
dynamics.py:21-27, fix the indentation at31-35, remove unusedxm_axisandrandom_phis_offsetinInitializeParticlesAroundSurfaceAxis()(209,236), and avoid reusing the original PRNG key forinitial_vparallel_over_vat288after splitting it. -
Another minor thing: dynamics.py:1128-1129,
v_perp()full-orbit branch:jnp.sqrt(vperp_squared)can produce NaNs from tiny negative roundoff. Please usejnp.sqrt(jnp.maximum(vperp_squared, 0.0)). -
Final minor thing. dynamics.py:1268-1269, loss_fraction_collisions(): this currently indexes with
x-1, so the first particle uses index-1, andlost_indices[x-1]-1can also wrap to the last timestep. I’d replace this with the other wrapper thing already used inloss_fraction_BioSavart_collisions()
… positions ans lodd energy. This will only be used in future examples which are not included in the prsent refactoring so they will be added later
…ed sharding section
…moved unused variables and added more clear key splitting
…ke loss_fraction_BioSavart_collisions which solved the existing bug with the first particle index
|
Addressed all issues here. I removed the differentiable loss functions as they will not be used in the present examples, and they are better added at a future pull request for alpha collisional transport optimzation. |
EstevaoMGomes
left a comment
There was a problem hiding this comment.
Overall, it looks good. I think the PR is ready to be merged.
Adding sharding safe check for number of devices, adding capability of initializing particles at a given distance from an axis, correcting energy/pitch/mu/perpnedicular_velocity methods for different types of tracing and adding first iteration of differentiable loss functions for loss fraction, heat flux and loss positions. Changes the dynamics.py.