Source code for mujoco.mjx._src.smooth

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core smooth dynamics functions."""

import jax
from jax import numpy as jp
import mujoco
from mujoco.mjx._src import math
from mujoco.mjx._src import scan
from mujoco.mjx._src import support
# pylint: disable=g-importing-member
from mujoco.mjx._src.types import CamLightType
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DataJAX
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import EqType
from mujoco.mjx._src.types import Impl
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ModelJAX
from mujoco.mjx._src.types import ObjType
from mujoco.mjx._src.types import TrnType
from mujoco.mjx._src.types import WrapType
# pylint: enable=g-importing-member
import mujoco.mjx.warp as mjxw
import numpy as np


[docs] def kinematics(m: Model, d: Data) -> Data: """Converts position/velocity from generalized coordinates to maximal.""" if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error return mjxw_smooth.kinematics(m, d) def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat): # calculate joint anchors, axes, body pos and quat in global frame # also normalize qpos while we're at it if carry is not None: _, _, _, parent_pos, parent_quat, _ = carry pos = parent_pos + math.rotate(pos, parent_quat) quat = math.quat_mul(parent_quat, quat) anchors, axes = [], [] qpos_i = 0 for i, jnt_typ in enumerate(jnt_typs): if jnt_typ == JointType.FREE: anchor, axis = qpos[qpos_i : qpos_i + 3], jp.array([0.0, 0.0, 1.0]) else: anchor = math.rotate(jnt_pos[i], quat) + pos axis = math.rotate(jnt_axis[i], quat) anchors, axes = anchors + [anchor], axes + [axis] if jnt_typ == JointType.FREE: pos = qpos[qpos_i : qpos_i + 3] quat = math.normalize(qpos[qpos_i + 3 : qpos_i + 7]) qpos = qpos.at[qpos_i + 3 : qpos_i + 7].set(quat) qpos_i += 7 elif jnt_typ == JointType.BALL: qloc = math.normalize(qpos[qpos_i : qpos_i + 4]) qpos = qpos.at[qpos_i : qpos_i + 4].set(qloc) quat = math.quat_mul(quat, qloc) pos = anchor - math.rotate(jnt_pos[i], quat) # off-center rotation qpos_i += 4 elif jnt_typ == JointType.HINGE: angle = qpos[qpos_i] - qpos0[qpos_i] qloc = math.axis_angle_to_quat(jnt_axis[i], angle) quat = math.quat_mul(quat, qloc) pos = anchor - math.rotate(jnt_pos[i], quat) # off-center rotation qpos_i += 1 elif jnt_typ == JointType.SLIDE: pos += axis * (qpos[qpos_i] - qpos0[qpos_i]) qpos_i += 1 else: raise RuntimeError(f'unrecognized joint type: {jnt_typ}') anchor = jp.stack(anchors) if anchors else jp.empty((0, 3)) axis = jp.stack(axes) if axes else jp.empty((0, 3)) mat = math.quat_to_mat(quat) return qpos, anchor, axis, pos, quat, mat qpos, xanchor, xaxis, xpos, xquat, xmat = scan.body_tree( m, fn, 'jjjqqbb', 'qjjbbb', m.jnt_type, m.jnt_pos, m.jnt_axis, d.qpos, m.qpos0, m.body_pos, m.body_quat, ) if m.nmocap: xpos = xpos.at[m.body_mocapid >= 0].set(d.mocap_pos) mocap_quat = jax.vmap(math.normalize)(d.mocap_quat) xquat = xquat.at[m.body_mocapid >= 0].set(mocap_quat) xmat = xmat.at[m.body_mocapid >= 0].set( jax.vmap(math.quat_to_mat)(mocap_quat) ) v_local_to_global = jax.vmap(support.local_to_global) # TODO(erikfrey): confirm that quats are more performant for mjx than mats xipos, ximat = v_local_to_global(xpos, xquat, m.body_ipos, m.body_iquat) d = d.replace(qpos=qpos, xanchor=xanchor, xaxis=xaxis, xpos=xpos) d = d.replace(xquat=xquat, xmat=xmat, xipos=xipos, ximat=ximat) if m.ngeom: geom_xpos, geom_xmat = v_local_to_global( xpos[m.geom_bodyid], xquat[m.geom_bodyid], m.geom_pos, m.geom_quat ) d = d.replace(geom_xpos=geom_xpos, geom_xmat=geom_xmat) if m.nsite: site_xpos, site_xmat = v_local_to_global( xpos[m.site_bodyid], xquat[m.site_bodyid], m.site_pos, m.site_quat ) d = d.replace(site_xpos=site_xpos, site_xmat=site_xmat) return d
[docs] def com_pos(m: Model, d: Data) -> Data: """Maps inertias and motion dofs to global frame centered at subtree-CoM.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('com_pos requires JAX backend implementation.') # calculate center of mass of each subtree def subtree_sum(carry, xipos, body_mass): pos, mass = xipos * body_mass, body_mass if carry is not None: subtree_pos, subtree_mass = carry pos, mass = pos + subtree_pos, mass + subtree_mass return pos, mass pos, mass = scan.body_tree( m, subtree_sum, 'bb', 'bb', d.xipos, m.body_mass, reverse=True ) cond = jp.tile(mass < mujoco.mjMINVAL, (3, 1)).T # take maximum to avoid NaN in gradient of jp.where subtree_com = jax.vmap(jp.divide)(pos, jp.maximum(mass, mujoco.mjMINVAL)) subtree_com = jp.where(cond, d.xipos, subtree_com) d = d.replace(subtree_com=subtree_com) # map inertias to frame centered at subtree_com @jax.vmap def inert_com(inert, ximat, off, mass): h = jp.cross(off, -jp.eye(3)) inert = math.matmul_unroll((ximat * inert), ximat.T) inert += math.matmul_unroll(h, h.T) * mass # cinert is triu(inert), mass * off, mass inert = inert[([0, 1, 2, 0, 0, 1], [0, 1, 2, 1, 2, 2])] return jp.concatenate([inert, off * mass, mass[None]]) root_com = subtree_com[m.body_rootid] offset = d.xipos - root_com cinert = inert_com(m.body_inertia, d.ximat, offset, m.body_mass) d = d.tree_replace({'_impl.cinert': cinert}) # map motion dofs to global frame centered at subtree_com def cdof_fn(jnt_typs, root_com, xmat, xanchor, xaxis): cdofs = [] dof_com_fn = lambda a, o: jp.concatenate([a, jp.cross(a, o)]) for i, jnt_typ in enumerate(jnt_typs): offset = root_com - xanchor[i] if jnt_typ == JointType.FREE: cdofs.append(jp.eye(3, 6, 3)) # free translation cdofs.append(jax.vmap(dof_com_fn, in_axes=(0, None))(xmat.T, offset)) elif jnt_typ == JointType.BALL: cdofs.append(jax.vmap(dof_com_fn, in_axes=(0, None))(xmat.T, offset)) elif jnt_typ == JointType.HINGE: cdof = dof_com_fn(xaxis[i], offset) cdofs.append(jp.expand_dims(cdof, 0)) elif jnt_typ == JointType.SLIDE: cdof = jp.concatenate((jp.zeros((3,)), xaxis[i])) cdofs.append(jp.expand_dims(cdof, 0)) else: raise RuntimeError(f'unrecognized joint type: {jnt_typ}') cdof = jp.concatenate(cdofs) if cdofs else jp.empty((0, 6)) return cdof cdof = scan.flat( m, cdof_fn, 'jbbjj', 'v', m.jnt_type, root_com, d.xmat, d.xanchor, d.xaxis, ) d = d.tree_replace({'cdof': cdof}) return d
[docs] def camlight(m: Model, d: Data) -> Data: """Computes camera and light positions and orientations.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('camlight requires JAX backend implementation.') if m.ncam == 0: return d.replace(cam_xpos=jp.zeros((0, 3)), cam_xmat=jp.zeros((0, 3, 3))) # use target body only if target body is specified is_target_cam = (m.cam_mode == CamLightType.TARGETBODY) | ( m.cam_mode == CamLightType.TARGETBODYCOM ) cam_mode = np.where( is_target_cam & (m.cam_targetbodyid < 0), CamLightType.FIXED, m.cam_mode ) cam_xpos, cam_xmat = jax.vmap(support.local_to_global)( d.xpos[m.cam_bodyid], d.xquat[m.cam_bodyid], m.cam_pos, m.cam_quat ) def fn( camid, cam_mode, cam_xpos, cam_xmat, body_xpos, subtree_com, target_body_xpos, target_subtree_com, ): if cam_mode == CamLightType.TRACK: cam_xmat = m.cam_mat0[camid] cam_xpos = body_xpos + m.cam_pos0[camid] elif cam_mode == CamLightType.TRACKCOM: cam_xmat = m.cam_mat0[camid] cam_xpos = subtree_com + m.cam_poscom0[camid] elif cam_mode in (CamLightType.TARGETBODY, CamLightType.TARGETBODYCOM): # get position to look at pos = target_body_xpos if cam_mode == CamLightType.TARGETBODYCOM: pos = target_subtree_com # zaxis = -desired camera direction, in global frame mat_3 = math.normalize(cam_xpos - pos) # xaxis: orthogonal to zaxis and to (0,0,1) mat_1 = math.normalize(jp.cross(jp.array([0.0, 0.0, 1.0]), mat_3)) mat_2 = math.normalize(jp.cross(mat_3, mat_1)) cam_xmat = jp.array([mat_1, mat_2, mat_3]).T return cam_xpos, cam_xmat cam_xpos, cam_xmat = scan.flat( m, fn, 'c' * 8, 'cc', jp.arange(m.ncam), cam_mode, cam_xpos, cam_xmat, d.xpos[m.cam_bodyid], d.subtree_com[m.cam_bodyid], d.xpos[m.cam_targetbodyid], d.subtree_com[m.cam_targetbodyid], group_by='c', ) d = d.replace( cam_xpos=cam_xpos, cam_xmat=cam_xmat, ) return d
[docs] def crb(m: Model, d: Data) -> Data: """Runs composite rigid body inertia algorithm.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('crb requires JAX backend implementation.') def crb_fn(crb_child, crb_body): if crb_child is not None: crb_body += crb_child return crb_body crb_body = scan.body_tree(m, crb_fn, 'b', 'b', d._impl.cinert, reverse=True) crb_body = crb_body.at[0].set(0.0) d = d.tree_replace({'_impl.crb': crb_body}) crb_dof = jp.take(crb_body, jp.array(m.dof_bodyid), axis=0) crb_cdof = jax.vmap(math.inert_mul)(crb_dof, d.cdof) qm = support.make_m(m, crb_cdof, d.cdof, m.dof_armature) d = d.tree_replace({'_impl.qM': qm}) return d
[docs] def factor_m(m: Model, d: Data) -> Data: """Gets factorizaton of inertia-like matrix M, assumed spd.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('factor_m requires JAX backend implementation.') if not support.is_sparse(m): qh, _ = jax.scipy.linalg.cho_factor(d._impl.qM) d = d.tree_replace({'_impl.qLD': qh}) return d # build up indices for where we will do backwards updates over qLD depth = [] for i in range(m.nv): depth.append(depth[m.dof_parentid[i]] + 1 if m.dof_parentid[i] != -1 else 0) updates = {} madr_ds = [] for i in range(m.nv): madr_d = madr_ij = m.dof_Madr[i] j = i while True: madr_ds.append(madr_d) madr_ij, j = madr_ij + 1, m.dof_parentid[j] if j == -1: break out_beg, out_end = tuple(m.dof_Madr[j : j + 2]) updates.setdefault(depth[j], []).append( (out_beg, out_end, madr_d, madr_ij) ) qld = d._impl.qM for _, updates in sorted(updates.items(), reverse=True): # combine the updates into one update batch (per depth level) rows = [] madr_ijs = [] pivots = [] out = [] for b, e, madr_d, madr_ij in updates: width = e - b rows.append(np.arange(madr_ij, madr_ij + width)) madr_ijs.append(np.full((width,), madr_ij)) pivots.append(np.full((width,), madr_d)) out.append(np.arange(b, e)) rows = np.concatenate(rows) madr_ijs = np.concatenate(madr_ijs) pivots = np.concatenate(pivots) out = np.concatenate(out) # apply the update batch qld = qld.at[out].add(-(qld[madr_ijs] / qld[pivots]) * qld[rows]) # TODO(erikfrey): determine if this minimum value guarding is necessary: # qld = qld.at[dof_madr].set(jp.maximum(qld[dof_madr], _MJ_MINVAL)) qld_diag = qld[m.dof_Madr] qld = (qld / qld[jp.array(madr_ds)]).at[m.dof_Madr].set(qld_diag) d = d.tree_replace({'_impl.qLD': qld, '_impl.qLDiagInv': 1 / qld_diag}) return d
def solve_m(m: Model, d: Data, x: jax.Array) -> jax.Array: """Computes sparse backsubstitution: x = inv(L'*D*L)*y .""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('solve_m requires JAX backend implementation.') if not support.is_sparse(m): return jax.scipy.linalg.cho_solve((d._impl.qLD, False), x) depth = [] for i in range(m.nv): depth.append(depth[m.dof_parentid[i]] + 1 if m.dof_parentid[i] != -1 else 0) updates_i, updates_j = {}, {} for i in range(m.nv): madr_ij, j = m.dof_Madr[i], i while True: madr_ij, j = madr_ij + 1, m.dof_parentid[j] if j == -1: break updates_i.setdefault(depth[i], []).append((i, madr_ij, j)) updates_j.setdefault(depth[j], []).append((j, madr_ij, i)) # x <- inv(L') * x for _, vals in sorted(updates_j.items(), reverse=True): j, madr_ij, i = np.array(vals).T x = x.at[j].add(-d._impl.qLD[madr_ij] * x[i]) # x <- inv(D) * x x = x * d._impl.qLDiagInv # x <- inv(L) * x for _, vals in sorted(updates_i.items()): i, madr_ij, j = np.array(vals).T x = x.at[i].add(-d._impl.qLD[madr_ij] * x[j]) return x
[docs] def com_vel(m: Model, d: Data) -> Data: """Computes cvel, cdof_dot.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('com_vel requires JAX backend implementation.') # forward scan down tree: accumulate link center of mass velocity def fn(parent, jnt_typs, cdof, qvel): cvel = jp.zeros((6,)) if parent is None else parent[0] cross_fn = jax.vmap(math.motion_cross, in_axes=(None, 0)) cdof_x_qvel = jax.vmap(jp.multiply)(cdof, qvel) dof_beg = 0 cdof_dots = [] for jnt_typ in jnt_typs: dof_end = dof_beg + JointType(jnt_typ).dof_width() if jnt_typ == JointType.FREE: cvel += jp.sum(cdof_x_qvel[:3], axis=0) cdof_ang_dot = cross_fn(cvel, cdof[3:]) cvel += jp.sum(cdof_x_qvel[3:], axis=0) cdof_dots.append(jp.concatenate((jp.zeros((3, 6)), cdof_ang_dot))) else: cdof_dots.append(cross_fn(cvel, cdof[dof_beg:dof_end])) cvel += jp.sum(cdof_x_qvel[dof_beg:dof_end], axis=0) dof_beg = dof_end cdof_dot = jp.concatenate(cdof_dots) if cdof_dots else jp.empty((0, 6)) return cvel, cdof_dot cvel, cdof_dot = scan.body_tree( m, fn, 'jvv', 'bv', m.jnt_type, d.cdof, d.qvel, ) d = d.tree_replace({'cvel': cvel, 'cdof_dot': cdof_dot}) return d
[docs] def subtree_vel(m: Model, d: Data) -> Data: """Subtree linear velocity and angular momentum.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('subtree_vel requires JAX backend implementation.') # bodywise quantities def _forward(cvel, xipos, ximat, subtree_com_root, mass, inertia): ang, lin = jp.split(cvel, 2) # update linear velocity lin = lin - jp.cross(xipos - subtree_com_root, ang) subtree_linvel = mass * lin subtree_angmom = inertia * ximat @ ximat.T @ ang body_vel = jp.concatenate([ang, lin]) return body_vel, subtree_linvel, subtree_angmom body_vel, subtree_linvel, subtree_angmom = jax.vmap(_forward)( d.cvel, d.xipos, d.ximat, d.subtree_com[m.body_rootid], m.body_mass, m.body_inertia, ) # sum body linear momentum recursively up the kinematic tree subtree_linvel = scan.body_tree( m, lambda x, y: y if x is None else x + y, 'bb', 'b', subtree_linvel, reverse=True, ) subtree_linvel /= jp.maximum(mujoco.mjMINVAL, m.body_subtreemass)[:, None] def _subtree_angmom( carry, angmom, com, com_parent, linvel, linvel_parent, subtreemass, xipos, vel, mass, mask, ): def _momentum(x0, x1, v0, v1, m): dx = x0 - x1 dv = v0 - v1 dp = dv * m return jp.cross(dx, dp) # momentum wrt current body mom = mask * _momentum(xipos, com, vel[3:], linvel, mass) # momentum wrt parent mom_parent = mask * _momentum( com, com_parent, linvel, linvel_parent, subtreemass ) if carry is None: return angmom + mom, mom_parent else: angmom_child, mom_parent_child = carry return angmom + mom + angmom_child + mom_parent_child, mom_parent subtree_angmom, _ = scan.body_tree( m, _subtree_angmom, 'bbbbbbbbbb', 'bb', subtree_angmom, d.subtree_com, d.subtree_com[m.body_parentid], subtree_linvel, subtree_linvel[m.body_parentid], m.body_subtreemass, d.xipos, body_vel, m.body_mass, jp.ones(m.nbody).at[0].set(0), reverse=True, ) return d.tree_replace({ '_impl.subtree_linvel': subtree_linvel, '_impl.subtree_angmom': subtree_angmom, })
[docs] def rne(m: Model, d: Data, flg_acc: bool = False) -> Data: """Computes inverse dynamics using the recursive Newton-Euler algorithm. flg_acc=False removes inertial term. """ if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('rne requires JAX backend implementation.') # forward scan over tree: accumulate link center of mass acceleration def cacc_fn(cacc, cdof_dot, qvel, cdof, qacc): if cacc is None: if m.opt.disableflags & DisableBit.GRAVITY: cacc = jp.zeros((6,)) else: cacc = jp.concatenate((jp.zeros((3,)), -m.opt.gravity)) cacc += jp.sum(jax.vmap(jp.multiply)(cdof_dot, qvel), axis=0) # cacc += cdof * qacc if flg_acc: cacc += jp.sum(jax.vmap(jp.multiply)(cdof, qacc), axis=0) return cacc cacc = scan.body_tree( m, cacc_fn, 'vvvv', 'b', d.cdof_dot, d.qvel, d.cdof, d.qacc ) def frc(cinert, cacc, cvel): frc = math.inert_mul(cinert, cacc) frc += math.motion_cross_force(cvel, math.inert_mul(cinert, cvel)) return frc loc_cfrc = jax.vmap(frc)(d._impl.cinert, cacc, d.cvel) # backward scan up tree: accumulate body forces def cfrc_fn(cfrc_child, cfrc): if cfrc_child is not None: cfrc += cfrc_child return cfrc cfrc = scan.body_tree(m, cfrc_fn, 'b', 'b', loc_cfrc, reverse=True) qfrc_bias = jax.vmap(jp.dot)(d.cdof, cfrc[jp.array(m.dof_bodyid)]) d = d.replace(qfrc_bias=qfrc_bias) return d
[docs] def rne_postconstraint(m: Model, d: Data) -> Data: """RNE with complete data: compute cacc, cfrc_ext, cfrc_int.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('rne_postconstraint requires JAX backend implementation.') def _transform_force(frc, offset): force, torque = jp.split(frc, 2) torque -= jp.cross(offset, force) # spatial motion vector layout is flipped: (torque, force) return jp.concatenate([torque, force]) # cfrc_ext = perturb cfrc_ext = jp.vstack([ jp.zeros((1, 6)), # world body jax.vmap(_transform_force)( d.xfrc_applied[1:], d.subtree_com[m.body_rootid][1:] - d.xipos[1:] ), ]) # cfrc_ext += contacts # compute contact forces for each condim forces = [] condim_idx = [] for dim in set(d._impl.contact.dim): force, idx = support.contact_force_dim(m, d, dim) forces.append(force) condim_idx.append(idx) # update cfrc_ext with contact forces if forces: @jax.vmap def _contact_force_to_cfrc_ext(force, pos, frame, id1, id2, com1, com2): # force: contact to world frame force = force.reshape((-1, 3)) @ frame force = force.reshape(-1) # contact force on bodies cfrc_com1 = _transform_force(force, com1 - pos) cfrc_com2 = _transform_force(force, com2 - pos) # mask mask1 = id1 != 0 mask2 = id2 != 0 return jp.vstack([-1 * cfrc_com1 * mask1, cfrc_com2 * mask2]), jp.array( [id1, id2] ) condim_idx = jp.concatenate(condim_idx) frame = d._impl.contact.frame[condim_idx] pos = d._impl.contact.pos[condim_idx] id1 = jp.array(m.geom_bodyid)[d._impl.contact.geom[condim_idx, 0]] id2 = jp.array(m.geom_bodyid)[d._impl.contact.geom[condim_idx, 1]] com1 = d.subtree_com[jp.array(m.body_rootid)][id1] com2 = d.subtree_com[jp.array(m.body_rootid)][id2] cfrc_contact, cfrc_idx = _contact_force_to_cfrc_ext( jp.concatenate(forces), pos, frame, id1, id2, com1, com2 ) cfrc_ext = cfrc_ext.at[cfrc_idx.reshape(-1)].add( cfrc_contact.reshape((-1, 6)) ) # cfrc_ext += connect, weld cfrc_ext_equality = [] cfrc_ext_equality_adr = [] connect_id = m.eq_type == EqType.CONNECT nconnect = connect_id.sum() if nconnect: cfrc_connect_force = d._impl.efc_force[: 3 * nconnect].reshape( (nconnect, 3) ) is_site = m.eq_objtype == ObjType.SITE body1id = np.copy(m.eq_obj1id) body2id = np.copy(m.eq_obj2id) pos1 = m.eq_data[:, :3] pos2 = m.eq_data[:, 3:6] if m.nsite: body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]] body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]] pos1 = jp.where(is_site[:, None], m.site_pos[m.eq_obj1id], pos1) pos2 = jp.where(is_site[:, None], m.site_pos[m.eq_obj2id], pos2) # body 1 k1_connect = body1id[connect_id] k1_connect_mask = k1_connect != 0 offset1_connect = pos1[connect_id] pos1_connect = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)( offset1_connect, d.xmat[k1_connect], d.xpos[k1_connect] ) subtree_com1_connect = d.subtree_com[jp.array(m.body_rootid)[k1_connect]] cfrc_com1_connect = jax.vmap( lambda dif, frc, mask: mask * jp.concatenate([-jp.cross(dif, frc), frc]) )(subtree_com1_connect - pos1_connect, cfrc_connect_force, k1_connect_mask) # body 2 k2_connect = body2id[connect_id] k2_connect_mask = -1 * (k2_connect != 0) offset2_connect = pos2[connect_id] pos2_connect = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)( offset2_connect, d.xmat[k2_connect], d.xpos[k2_connect] ) subtree_com2_connect = d.subtree_com[jp.array(m.body_rootid)[k2_connect]] cfrc_com2_connect = jax.vmap( lambda dif, frc, mask: mask * jp.concatenate([-jp.cross(dif, frc), frc]) )(subtree_com2_connect - pos2_connect, cfrc_connect_force, k2_connect_mask) cfrc_ext_equality.append(jp.vstack([cfrc_com1_connect, cfrc_com2_connect])) cfrc_ext_equality_adr.append(jp.concatenate([k1_connect, k2_connect])) weld_id = m.eq_type == EqType.WELD nweld = weld_id.sum() if nweld: cfrc_weld = d._impl.efc_force[ 3 * nconnect : 3 * nconnect + 6 * nweld ].reshape((nweld, 6)) cfrc_weld_force = cfrc_weld[:, :3] cfrc_weld_torque = cfrc_weld[:, 3:] is_site = m.eq_objtype == ObjType.SITE body1id = np.copy(m.eq_obj1id) body2id = np.copy(m.eq_obj2id) pos1 = m.eq_data[:, 3:6] pos2 = m.eq_data[:, :3] if m.nsite: body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]] body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]] pos1 = jp.where(is_site[:, None], m.site_pos[m.eq_obj1id], pos1) pos2 = jp.where(is_site[:, None], m.site_pos[m.eq_obj2id], pos2) # body 1 k1_weld = body1id[weld_id] k1_weld_mask = k1_weld != 0 offset1_weld = pos1[weld_id] pos1_weld = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)( offset1_weld, d.xmat[k1_weld], d.xpos[k1_weld] ) subtree_com1_weld = d.subtree_com[jp.array(m.body_rootid)[k1_weld]] cfrc_com1_weld = jax.vmap( lambda dif, frc, trq, mask: mask * jp.concatenate([trq - jp.cross(dif, frc), frc]) )( subtree_com1_weld - pos1_weld, cfrc_weld_force, cfrc_weld_torque, k1_weld_mask, ) # body 2 k2_weld = body2id[weld_id] k2_weld_mask = -1 * (k2_weld != 0) offset2_weld = pos2[weld_id] pos2_weld = jax.vmap(lambda pnt, mat, vec: mat @ pnt + vec)( offset2_weld, d.xmat[k2_weld], d.xpos[k2_weld] ) subtree_com2_weld = d.subtree_com[jp.array(m.body_rootid)[k2_weld]] cfrc_com2_weld = jax.vmap( lambda dif, frc, trq, mask: mask * jp.concatenate([trq - jp.cross(dif, frc), frc]) )( subtree_com2_weld - pos2_weld, cfrc_weld_force, cfrc_weld_torque, k2_weld_mask, ) cfrc_ext_equality.append(jp.vstack([cfrc_com1_weld, cfrc_com2_weld])) cfrc_ext_equality_adr.append(jp.concatenate([k1_weld, k2_weld])) if nconnect or nweld: cfrc_ext = cfrc_ext.at[jp.concatenate(cfrc_ext_equality_adr)].add( jp.vstack(cfrc_ext_equality) ) # forward pass over bodies: compute cacc, cfrc_int def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum): if carry is None: if m.opt.disableflags & DisableBit.GRAVITY: cacc0 = jp.zeros(6) else: cacc0 = jp.concatenate((jp.zeros(3), -m.opt.gravity)) return cacc0, jp.zeros(6) else: cacc_parent, _ = carry # create dof mask indices = jp.arange(m.nv) mask = jp.logical_and( indices >= body_dofadr, indices < body_dofadr + body_dofnum ) # cacc = cacc_parent + cdofdot * qvel + cdof * qacc cacc_vel = d.cdof_dot.T @ (mask * d.qvel) cacc_acc = d.cdof.T @ (mask * d.qacc) cacc = cacc_parent + cacc_vel + cacc_acc # cfrc_body = cinert * cacc + cvel x (cinert * cvel) cfrc_body = math.inert_mul(cinert, cacc) cfrc_corr = math.inert_mul(cinert, cvel) cfrc = math.motion_cross_force(cvel, cfrc_corr) cfrc_body = cfrc_body + cfrc cfrc_int = cfrc_body - cfrc_ext return cacc, cfrc_int cacc, cfrc_int = scan.body_tree( m, _forward, 'bbbbb', 'bb', cfrc_ext, d._impl.cinert, d.cvel, jp.array(m.body_dofadr), jp.array(m.body_dofnum), ) # backward pass over bodies: accumulate cfrc_int from children cfrc_int = scan.body_tree( m, lambda c, p: p + c if c is not None else p, # add child to parent 'b', 'b', cfrc_int, reverse=True, ) # update data return d.tree_replace({ '_impl.cacc': cacc, '_impl.cfrc_int': cfrc_int, '_impl.cfrc_ext': cfrc_ext, })
[docs] def tendon(m: Model, d: Data) -> Data: """Computes tendon lengths and moments.""" if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error return mjxw_smooth.tendon(m, d) if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('tendon requires JAX backend implementation.') if not m.ntendon: return d # process joint tendons (wrap_id_jnt,) = np.nonzero(m.wrap_type == WrapType.JOINT) (tendon_id_jnt,) = np.nonzero(np.isin(m.tendon_adr, wrap_id_jnt)) ntendon_jnt = tendon_id_jnt.size wrap_objid_jnt = m.wrap_objid[wrap_id_jnt] tendon_num_jnt = m.tendon_num[tendon_id_jnt] moment_jnt = m.wrap_prm[wrap_id_jnt] length_jnt = jax.ops.segment_sum( moment_jnt * d.qpos[m.jnt_qposadr[wrap_objid_jnt]], np.repeat(np.arange(ntendon_jnt), tendon_num_jnt), ntendon_jnt, ) adr_moment_jnt = np.repeat(tendon_id_jnt, tendon_num_jnt) dofadr_moment_jnt = m.jnt_dofadr[wrap_objid_jnt] # process pulleys (wrap_id_pulley,) = np.nonzero(m.wrap_type == WrapType.PULLEY) nwrap_pulley = wrap_id_pulley.size tendon_wrapnum_pulley = np.array([ sum((wrap_id_pulley >= adr) & (wrap_id_pulley < adr + num)) for adr, num in zip(m.tendon_adr, m.tendon_num) ]) divisor = np.ones(m.nwrap) for adr, num in zip(m.tendon_adr, m.tendon_num): for id_pulley in wrap_id_pulley: if adr <= id_pulley < adr + num: divisor[id_pulley : adr + num] = np.maximum( mujoco.mjMINVAL, m.wrap_prm[id_pulley] ) # process spatial tendon sites (wrap_id_site,) = np.nonzero(m.wrap_type == WrapType.SITE) nwrap_site = wrap_id_site.size # find consecutive sites, skipping tendon transitions (pair_id,) = np.nonzero(np.diff(wrap_id_site) == 1) wrap_id_site_pair = np.setdiff1d(wrap_id_site[pair_id], m.tendon_adr[1:] - 1) wrap_objid_site0 = m.wrap_objid[wrap_id_site_pair] wrap_objid_site1 = m.wrap_objid[wrap_id_site_pair + 1] @jax.vmap def _length_moment(pnt0, pnt1, body0, body1): dif = pnt1 - pnt0 length = math.norm(dif) vec = jp.where( length < mujoco.mjMINVAL, jp.array([1.0, 0.0, 0.0]), math.safe_div(dif, length), ) jacp1, _ = support.jac(m, d, pnt0, body0) jacp2, _ = support.jac(m, d, pnt1, body1) jacdif = jacp2 - jacp1 moment = jp.where(body0 != body1, jacdif @ vec, jp.zeros(m.nv)) return length, moment lengths_site, moments_site = _length_moment( d.site_xpos[wrap_objid_site0], d.site_xpos[wrap_objid_site1], m.site_bodyid[wrap_objid_site0], m.site_bodyid[wrap_objid_site1], ) if wrap_id_site_pair.size: divisor_site_pair = divisor[wrap_id_site_pair] lengths_site /= divisor_site_pair moments_site /= divisor_site_pair[:, None] tendon_nsite = np.array([ sum((wrap_id_site_pair >= adr) & (wrap_id_site_pair < adr + num)) for adr, num in zip(m.tendon_adr, m.tendon_num) ]) tendon_wrapnum_site = np.array([ sum((wrap_id_site >= adr) & (wrap_id_site < adr + num)) for adr, num in zip(m.tendon_adr, m.tendon_num) ]) tendon_has_site = tendon_nsite > 0 (tendon_id_site,) = np.nonzero(tendon_has_site) tendon_nsite = tendon_nsite[tendon_has_site] tendon_with_site = tendon_nsite.size ten_site_id = np.repeat(np.arange(tendon_with_site), tendon_nsite) length_site = jax.ops.segment_sum(lengths_site, ten_site_id, tendon_with_site) moment_site = jax.ops.segment_sum(moments_site, ten_site_id, tendon_with_site) # process spatial sphere/cylinder wrap (wrap_id_geom,) = np.nonzero( (m.wrap_type == WrapType.SPHERE) | (m.wrap_type == WrapType.CYLINDER) ) # get objid for site-geom-site instances wrap_id_sitegeomsite = wrap_id_geom[:, None] + np.array([-1, 0, 1])[None] wrap_objid_site0, wrap_objid_geom, wrap_objid_site1 = m.wrap_objid[ wrap_id_sitegeomsite ].T # get site positions before and after geom site_pnt0 = d.site_xpos[wrap_objid_site0] site_pnt1 = d.site_xpos[wrap_objid_site1] # get geom information geom_xpos = d.geom_xpos[wrap_objid_geom] geom_xmat = d.geom_xmat[wrap_objid_geom] geom_size = m.geom_size[wrap_objid_geom, 0] geom_type = m.wrap_type[wrap_id_geom] is_sphere = geom_type == WrapType.SPHERE # get body ids for site-geom-site instances body_id_site0 = m.site_bodyid[wrap_objid_site0] body_id_geom = m.geom_bodyid[wrap_objid_geom] body_id_site1 = m.site_bodyid[wrap_objid_site1] # find wrap object sidesites (if any exist) side_id = np.round(m.wrap_prm[wrap_id_geom]).astype(int) side = d.site_xpos[side_id] has_sidesite = np.expand_dims(np.array(side_id >= 0), -1) # wrap inside # TODO(taylorhowell): check that is_wrap_inside is consistent with # site and geom relative positions (wrap_inside_id,) = np.nonzero(m._impl.is_wrap_inside) (wrap_outside_id,) = np.nonzero(~m._impl.is_wrap_inside) # compute geom wrap length and connect points (if wrap occurs) v_wrap = jax.vmap( support.wrap, in_axes=(0, 0, 0, 0, 0, 0, 0, 0, None, None, None, None) ) lengths_inside, pnt0_inside, pnt1_inside = v_wrap( site_pnt0[wrap_inside_id], site_pnt1[wrap_inside_id], geom_xpos[wrap_inside_id], geom_xmat[wrap_inside_id], geom_size[wrap_inside_id], side[wrap_inside_id], has_sidesite[wrap_inside_id], is_sphere[wrap_inside_id], True, m._impl.wrap_inside_maxiter, m._impl.wrap_inside_tolerance, m._impl.wrap_inside_z_init, ) lengths_outside, pnt0_outside, pnt1_outside = v_wrap( site_pnt0[wrap_outside_id], site_pnt1[wrap_outside_id], geom_xpos[wrap_outside_id], geom_xmat[wrap_outside_id], geom_size[wrap_outside_id], side[wrap_outside_id], has_sidesite[wrap_outside_id], is_sphere[wrap_outside_id], False, m._impl.wrap_inside_maxiter, m._impl.wrap_inside_tolerance, m._impl.wrap_inside_z_init, ) wrap_id = np.argsort(np.concatenate([wrap_inside_id, wrap_outside_id])) vstack_ = lambda x, y: jp.vstack([x, y])[wrap_id] lengths_geomgeom = vstack_(lengths_inside, lengths_outside) geom_pnt0 = vstack_(pnt0_inside, pnt0_outside) geom_pnt1 = vstack_(pnt1_inside, pnt1_outside) lengths_geomgeom = lengths_geomgeom.reshape(-1) # identify geoms where wrap does not occur no_geom_wrap = lengths_geomgeom < 0 wrap_objid_geom_skip = jp.where(no_geom_wrap, 0, wrap_objid_geom) # compute lengths for site-site (no wrap), site-geom, and geom-site segments def _distance(p0, p1): return jax.vmap(lambda x, y: math.norm(x - y))(p0, p1) lengths_sitesite = _distance(site_pnt0, site_pnt1) lengths_sitegeom = _distance(site_pnt0, geom_pnt0) lengths_geomsite = _distance(geom_pnt1, site_pnt1) # select length segments according to geom wrap lengths_geom = jp.where( no_geom_wrap, lengths_sitesite, lengths_sitegeom + lengths_geomgeom + lengths_geomsite, ) # compute moments for site-site (no wrap), site-geom, geom-geom, and geom-site # segments _, moments_sitesite = _length_moment( site_pnt0, site_pnt1, body_id_site0, body_id_site1 ) _, moments_sitegeom = _length_moment( site_pnt0, geom_pnt0, body_id_site0, body_id_geom ) _, moments_geomgeom = _length_moment( geom_pnt0, geom_pnt1, body_id_geom, body_id_geom ) _, moments_geomsite = _length_moment( geom_pnt1, site_pnt1, body_id_geom, body_id_site1 ) # select moment segments according to geom wrap moments_geom = jp.where( no_geom_wrap[:, None], moments_sitesite, moments_sitegeom + moments_geomgeom + moments_geomsite, ) if wrap_id_geom.size: divisor_geom = divisor[wrap_id_geom] lengths_geom /= divisor_geom moments_geom /= divisor_geom[:, None] # construct number of site-geom-site instances per tendon tendon_ngeom = np.array([ sum((wrap_id_geom >= adr) & (wrap_id_geom < adr + num)) for adr, num in zip(m.tendon_adr, m.tendon_num) ]) tendon_has_geom = tendon_ngeom > 0 tendon_ngeom = tendon_ngeom[tendon_has_geom] # identify tendons with at least one site-geom-site instance (tendon_id_geom,) = np.nonzero(tendon_has_geom) # combine site-geom-site segment lengths and moments for each tendon tendon_with_geom = tendon_ngeom.size ten_geom_id = np.repeat(np.arange(tendon_with_geom), tendon_ngeom) length_geom = jax.ops.segment_sum(lengths_geom, ten_geom_id, tendon_with_geom) moment_geom = jax.ops.segment_sum(moments_geom, ten_geom_id, tendon_with_geom) # calculate number of wrap objects per tendon, based on geom wrap wrapnums_geom = jp.where(no_geom_wrap, 0, 2) tendon_wrapnum_geom = jax.ops.segment_sum( wrapnums_geom, ten_geom_id, tendon_with_geom ) # assemble length and moment ten_length = ( jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt) ) ten_length = ten_length.at[tendon_id_site].add(length_site) ten_length = ten_length.at[tendon_id_geom].add(length_geom) ten_moment = ( jp.zeros_like(d._impl.ten_J) .at[adr_moment_jnt, dofadr_moment_jnt] .set(moment_jnt) ) ten_moment = ten_moment.at[tendon_id_site].add(moment_site) ten_moment = ten_moment.at[tendon_id_geom].add(moment_geom) # construct wrap addresses wrap_adr_pulley = [] wrap_adr_site = [] wrap_adr_geom = [] count = 0 for wrap_type in m.wrap_type: if wrap_type == WrapType.PULLEY: wrap_adr_pulley.append(count) count += 1 elif wrap_type == WrapType.SITE: wrap_adr_site.append(count) count += 1 elif wrap_type in (WrapType.SPHERE, WrapType.CYLINDER): wrap_adr_geom.append(count) wrap_adr_geom.append(count + 1) count += 2 wrap_adr_pulley = np.array(wrap_adr_pulley).astype(int) wrap_adr_site = np.array(wrap_adr_site).astype(int) wrap_adr_geom = np.array(wrap_adr_geom).astype(int) wrap_adr = np.concatenate([wrap_adr_pulley, wrap_adr_site, wrap_adr_geom]) ten_wrapnum = jp.array(tendon_wrapnum_pulley + tendon_wrapnum_site) ten_wrapnum = ten_wrapnum.at[tendon_id_geom].add(tendon_wrapnum_geom) ten_wrapadr = jp.concatenate([jp.array([0]), jp.cumsum(ten_wrapnum)[:-1]]) xpos_site = d.site_xpos[m.wrap_objid[wrap_id_site]] xpos_geom = jp.hstack([geom_pnt0, geom_pnt1]).reshape((-1, 3)) # sort objects, moving no wrap geoms to bottom rows wrap_adr_sort = np.argsort(wrap_adr) skipped = ( jp.zeros(count, dtype=bool) .at[wrap_adr_geom] .set(jp.repeat(no_geom_wrap, 2).reshape(-1)) ) sort = jp.argsort(skipped) wrap_xpos = jp.concatenate( [jp.zeros((nwrap_pulley, 3)), xpos_site, xpos_geom] )[wrap_adr_sort] wrap_xpos = jp.concatenate( [wrap_xpos[sort], jp.zeros((2 * m.nwrap - count, 3))] ).reshape((m.nwrap, 6)) wrap_obj = jp.concatenate([ -2 * jp.ones(nwrap_pulley, dtype=int), -1 * jp.ones(nwrap_site, dtype=int), jp.repeat(wrap_objid_geom_skip, 2).reshape(-1), ])[wrap_adr_sort] wrap_obj = jp.concatenate( [wrap_obj[sort], jp.zeros(2 * m.nwrap - count, dtype=int)] ).reshape((m.nwrap, 2)) return d.tree_replace({ 'ten_length': ten_length, '_impl.ten_J': ten_moment, '_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int), '_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int), '_impl.wrap_xpos': wrap_xpos, '_impl.wrap_obj': jp.array(wrap_obj, dtype=int), })
def _site_dof_mask(m: Model) -> np.ndarray: """Creates a dof mask for site transmissions.""" mask = np.ones((m.nu, m.nv)) for i in np.nonzero(m.actuator_trnid[:, 1] != -1)[0]: id_, refid = m.actuator_trnid[i] # initialize last dof address for each body b0 = m.body_weldid[m.site_bodyid[id_]] b1 = m.body_weldid[m.site_bodyid[refid]] dofadr0 = m.body_dofadr[b0] + m.body_dofnum[b0] - 1 dofadr1 = m.body_dofadr[b1] + m.body_dofnum[b1] - 1 # find common ancestral dof, if any while dofadr0 != dofadr1: if dofadr0 < dofadr1: dofadr1 = m.dof_parentid[dofadr1] else: dofadr0 = m.dof_parentid[dofadr0] if dofadr0 == -1 or dofadr1 == -1: break # if common ancestral dof was found, clear the columns of its parental chain da = dofadr0 if dofadr0 == dofadr1 else -1 while da >= 0: mask[i, da] = 0.0 da = m.dof_parentid[da] return mask
[docs] def transmission(m: Model, d: Data) -> Data: """Computes actuator/transmission lengths and moments.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('transmission requires JAX backend implementation.') if not m.nu: return d def fn( trntype, trnid, gear, jnt_typ, m_j, qpos, has_refsite, site_dof_mask, site_xpos, site_xmat, site_quat, ): if trntype in (TrnType.JOINT, TrnType.JOINTINPARENT): if jnt_typ == JointType.FREE: length = jp.zeros(1) moment = gear if trntype == TrnType.JOINTINPARENT: quat_neg = math.quat_inv(qpos[3:]) gearaxis = math.rotate(gear[3:], quat_neg) moment = moment.at[3:].set(gearaxis) m_j = m_j + jp.arange(6) elif jnt_typ == JointType.BALL: axis, angle = math.quat_to_axis_angle(qpos) gearaxis = gear[:3] if trntype == TrnType.JOINTINPARENT: quat_neg = math.quat_inv(qpos) gearaxis = math.rotate(gear[:3], quat_neg) length = jp.dot(axis * angle, gearaxis)[None] moment = gearaxis m_j = m_j + jp.arange(3) elif jnt_typ in (JointType.SLIDE, JointType.HINGE): length = qpos * gear[0] moment = gear[:1] m_j = m_j[None] else: raise RuntimeError(f'unrecognized joint type: {JointType(jnt_typ)}') moment = jp.zeros((m.nv,)).at[m_j].set(moment) elif trntype == TrnType.SITE: length = jp.zeros(1) id_, refid = jp.array(m.site_bodyid)[trnid] jacp, jacr = support.jac(m, d, site_xpos[0], id_) frame_xmat = site_xmat[0] if has_refsite: vecp = site_xmat[1].T @ (site_xpos[0] - site_xpos[1]) vecr = math.quat_sub(site_quat[0], site_quat[1]) length += jp.dot(jp.concatenate([vecp, vecr]), gear) jacrefp, jacrefr = support.jac(m, d, site_xpos[1], refid) jacp, jacr = jacp - jacrefp, jacr - jacrefr frame_xmat = site_xmat[1] jac = jp.concatenate((jacp, jacr), axis=1) * site_dof_mask[:, None] wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:])) moment = jac @ wrench elif trntype == TrnType.TENDON: length = d.ten_length[trnid[0]] * gear[:1] moment = d._impl.ten_J[trnid[0]] * gear[0] else: raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}') return length, moment # pre-compute values for site transmissions has_refsite = m.actuator_trnid[:, 1] != -1 site_dof_mask = _site_dof_mask(m) site_quat = jax.vmap(math.quat_mul)(m.site_quat, d.xquat[m.site_bodyid]) length, moment = scan.flat( m, fn, 'uuujjquusss', 'uu', m.actuator_trntype, jp.array(m.actuator_trnid), m.actuator_gear, m.jnt_type, jp.array(m.jnt_dofadr), d.qpos, has_refsite, jp.array(site_dof_mask), d.site_xpos, d.site_xmat, site_quat, group_by='u', ) length = length.reshape((m.nu,)) moment = moment.reshape((m.nu, m.nv)) d = d.tree_replace( {'actuator_length': length, '_impl.actuator_moment': moment} ) return d
[docs] def tendon_armature(m: Model, d: Data) -> Data: """Add tendon armature to qM.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('tendon_armature requires JAX backend implementation.') if not m.ntendon: return d # TODO(taylorhowell): if sparse, compute sparse JTAJ JTAJ = d._impl.ten_J.T @ jax.vmap(jp.multiply)( d._impl.ten_J, m.tendon_armature ) if support.is_sparse(m): ij = [] for i in range(m.nv): j = i while j > -1: ij.append((i, j)) j = m.dof_parentid[j] i, j = (jp.array(x) for x in zip(*ij)) JTAJ = JTAJ[(i, j)] return d.tree_replace({'_impl.qM': d._impl.qM + JTAJ})
def tendon_dot(m: Model, d: Data) -> jax.Array: """Compute time derivative of dense tendon Jacobian for one tendon.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('tendon_dot requires JAX backend implementation.') ten_Jdot = jp.zeros((m.ntendon, m.nv)) # pylint: disable=invalid-name if not m.ntendon: return ten_Jdot # process pulleys (wrap_id_pulley,) = np.nonzero(m.wrap_type == WrapType.PULLEY) divisor = np.ones(m.nwrap) for adr, num in zip(m.tendon_adr, m.tendon_num): for id_pulley in wrap_id_pulley: if adr <= id_pulley < adr + num: divisor[id_pulley : adr + num] = np.maximum( mujoco.mjMINVAL, m.wrap_prm[id_pulley] ) # process spatial tendon sites (wrap_id_site,) = np.nonzero(m.wrap_type == WrapType.SITE) # find consecutive sites, skipping tendon transitions (pair_id,) = np.nonzero(np.diff(wrap_id_site) == 1) wrap_id_site_pair = np.setdiff1d(wrap_id_site[pair_id], m.tendon_adr[1:] - 1) wrap_objid_site0 = m.wrap_objid[wrap_id_site_pair] wrap_objid_site1 = m.wrap_objid[wrap_id_site_pair + 1] site_bodyid0 = m.site_bodyid[wrap_objid_site0] site_bodyid1 = m.site_bodyid[wrap_objid_site1] site_xpos0 = d.site_xpos[wrap_objid_site0] site_xpos1 = d.site_xpos[wrap_objid_site1] subtree_com0 = d.subtree_com[m.body_rootid[site_bodyid0]] subtree_com1 = d.subtree_com[m.body_rootid[site_bodyid1]] site_vel0 = jax.vmap(lambda a, b: a[3:] - jp.cross(b, a[:3]))( d.cvel[site_bodyid0], site_xpos0 - subtree_com0 ) site_vel1 = jax.vmap(lambda a, b: a[3:] - jp.cross(b, a[:3]))( d.cvel[site_bodyid1], site_xpos1 - subtree_com1 ) @jax.vmap def _momentdot(wpnt0, wpnt1, wvel0, wvel1, body0, body1): # dpnt = 3D position difference, normalize dpnt = wpnt1 - wpnt0 norm = math.norm(dpnt) dpnt = jp.where( norm < mujoco.mjMINVAL, jp.array([1.0, 0.0, 0.0]), math.safe_div(dpnt, norm), ) # dvel = d / dt(dpnt) dvel = wvel1 - wvel0 dot = jp.dot(dpnt, dvel) dvel += dpnt * -dot dvel = jp.where(norm > mujoco.mjMINVAL, math.safe_div(dvel, norm), 0.0) # get endpoint JacobianDots, subtract jacp1, _ = support.jac_dot(m, d, wpnt0, body0) jacp2, _ = support.jac_dot(m, d, wpnt1, body1) jacdif = jacp2 - jacp1 # chain rule, first term: Jdot += d / dt(jac2 - jac1) * dpnt tmp0 = jacdif @ dpnt # get endpoint Jacobians, subtract jacp1, _ = support.jac(m, d, wpnt0, body0) jacp2, _ = support.jac(m, d, wpnt1, body1) jacdif = jacp2 - jacp1 # chain rule, second term: Jdot += (jac2 - jac1) * d/dt (dpnt) tmp1 = jacdif @ dvel return jp.where(body0 != body1, tmp0 + tmp1, jp.zeros(m.nv)) momentdots = _momentdot( site_xpos0, site_xpos1, site_vel0, site_vel1, site_bodyid0, site_bodyid1, ) if wrap_id_site_pair.size: divisor_site_pair = divisor[wrap_id_site_pair] momentdots /= divisor_site_pair[:, None] tendon_nsite = np.array([ sum((wrap_id_site_pair >= adr) & (wrap_id_site_pair < adr + num)) for adr, num in zip(m.tendon_adr, m.tendon_num) ]) tendon_has_site = tendon_nsite > 0 (tendon_id_site,) = np.nonzero(tendon_has_site) tendon_nsite = tendon_nsite[tendon_has_site] tendon_with_site = tendon_nsite.size ten_site_id = np.repeat(np.arange(tendon_with_site), tendon_nsite) momentdot = jax.ops.segment_sum(momentdots, ten_site_id, tendon_with_site) ten_Jdot = ten_Jdot.at[tendon_id_site].set(momentdot) # pylint: disable=invalid-name # TODO(taylorhowell): time derivatives for geoms return ten_Jdot
[docs] def tendon_bias(m: Model, d: Data) -> Data: """Add bias force due to tendon armature.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('tendon_bias requires JAX backend implementation.') if not m.ntendon: return d # get dense d/dt(tendon Jacobian) for each tendon ten_Jdot = tendon_dot(m, d) # pylint: disable=invalid-name # add bias term: qfrc += ten_J * armature * ten_Jdot @ qvel coef = m.tendon_armature * jp.dot(ten_Jdot, d.qvel) return d.tree_replace({ 'qfrc_bias': ( d.qfrc_bias + jp.sum(jax.vmap(jp.multiply)(d._impl.ten_J, coef), axis=0) ) })