Source code for mujoco.mjx._src.forward

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Forward step functions."""

import functools
from typing import Optional, Sequence

import jax
from jax import numpy as jp
import mujoco
from mujoco.mjx._src import collision_driver
from mujoco.mjx._src import constraint
from mujoco.mjx._src import derivative
from mujoco.mjx._src import math
from mujoco.mjx._src import passive
from mujoco.mjx._src import scan
from mujoco.mjx._src import sensor
from mujoco.mjx._src import smooth
from mujoco.mjx._src import solver
from mujoco.mjx._src import support
# pylint: disable=g-importing-member
from mujoco.mjx._src.types import BiasType
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DataJAX
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import DynType
from mujoco.mjx._src.types import GainType
from mujoco.mjx._src.types import Impl
from mujoco.mjx._src.types import IntegratorType
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ModelJAX
from mujoco.mjx._src.types import TrnType
# pylint: enable=g-importing-member
import mujoco.mjx.warp as mjxw
import numpy as np

# RK4 tableau
_RK4_A = np.array([
    [0.5, 0.0, 0.0],
    [0.0, 0.5, 0.0],
    [0.0, 0.0, 1.0],
])
_RK4_B = np.array([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0])


def named_scope(fn, name: str = ''):
  @functools.wraps(fn)
  def wrapper(*args, **kwargs):
    with jax.named_scope(name or getattr(fn, '__name__')):
      res = fn(*args, **kwargs)
    return res

  return wrapper


[docs] @named_scope def fwd_position(m: Model, d: Data) -> Data: """Position-dependent computations.""" # TODO(robotics-simulation): tendon d = smooth.kinematics(m, d) d = smooth.com_pos(m, d) d = smooth.camlight(m, d) d = smooth.tendon(m, d) d = smooth.crb(m, d) d = smooth.tendon_armature(m, d) d = smooth.factor_m(m, d) d = collision_driver.collision(m, d) d = constraint.make_constraint(m, d) d = smooth.transmission(m, d) return d
[docs] @named_scope def fwd_velocity(m: Model, d: Data) -> Data: """Velocity-dependent computations.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('fwd_velocity requires JAX backend implementation.') d = d.tree_replace({ '_impl.actuator_velocity': d._impl.actuator_moment @ d.qvel, '_impl.ten_velocity': d._impl.ten_J @ d.qvel, }) d = smooth.com_vel(m, d) d = passive.passive(m, d) d = smooth.rne(m, d) d = smooth.tendon_bias(m, d) return d
[docs] @named_scope def fwd_actuation(m: Model, d: Data) -> Data: """Actuation-dependent computations.""" if not isinstance(d._impl, DataJAX): raise ValueError('fwd_actuation requires JAX backend implementation.') if not m.nu or m.opt.disableflags & DisableBit.ACTUATION: return d.replace( act_dot=jp.zeros((m.na,)), qfrc_actuator=jp.zeros((m.nv,)), ) ctrl = d.ctrl if not m.opt.disableflags & DisableBit.CLAMPCTRL: ctrlrange = jp.where( m.actuator_ctrllimited[:, None], m.actuator_ctrlrange, jp.array([-jp.inf, jp.inf]), ) ctrl = jp.clip(ctrl, ctrlrange[:, 0], ctrlrange[:, 1]) # act_dot for stateful actuators def get_act_dot(dyn_typ, dyn_prm, ctrl, act): if dyn_typ == DynType.NONE: act_dot = jp.array(0.0) elif dyn_typ == DynType.INTEGRATOR: act_dot = ctrl elif dyn_typ in (DynType.FILTER, DynType.FILTEREXACT): act_dot = (ctrl - act) / jp.clip(dyn_prm[0], mujoco.mjMINVAL) elif dyn_typ == DynType.MUSCLE: act_dot = support.muscle_dynamics(ctrl, act, dyn_prm) else: raise NotImplementedError(f'dyntype {dyn_typ.name} not implemented.') return act_dot act_dot = jp.zeros((m.na,)) if m.na: act_dot = scan.flat( m, get_act_dot, 'uuua', 'a', m.actuator_dyntype, m.actuator_dynprm, ctrl, d.act, group_by='u', ) ctrl_act = ctrl if m.na: act_last_dim = d.act[m.actuator_actadr + m.actuator_actnum - 1] ctrl_act = jp.where(m.actuator_actadr == -1, ctrl, act_last_dim) def get_force(*args): gain_t, gain_p, bias_t, bias_p, len_, vel, ctrl_act, len_range, acc0 = args typ, prm = GainType(gain_t), gain_p if typ == GainType.FIXED: gain = prm[0] elif typ == GainType.AFFINE: gain = prm[0] + prm[1] * len_ + prm[2] * vel elif typ == GainType.MUSCLE: gain = support.muscle_gain(len_, vel, len_range, acc0, prm) else: raise RuntimeError(f'unrecognized gaintype {typ.name}.') typ, prm = BiasType(bias_t), bias_p bias = jp.array(0.0) if typ == BiasType.AFFINE: bias = prm[0] + prm[1] * len_ + prm[2] * vel elif typ == BiasType.MUSCLE: bias = support.muscle_bias(len_, len_range, acc0, prm) return gain * ctrl_act + bias force = scan.flat( m, get_force, 'uuuuuuuuu', 'u', m.actuator_gaintype, m.actuator_gainprm, m.actuator_biastype, m.actuator_biasprm, d.actuator_length, d._impl.actuator_velocity, ctrl_act, jp.array(m.actuator_lengthrange), jp.array(m.actuator_acc0), group_by='u', ) # tendon total force clamping if np.any(m.tendon_actfrclimited): (tendon_actfrclimited_id,) = np.nonzero(m.tendon_actfrclimited) actuator_tendon = m.actuator_trntype == TrnType.TENDON force_mask = [ actuator_tendon & (m.actuator_trnid[:, 0] == tendon_id) for tendon_id in tendon_actfrclimited_id ] force_ids = np.concatenate([np.nonzero(mask)[0] for mask in force_mask]) force_mat = np.array(force_mask)[:, force_ids] tendon_total_force = force_mat @ force[force_ids] force_scaling = jp.where( tendon_total_force < m.tendon_actfrcrange[tendon_actfrclimited_id, 0], m.tendon_actfrcrange[tendon_actfrclimited_id, 0] / tendon_total_force, 1, ) force_scaling = jp.where( tendon_total_force > m.tendon_actfrcrange[tendon_actfrclimited_id, 1], m.tendon_actfrcrange[tendon_actfrclimited_id, 1] / tendon_total_force, force_scaling, ) tendon_forces = force[force_ids] * (force_mat.T @ force_scaling) force = force.at[force_ids].set(tendon_forces) forcerange = jp.where( m.actuator_forcelimited[:, None], m.actuator_forcerange, jp.array([-jp.inf, jp.inf]), ) force = jp.clip(force, forcerange[:, 0], forcerange[:, 1]) qfrc_actuator = d._impl.actuator_moment.T @ force if m.ngravcomp: # actuator-level gravity compensation, skip if added as passive force qfrc_actuator += d.qfrc_gravcomp * m.jnt_actgravcomp[m.dof_jntid] # clamp qfrc_actuator actfrcrange = jp.where( m.jnt_actfrclimited[:, None], m.jnt_actfrcrange, jp.array([-jp.inf, jp.inf]), ) actfrcrange = actfrcrange[m.dof_jntid] qfrc_actuator = jp.clip(qfrc_actuator, actfrcrange[:, 0], actfrcrange[:, 1]) d = d.replace( act_dot=act_dot, qfrc_actuator=qfrc_actuator, actuator_force=force ) return d
[docs] @named_scope def fwd_acceleration(m: Model, d: Data) -> Data: """Add up all non-constraint forces, compute qacc_smooth.""" qfrc_applied = d.qfrc_applied + support.xfrc_accumulate(m, d) qfrc_smooth = d.qfrc_passive - d.qfrc_bias + d.qfrc_actuator + qfrc_applied qacc_smooth = smooth.solve_m(m, d, qfrc_smooth) d = d.replace(qfrc_smooth=qfrc_smooth, qacc_smooth=qacc_smooth) return d
@named_scope def _integrate_pos( jnt_typs: Sequence[str], qpos: jax.Array, qvel: jax.Array, dt: jax.Array ) -> jax.Array: """Integrate position given velocity.""" qs, qi, vi = [], 0, 0 for jnt_typ in jnt_typs: if jnt_typ == JointType.FREE: pos = qpos[qi : qi + 3] + dt * qvel[vi : vi + 3] quat = math.quat_integrate( qpos[qi + 3 : qi + 7], qvel[vi + 3 : vi + 6], dt ) qs.append(jp.concatenate([pos, quat])) qi, vi = qi + 7, vi + 6 elif jnt_typ == JointType.BALL: quat = math.quat_integrate(qpos[qi : qi + 4], qvel[vi : vi + 3], dt) qs.append(quat) qi, vi = qi + 4, vi + 3 elif jnt_typ in (JointType.HINGE, JointType.SLIDE): pos = qpos[qi] + dt * qvel[vi] qs.append(pos[None]) qi, vi = qi + 1, vi + 1 else: raise RuntimeError(f'unrecognized joint type: {jnt_typ}') return jp.concatenate(qs) if qs else jp.empty((0,)) def _next_activation(m: Model, d: Data, act_dot: jax.Array) -> jax.Array: """Returns the next act given the current act_dot, after clamping.""" act = d.act if not m.na: return act actrange = jp.where( m.actuator_actlimited[:, None], m.actuator_actrange, jp.array([-jp.inf, jp.inf]), ) def fn(dyntype, dynprm, act, act_dot, actrange): if dyntype == DynType.FILTEREXACT: tau = jp.clip(dynprm[0], a_min=mujoco.mjMINVAL) act = act + act_dot * tau * (1 - jp.exp(-m.opt.timestep / tau)) else: act = act + act_dot * m.opt.timestep act = jp.clip(act, actrange[0], actrange[1]) return act args = (m.actuator_dyntype, m.actuator_dynprm, act, act_dot, actrange) act = scan.flat(m, fn, 'uuaau', 'a', *args, group_by='u') return act.reshape(m.na) @named_scope def _advance( m: Model, d: Data, act_dot: jax.Array, qacc: jax.Array, qvel: Optional[jax.Array] = None, ) -> Data: """Advance state and time given activation derivatives and acceleration.""" act = _next_activation(m, d, act_dot) # advance velocities d = d.replace(qvel=d.qvel + qacc * m.opt.timestep) # advance positions with qvel if given, d.qvel otherwise (semi-implicit) qvel = d.qvel if qvel is None else qvel integrate_fn = lambda *args: _integrate_pos(*args, dt=m.opt.timestep) qpos = scan.flat(m, integrate_fn, 'jqv', 'q', m.jnt_type, d.qpos, qvel) # advance time time = d.time + m.opt.timestep # save qacc for next step warmstart d = d.replace(qacc_warmstart=d.qacc) return d.replace(act=act, qpos=qpos, time=time)
[docs] @named_scope def euler(m: Model, d: Data) -> Data: """Euler integrator, semi-implicit in velocity.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('euler requires JAX backend implementation.') # integrate damping implicitly qacc = d.qacc if not m.opt.disableflags & DisableBit.EULERDAMP: if support.is_sparse(m): qM = d._impl.qM.at[m.dof_Madr].add(m.opt.timestep * m.dof_damping) else: qM = d._impl.qM + jp.diag(m.opt.timestep * m.dof_damping) dh = d.tree_replace({'_impl.qM': qM}) dh = smooth.factor_m(m, dh) qfrc = d.qfrc_smooth + d.qfrc_constraint qacc = smooth.solve_m(m, dh, qfrc) return _advance(m, d, d.act_dot, qacc)
[docs] @named_scope def rungekutta4(m: Model, d: Data) -> Data: """Runge-Kutta explicit order 4 integrator.""" d0 = d # pylint: disable=invalid-name A, B = _RK4_A, _RK4_B C = jp.tril(A).sum(axis=0) # C(i) = sum_j A(i,j) T = d.time + C * m.opt.timestep # pylint: enable=invalid-name kqvel = d.qvel # intermediate RK solution # RK solutions sum qvel, qacc, act_dot = jax.tree_util.tree_map( lambda k: B[0] * k, (kqvel, d.qacc, d.act_dot) ) integrate_fn = lambda *args: _integrate_pos(*args, dt=m.opt.timestep) def f(carry, x): qvel, qacc, act_dot, kqvel, d = carry a, b, t = x # tableau numbers dqvel, dqacc, dact_dot = jax.tree_util.tree_map( lambda k: a * k, (kqvel, d.qacc, d.act_dot) ) # get intermediate RK solutions kqpos = scan.flat(m, integrate_fn, 'jqv', 'q', m.jnt_type, d0.qpos, dqvel) kact = d0.act + dact_dot * m.opt.timestep kqvel = d0.qvel + dqacc * m.opt.timestep d = d.replace(qpos=kqpos, qvel=kqvel, act=kact, time=t) d = forward(m, d) qvel += b * kqvel qacc += b * d.qacc act_dot += b * d.act_dot return (qvel, qacc, act_dot, kqvel, d), None abt = jp.vstack([jp.diag(A), B[1:4], T]).T out, _ = jax.lax.scan(f, (qvel, qacc, act_dot, kqvel, d), abt, unroll=3) qvel, qacc, act_dot, _, d1 = out d = d1.replace(qpos=d0.qpos, qvel=d0.qvel, act=d0.act, time=d0.time) d = _advance(m, d, act_dot, qacc, qvel) return d
[docs] @named_scope def implicit(m: Model, d: Data) -> Data: """Integrates fully implicit in velocity.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('implicit requires JAX backend implementation.') qderiv = derivative.deriv_smooth_vel(m, d) qacc = d.qacc if qderiv is not None: # TODO(robotics-simulation): use smooth.factor_m / solve_m here: qm = support.full_m(m, d) if support.is_sparse(m) else d._impl.qM qm -= m.opt.timestep * qderiv qh, _ = jax.scipy.linalg.cho_factor(qm) qfrc = d.qfrc_smooth + d.qfrc_constraint qacc = jax.scipy.linalg.cho_solve((qh, False), qfrc) return _advance(m, d, d.act_dot, qacc)
[docs] @named_scope def forward(m: Model, d: Data) -> Data: """Forward dynamics.""" if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: from mujoco.mjx.warp import forward as mjxw_forward # pylint: disable=g-import-not-at-top # pytype: disable=import-error return mjxw_forward.forward(m, d) if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('forward requires JAX backend implementation.') d = fwd_position(m, d) d = sensor.sensor_pos(m, d) d = fwd_velocity(m, d) d = sensor.sensor_vel(m, d) d = fwd_actuation(m, d) d = fwd_acceleration(m, d) if d._impl.efc_J.size == 0: d = d.replace(qacc=d.qacc_smooth) return d d = named_scope(solver.solve)(m, d) d = sensor.sensor_acc(m, d) return d
[docs] @named_scope def step(m: Model, d: Data) -> Data: """Advance simulation.""" if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: from mujoco.mjx.warp import forward as mjxw_forward # pylint: disable=g-import-not-at-top # pytype: disable=import-error return mjxw_forward.step(m, d) d = forward(m, d) if m.opt.integrator == IntegratorType.EULER: d = euler(m, d) elif m.opt.integrator == IntegratorType.RK4: d = rungekutta4(m, d) elif m.opt.integrator == IntegratorType.IMPLICITFAST: d = implicit(m, d) else: raise NotImplementedError(f'integrator {m.opt.integrator} not implemented.') return d