Source code for mujoco.mjx._src.passive

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Passive forces."""

from typing import Tuple

import jax
from jax import numpy as jp
from mujoco.mjx._src import math
from mujoco.mjx._src import scan
from mujoco.mjx._src import support
# pylint: disable=g-importing-member
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DataJAX
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ModelJAX
from mujoco.mjx._src.types import OptionJAX
# pylint: enable=g-importing-member


def _spring_damper(m: Model, d: Data) -> jax.Array:
  """Applies joint level spring and damping forces."""
  if not isinstance(d._impl, DataJAX) and not isinstance(m._impl, ModelJAX):
    raise ValueError('_spring_damper requires JAX backend implementation.')
  assert isinstance(d._impl, DataJAX) and isinstance(m._impl, ModelJAX)

  def fn(jnt_typs, stiffness, qpos_spring, qpos):
    qpos_i = 0
    qfrcs = []
    for i in range(len(jnt_typs)):
      jnt_typ = JointType(jnt_typs[i])
      q = qpos[qpos_i : qpos_i + jnt_typ.qpos_width()]
      qs = qpos_spring[qpos_i : qpos_i + jnt_typ.qpos_width()]
      qfrc = jp.zeros(jnt_typ.dof_width())
      if jnt_typ == JointType.FREE:
        qfrc = qfrc.at[:3].set(-stiffness[i] * (q[:3] - qs[:3]))
        qfrc = qfrc.at[3:6].set(-stiffness[i] * math.quat_sub(q[3:7], qs[3:7]))
      elif jnt_typ == JointType.BALL:
        qfrc = -stiffness[i] * math.quat_sub(q, qs)
      elif jnt_typ in (
          JointType.SLIDE,
          JointType.HINGE,
      ):
        qfrc = -stiffness[i] * (q - qs)
      else:
        raise RuntimeError(f'unrecognized joint type: {jnt_typ}')
      qfrcs.append(qfrc)
      qpos_i += jnt_typ.qpos_width()
    return jp.concatenate(qfrcs)

  # dof-level springs
  qfrc = jp.zeros(m.nv)
  if not m.opt.disableflags & DisableBit.SPRING:
    qfrc = scan.flat(
        m,
        fn,
        'jjqq',
        'v',
        m.jnt_type,
        m.jnt_stiffness,
        m.qpos_spring,
        d.qpos,
    )

  # dof-level dampers
  if not m.opt.disableflags & DisableBit.DAMPER:
    qfrc -= m.dof_damping * d.qvel

  # tendon-level springs
  if not m.opt.disableflags & DisableBit.SPRING:
    below, above = m.tendon_lengthspring.T - d.ten_length
    frc_spring = jp.where(below > 0, m.tendon_stiffness * below, 0)
    frc_spring = jp.where(above < 0, m.tendon_stiffness * above, frc_spring)
  else:
    frc_spring = jp.zeros(m.ntendon)

  # tendon-level dampers
  frc_damper = (
      -m.tendon_damping * d._impl.ten_velocity
      if not m.opt.disableflags & DisableBit.DAMPER
      else jp.zeros(m.ntendon)
  )

  qfrc += d._impl.ten_J.T @ (frc_spring + frc_damper)

  return qfrc


def _gravcomp(m: Model, d: Data) -> jax.Array:
  """Applies body-level gravity compensation."""
  force = -m.opt.gravity * (m.body_mass * m.body_gravcomp)[:, None]

  apply_f = lambda f, pos, body_id: support.jac(m, d, pos, body_id)[0] @ f
  qfrc = jax.vmap(apply_f)(force, d.xipos, jp.arange(m.nbody)).sum(axis=0)

  return qfrc


def _fluid(m: Model, d: Data) -> jax.Array:
  """Applies body-level viscosity, lift and drag."""
  force, torque = jax.vmap(
      _inertia_box_fluid_model, in_axes=(None, 0, 0, 0, 0, 0, 0)
  )(
      m,
      m.body_inertia,
      m.body_mass,
      d.subtree_com[jp.array(m.body_rootid)],
      d.xipos,
      d.ximat,
      d.cvel,
  )
  qfrc = jax.vmap(support.apply_ft, in_axes=(None, None, 0, 0, 0, 0))(
      m, d, force, torque, d.xipos, jp.arange(m.nbody)
  )

  return jp.sum(qfrc, axis=0)


[docs] def passive(m: Model, d: Data) -> Data: """Adds all passive forces.""" if ( not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX) or not isinstance(m.opt._impl, OptionJAX) ): raise ValueError('passive requires JAX backend implementation.') if m.opt.disableflags & (DisableBit.SPRING | DisableBit.DAMPER): return d.replace(qfrc_passive=jp.zeros(m.nv), qfrc_gravcomp=jp.zeros(m.nv)) qfrc_passive = _spring_damper(m, d) qfrc_gravcomp = jp.zeros(m.nv) if m.ngravcomp and not m.opt.disableflags & DisableBit.GRAVITY: qfrc_gravcomp = _gravcomp(m, d) # add gravcomp unless added via actuators qfrc_passive += qfrc_gravcomp * (1 - m.jnt_actgravcomp[m.dof_jntid]) if m.opt._impl.has_fluid_params: # pytype: disable=attribute-error qfrc_passive += _fluid(m, d) d = d.replace(qfrc_passive=qfrc_passive, qfrc_gravcomp=qfrc_gravcomp) return d
def _inertia_box_fluid_model( m: Model, inertia: jax.Array, mass: jax.Array, root_com: jax.Array, xipos: jax.Array, ximat: jax.Array, cvel: jax.Array, ) -> Tuple[jax.Array, jax.Array]: """Fluid forces based on inertia-box approximation.""" if not isinstance(m.opt._impl, OptionJAX): raise ValueError( '_inertia_box_fluid_model requires JAX backend implementation.' ) box = jp.repeat(inertia[None, :], 3, axis=0) box *= jp.ones((3, 3)) - 2 * jp.eye(3) box = 6.0 * jp.clip(jp.sum(box, axis=-1), a_min=1e-12) box = jp.sqrt(box / jp.maximum(mass, 1e-12)) * (mass > 0.0) # transform to local coordinate frame offset = xipos - root_com lvel = math.transform_motion(cvel, offset, ximat) lwind = ximat.T @ m.opt.wind lvel = lvel.at[3:].add(-lwind) # set viscous force and torque diam = jp.mean(box, axis=-1) lfrc_ang = lvel[:3] * -jp.pi * diam**3 * m.opt.viscosity lfrc_vel = lvel[3:] * -3.0 * jp.pi * diam * m.opt.viscosity # add lift and drag force and torque scale_vel = jp.array([box[1] * box[2], box[0] * box[2], box[0] * box[1]]) scale_ang = jp.array([ box[0] * (box[1] ** 4 + box[2] ** 4), box[1] * (box[0] ** 4 + box[2] ** 4), box[2] * (box[0] ** 4 + box[1] ** 4), ]) lfrc_vel -= 0.5 * m.opt.density * scale_vel * jp.abs(lvel[3:]) * lvel[3:] lfrc_ang -= ( 1.0 * m.opt.density * scale_ang * jp.abs(lvel[:3]) * lvel[:3] / 64.0 ) # rotate to global orientation: lfrc -> bfrc force, torque = ximat @ lfrc_vel, ximat @ lfrc_ang return force, torque