Source code for mujoco_warp._src.derivative

# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import warp as wp

from mujoco_warp._src.types import BiasType
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import DisableBit
from mujoco_warp._src.types import DynType
from mujoco_warp._src.types import GainType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import TileSet
from mujoco_warp._src.types import vec10f
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})


@wp.kernel
def _qderiv_actuator_passive_vel(
  # Model:
  actuator_dyntype: wp.array(dtype=int),
  actuator_gaintype: wp.array(dtype=int),
  actuator_biastype: wp.array(dtype=int),
  actuator_actadr: wp.array(dtype=int),
  actuator_actnum: wp.array(dtype=int),
  actuator_forcelimited: wp.array(dtype=bool),
  actuator_gainprm: wp.array2d(dtype=vec10f),
  actuator_biasprm: wp.array2d(dtype=vec10f),
  actuator_forcerange: wp.array2d(dtype=wp.vec2),
  # Data in:
  act_in: wp.array2d(dtype=float),
  ctrl_in: wp.array2d(dtype=float),
  actuator_force_in: wp.array2d(dtype=float),
  # Out:
  vel_out: wp.array2d(dtype=float),
):
  worldid, actid = wp.tid()

  actuator_gainprm_id = worldid % actuator_gainprm.shape[0]
  actuator_biasprm_id = worldid % actuator_biasprm.shape[0]

  if actuator_gaintype[actid] == GainType.AFFINE:
    gain = actuator_gainprm[actuator_gainprm_id, actid][2]
  else:
    gain = 0.0

  if actuator_biastype[actid] == BiasType.AFFINE:
    bias = actuator_biasprm[actuator_biasprm_id, actid][2]
  else:
    bias = 0.0

  if bias == 0.0 and gain == 0.0:
    vel_out[worldid, actid] = 0.0
    return

  # skip if force is clamped by forcerange
  if actuator_forcelimited[actid]:
    force = actuator_force_in[worldid, actid]
    forcerange = actuator_forcerange[worldid % actuator_forcerange.shape[0], actid]
    if force <= forcerange[0] or force >= forcerange[1]:
      vel_out[worldid, actid] = 0.0
      return

  vel = float(bias)
  if actuator_dyntype[actid] != DynType.NONE:
    if gain != 0.0:
      act_first = actuator_actadr[actid]
      act_last = act_first + actuator_actnum[actid] - 1
      vel += gain * act_in[worldid, act_last]
  else:
    if gain != 0.0:
      vel += gain * ctrl_in[worldid, actid]

  vel_out[worldid, actid] = vel


@wp.func
def _nonzero_mask(x: float) -> float:
  """Returns 1.0 for non-zero input, 0.0 otherwise."""
  if x != 0.0:
    return 1.0
  return 0.0


@cache_kernel
def _qderiv_actuator_passive_actuation_dense(tile: TileSet, nu: int):
  @wp.kernel(module="unique", enable_backward=False)
  def kernel(
    # Data in:
    actuator_moment_in: wp.array3d(dtype=float),
    qM_in: wp.array3d(dtype=float),
    # In:
    vel_in: wp.array3d(dtype=float),
    adr: wp.array(dtype=int),
    # Out:
    qDeriv_out: wp.array3d(dtype=float),
  ):
    worldid, nodeid = wp.tid()
    TILE_SIZE = wp.static(tile.size)
    NU = wp.static(nu)

    dofid = adr[nodeid]
    vel_tile = wp.tile_load(vel_in[worldid], shape=(NU, 1), bounds_check=False)
    moment_tile = wp.tile_load(actuator_moment_in[worldid], shape=(NU, TILE_SIZE), offset=(0, dofid), bounds_check=False)
    moment_weighted = wp.tile_map(wp.mul, wp.tile_broadcast(vel_tile, shape=(NU, TILE_SIZE)), moment_tile)
    qderiv_tile = wp.tile_matmul(wp.tile_transpose(moment_tile), moment_weighted)

    # Mask out cross-terms for DOF pairs that are structurally zero in M
    # (e.g., sibling DOFs coupled only through tendons).  Without this,
    # stale actuation values at sibling positions make A = M - dt*qDeriv
    # non-positive-definite, causing the tiled Cholesky to produce NaN.
    # Dropping these terms matches MuJoCo CPU's implicitfast approximation.
    qM_tile = wp.tile_load(qM_in[worldid], shape=(TILE_SIZE, TILE_SIZE), offset=(dofid, dofid), bounds_check=False)
    mask_tile = wp.tile_map(_nonzero_mask, qM_tile)
    qderiv_tile = wp.tile_map(wp.mul, qderiv_tile, mask_tile)

    wp.tile_store(qDeriv_out[worldid], qderiv_tile, offset=(dofid, dofid), bounds_check=False)

  return kernel


@wp.kernel
def _qderiv_actuator_passive_actuation_sparse(
  # Model:
  nu: int,
  # Data in:
  actuator_moment_in: wp.array3d(dtype=float),
  # In:
  vel_in: wp.array2d(dtype=float),
  qMi: wp.array(dtype=int),
  qMj: wp.array(dtype=int),
  # Out:
  qDeriv_out: wp.array3d(dtype=float),
):
  worldid, elemid = wp.tid()

  dofiid = qMi[elemid]
  dofjid = qMj[elemid]
  qderiv_contrib = float(0.0)
  for actid in range(nu):
    vel = vel_in[worldid, actid]
    if vel == 0.0:
      continue

    moment_i = actuator_moment_in[worldid, actid, dofiid]
    moment_j = actuator_moment_in[worldid, actid, dofjid]

    qderiv_contrib += moment_i * moment_j * vel

  qDeriv_out[worldid, 0, elemid] = qderiv_contrib


