# Copyright 2025 DeepMind Technologies Limited## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Derivative functions."""fromtypingimportOptionalimportjaxfromjaximportnumpyasjp# pylint: disable=g-importing-memberfrommujoco.mjx._src.typesimportBiasTypefrommujoco.mjx._src.typesimportDatafrommujoco.mjx._src.typesimportDataJAXfrommujoco.mjx._src.typesimportDisableBitfrommujoco.mjx._src.typesimportDynTypefrommujoco.mjx._src.typesimportGainTypefrommujoco.mjx._src.typesimportModelfrommujoco.mjx._src.typesimportModelJAXfrommujoco.mjx._src.typesimportOptionJAX# pylint: enable=g-importing-member
[docs]defderiv_smooth_vel(m:Model,d:Data)->Optional[jax.Array]:"""Analytical derivative of smooth forces w.r.t. velocities."""if(notisinstance(m._impl,ModelJAX)ornotisinstance(d._impl,DataJAX)ornotisinstance(m.opt._impl,OptionJAX)):raiseValueError('deriv_smooth_vel requires JAX MJX implementation.')qderiv=None# qDeriv += d qfrc_actuator / d qvelifnotm.opt.disableflags&DisableBit.ACTUATION:affine_bias=m.actuator_biastype==BiasType.AFFINEbias_vel=m.actuator_biasprm[:,2]*affine_biasaffine_gain=m.actuator_gaintype==GainType.AFFINEgain_vel=m.actuator_gainprm[:,2]*affine_gainctrl=d.ctrl.at[m.actuator_dyntype!=DynType.NONE].set(d.act)vel=bias_vel+gain_vel*ctrlqderiv=d._impl.actuator_moment.T@jax.vmap(jp.multiply)(d._impl.actuator_moment,vel)# qDeriv += d qfrc_passive / d qvelifnotm.opt.disableflags&DisableBit.DAMPER:ifqderivisNone:qderiv=-jp.diag(m.dof_damping)else:qderiv-=jp.diag(m.dof_damping)ifm.ntendon:qderiv-=d._impl.ten_J.T@jp.diag(m.tendon_damping)@d._impl.ten_Jifnotm.opt.disableflags&(DisableBit.DAMPER|DisableBit.SPRING):# TODO(robotics-simulation): fluid drag modelifm.opt._impl.has_fluid_params:# pytype: disable=attribute-errorraiseNotImplementedError('fluid drag not supported for implicitfast')# TODO(team): rne derivativereturnqderiv