@wp.kernel
def _qderiv_actuator_passive(
  # Model:
  opt_timestep: wp.array(dtype=float),
  opt_disableflags: int,
  dof_damping: wp.array2d(dtype=float),
  is_sparse: bool,
  # Data in:
  qM_in: wp.array3d(dtype=float),
  # In:
  qMi: wp.array(dtype=int),
  qMj: wp.array(dtype=int),
  qDeriv_in: wp.array3d(dtype=float),
  # Out:
  qDeriv_out: wp.array3d(dtype=float),
):
  worldid, elemid = wp.tid()

  dofiid = qMi[elemid]
  dofjid = qMj[elemid]

  if is_sparse:
    qderiv = qDeriv_in[worldid, 0, elemid]
  else:
    qderiv = qDeriv_in[worldid, dofiid, dofjid]

  if not opt_disableflags & DisableBit.DAMPER and dofiid == dofjid:
    qderiv -= dof_damping[worldid % dof_damping.shape[0], dofiid]

  qderiv *= opt_timestep[worldid % opt_timestep.shape[0]]

  if is_sparse:
    qDeriv_out[worldid, 0, elemid] = qM_in[worldid, 0, elemid] - qderiv
  else:
    qM = qM_in[worldid, dofiid, dofjid] - qderiv
    qDeriv_out[worldid, dofiid, dofjid] = qM
    if dofiid != dofjid:
      qDeriv_out[worldid, dofjid, dofiid] = qM


# TODO(team): improve performance with tile operations?
@wp.kernel
def _qderiv_tendon_damping(
  # Model:
  ntendon: int,
  opt_timestep: wp.array(dtype=float),
  tendon_damping: wp.array2d(dtype=float),
  is_sparse: bool,
  # Data in:
  ten_J_in: wp.array3d(dtype=float),
  # In:
  qMi: wp.array(dtype=int),
  qMj: wp.array(dtype=int),
  # Out:
  qDeriv_out: wp.array3d(dtype=float),
):
  worldid, elemid = wp.tid()
  dofiid = qMi[elemid]
  dofjid = qMj[elemid]

  qderiv = float(0.0)
  tendon_damping_id = worldid % tendon_damping.shape[0]
  for tenid in range(ntendon):
    qderiv -= ten_J_in[worldid, tenid, dofiid] * ten_J_in[worldid, tenid, dofjid] * tendon_damping[tendon_damping_id, tenid]

  qderiv *= opt_timestep[worldid % opt_timestep.shape[0]]

  if is_sparse:
    qDeriv_out[worldid, 0, elemid] -= qderiv
  else:
    qDeriv_out[worldid, dofiid, dofjid] -= qderiv
    if dofiid != dofjid:
      qDeriv_out[worldid, dofjid, dofiid] -= qderiv


[docs] @event_scope def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): """Analytical derivative of smooth forces w.r.t. velocities. Args: m: The model containing kinematic and dynamic information (device). d: The data object containing the current state and output arrays (device). out: qM - dt * qDeriv (derivatives of smooth forces w.r.t velocities). """ qMi = m.qM_fullm_i qMj = m.qM_fullm_j # TODO(team): implicit requires different sparsity structure if ~(m.opt.disableflags & (DisableBit.ACTUATION | DisableBit.DAMPER)): # TODO(team): only clear elements not set by _qderiv_actuator_passive out.zero_() if m.nu > 0 and not m.opt.disableflags & DisableBit.ACTUATION: vel = wp.empty((d.nworld, m.nu), dtype=float) wp.launch( _qderiv_actuator_passive_vel, dim=(d.nworld, m.nu), inputs=[ m.actuator_dyntype, m.actuator_gaintype, m.actuator_biastype, m.actuator_actadr, m.actuator_actnum, m.actuator_forcelimited, m.actuator_gainprm, m.actuator_biasprm, m.actuator_forcerange, d.act, d.ctrl, d.actuator_force, ], outputs=[vel], ) if m.is_sparse: wp.launch( _qderiv_actuator_passive_actuation_sparse, dim=(d.nworld, qMi.size), inputs=[m.nu, d.actuator_moment, vel, qMi, qMj], outputs=[out], ) else: vel_3d = vel.reshape(vel.shape + (1,)) for tile in m.qM_tiles: wp.launch_tiled( _qderiv_actuator_passive_actuation_dense(tile, m.nu), dim=(d.nworld, tile.adr.size), inputs=[d.actuator_moment, d.qM, vel_3d, tile.adr], outputs=[out], block_dim=m.block_dim.qderiv_actuator_dense, ) wp.launch( _qderiv_actuator_passive, dim=(d.nworld, qMi.size), inputs=[ m.opt.timestep, m.opt.disableflags, m.dof_damping, m.is_sparse, d.qM, qMi, qMj, out, ], outputs=[out], ) else: # TODO(team): directly utilize qM for these settings wp.copy(out, d.qM) if not m.opt.disableflags & DisableBit.DAMPER: wp.launch( _qderiv_tendon_damping, dim=(d.nworld, qMi.size), inputs=[m.ntendon, m.opt.timestep, m.tendon_damping, m.is_sparse, d.ten_J, qMi, qMj], outputs=[out], )
# TODO(team): rne derivative