Source code for mujoco.mjx._src.support

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Engine support functions."""

from collections.abc import Iterable, Sequence
from typing import Optional, Tuple, Union

import jax
from jax import numpy as jp
import mujoco
from mujoco.introspect import mjxmacro
from mujoco.mjx._src import math
from mujoco.mjx._src import scan
# pylint: disable=g-importing-member
from mujoco.mjx._src.types import ConeType
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import JacobianType
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import Model
# pylint: enable=g-importing-member
import numpy as np


[docs] def is_sparse(m: Union[mujoco.MjModel, Model]) -> bool: """Return True if this model should create sparse mass matrices. Args: m: a MuJoCo or MJX model Returns: True if provided model should create sparse mass matrices Modern TPUs have specialized hardware for rapidly operating over sparse matrices, whereas GPUs tend to be faster with dense matrices as long as they fit onto the device. As such, the default behavior in MJX (via ``JacobianType.AUTO``) is sparse if ``nv`` is >= 60 or MJX detects a TPU as the default backend, otherwise dense. """ # AUTO is a rough heuristic - you may see better performance for your workload # and compute by explicitly setting jacobian to dense or sparse if m.opt.jacobian == JacobianType.AUTO: return m.nv >= 60 or jax.default_backend() == 'tpu' return m.opt.jacobian == JacobianType.SPARSE
def make_m( m: Model, a: jax.Array, b: jax.Array, d: Optional[jax.Array] = None ) -> jax.Array: """Computes M = a @ b.T + diag(d).""" ij = [] for i in range(m.nv): j = i while j > -1: ij.append((i, j)) j = m.dof_parentid[j] i, j = (jp.array(x) for x in zip(*ij)) if not is_sparse(m): qm = a @ b.T if d is not None: qm += jp.diag(d) mask = jp.zeros((m.nv, m.nv), dtype=bool).at[(i, j)].set(True) qm = qm * mask qm = qm + jp.tril(qm, -1).T return qm a_i = jp.take(a, i, axis=0) b_j = jp.take(b, j, axis=0) qm = jax.vmap(jp.dot)(a_i, b_j) # add diagonal if d is not None: qm = qm.at[m.dof_Madr].add(d) return qm
[docs] def full_m(m: Model, d: Data) -> jax.Array: """Reconstitute dense mass matrix from qM.""" if not is_sparse(m): return d._impl.qM # pytype: disable=attribute-error ij = [] for i in range(m.nv): j = i while j > -1: ij.append((i, j)) j = m.dof_parentid[j] i, j = (jp.array(x) for x in zip(*ij)) mat = jp.zeros((m.nv, m.nv)).at[(i, j)].set(d._impl.qM) # pytype: disable=attribute-error # also set upper triangular mat = mat + jp.tril(mat, -1).T return mat
[docs] def mul_m(m: Model, d: Data, vec: jax.Array) -> jax.Array: """Multiply vector by inertia matrix.""" if not is_sparse(m): return d._impl.qM @ vec # pytype: disable=attribute-error diag_mul = d._impl.qM[jp.array(m.dof_Madr)] * vec # pytype: disable=attribute-error is_, js, madr_ijs = [], [], [] for i in range(m.nv): madr_ij, j = m.dof_Madr[i], i while True: madr_ij, j = madr_ij + 1, m.dof_parentid[j] if j == -1: break is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] i, j, madr_ij = (jp.array(x, dtype=jp.int32) for x in (is_, js, madr_ijs)) out = diag_mul.at[i].add(d._impl.qM[madr_ij] * vec[j]) # pytype: disable=attribute-error out = out.at[j].add(d._impl.qM[madr_ij] * vec[i]) # pytype: disable=attribute-error return out
[docs] def jac( m: Model, d: Data, point: jax.Array, body_id: jax.Array ) -> Tuple[jax.Array, jax.Array]: """Compute pair of (NV, 3) Jacobians of global point attached to body.""" # TODO(taylorhowell): statically construct mask fn = lambda carry, b: b if carry is None else b + carry mask = (jp.arange(m.nbody) == body_id) * 1 mask = scan.body_tree(m, fn, 'b', 'b', mask, reverse=True) mask = mask[jp.array(m.dof_bodyid)] > 0 offset = point - d.subtree_com[jp.array(m.body_rootid)[body_id]] jacp = jax.vmap(lambda a, b=offset: a[3:] + jp.cross(a[:3], b))(d.cdof) # pytype: disable=attribute-error jacp = jax.vmap(jp.multiply)(jacp, mask) jacr = jax.vmap(jp.multiply)(d.cdof[:, :3], mask) # pytype: disable=attribute-error return jacp, jacr
def jac_dot( m: Model, d: Data, point: jax.Array, body_id: jax.Array ) -> Tuple[jax.Array, jax.Array]: """Compute pair of (NV, 3) Jacobian time derivatives of global point attached to body.""" # TODO(taylorhowell): statically construct mask fn = lambda carry, b: b if carry is None else b + carry mask = (jp.arange(m.nbody) == body_id) * 1 mask = scan.body_tree(m, fn, 'b', 'b', mask, reverse=True) mask = mask[jp.array(m.dof_bodyid)] > 0 offset = point - d.subtree_com[jp.array(m.body_rootid)[body_id]] pvel_lin = d.cvel[body_id][3:] - jp.cross(offset, d.cvel[body_id][:3]) cdof = d.cdof cdof_dot = d.cdof_dot # check for quaternion jnt_type = m.jnt_type[m.dof_jntid] dof_adr = m.jnt_dofadr[m.dof_jntid] is_quat = (jnt_type == JointType.BALL) | ( jnt_type == JointType.FREE & (np.arange(m.nv) >= dof_adr + 3) ) # compute cdof_dot for quaternion (use current body cvel) cdof_dot_quat = jax.vmap(math.motion_cross)(d.cvel[m.dof_bodyid], cdof) cdof_dot = jp.where(is_quat[:, None], cdof_dot_quat, cdof_dot) jacp = jax.vmap( lambda a, b: a[3:] + jp.cross(a[:3], offset) + jp.cross(b[:3], pvel_lin) )(cdof_dot, cdof) jacp = jax.vmap(jp.multiply)(jacp, mask) jacr = jax.vmap(jp.multiply)(cdof_dot[:, :3], mask) # pytype: disable=attribute-error return jacp, jacr
[docs] def apply_ft( m: Model, d: Data, force: jax.Array, torque: jax.Array, point: jax.Array, body_id: jax.Array, ) -> jax.Array: """Apply Cartesian force and torque.""" jacp, jacr = jac(m, d, point, body_id) return jacp @ force + jacr @ torque
[docs] def xfrc_accumulate(m: Model, d: Data) -> jax.Array: """Accumulate xfrc_applied into a qfrc.""" qfrc = jax.vmap(apply_ft, in_axes=(None, None, 0, 0, 0, 0))( m, d, d.xfrc_applied[:, :3], d.xfrc_applied[:, 3:], d.xipos, jp.arange(m.nbody), ) return jp.sum(qfrc, axis=0)
def local_to_global( world_pos: jax.Array, world_quat: jax.Array, local_pos: jax.Array, local_quat: jax.Array, ) -> Tuple[jax.Array, jax.Array]: """Converts local position/orientation to world frame.""" pos = world_pos + math.rotate(local_pos, world_quat) mat = math.quat_to_mat(math.quat_mul(world_quat, local_quat)) return pos, mat def _getnum(m: Union[Model, mujoco.MjModel], obj: mujoco._enums.mjtObj) -> int: """Gets the number of objects for the given object type.""" return { mujoco.mjtObj.mjOBJ_BODY: m.nbody, mujoco.mjtObj.mjOBJ_JOINT: m.njnt, mujoco.mjtObj.mjOBJ_GEOM: m.ngeom, mujoco.mjtObj.mjOBJ_SITE: m.nsite, mujoco.mjtObj.mjOBJ_CAMERA: m.ncam, mujoco.mjtObj.mjOBJ_MESH: m.nmesh, mujoco.mjtObj.mjOBJ_HFIELD: m.nhfield, mujoco.mjtObj.mjOBJ_PAIR: m.npair, mujoco.mjtObj.mjOBJ_EQUALITY: m.neq, mujoco.mjtObj.mjOBJ_TENDON: m.ntendon, mujoco.mjtObj.mjOBJ_ACTUATOR: m.nu, mujoco.mjtObj.mjOBJ_SENSOR: m.nsensor, mujoco.mjtObj.mjOBJ_NUMERIC: m.nnumeric, mujoco.mjtObj.mjOBJ_TUPLE: m.ntuple, mujoco.mjtObj.mjOBJ_KEY: m.nkey, }.get(obj, 0) def _getadr( m: Union[Model, mujoco.MjModel], obj: mujoco._enums.mjtObj ) -> np.ndarray: """Gets the name addresses for the given object type.""" return { mujoco.mjtObj.mjOBJ_BODY: m.name_bodyadr, mujoco.mjtObj.mjOBJ_JOINT: m.name_jntadr, mujoco.mjtObj.mjOBJ_GEOM: m.name_geomadr, mujoco.mjtObj.mjOBJ_SITE: m.name_siteadr, mujoco.mjtObj.mjOBJ_CAMERA: m.name_camadr, mujoco.mjtObj.mjOBJ_MESH: m.name_meshadr, mujoco.mjtObj.mjOBJ_HFIELD: m.name_hfieldadr, mujoco.mjtObj.mjOBJ_PAIR: m.name_pairadr, mujoco.mjtObj.mjOBJ_EQUALITY: m.name_eqadr, mujoco.mjtObj.mjOBJ_TENDON: m.name_tendonadr, mujoco.mjtObj.mjOBJ_ACTUATOR: m.name_actuatoradr, mujoco.mjtObj.mjOBJ_SENSOR: m.name_sensoradr, mujoco.mjtObj.mjOBJ_NUMERIC: m.name_numericadr, mujoco.mjtObj.mjOBJ_TUPLE: m.name_tupleadr, mujoco.mjtObj.mjOBJ_KEY: m.name_keyadr, }[obj]
[docs] def id2name( m: Union[Model, mujoco.MjModel], typ: mujoco._enums.mjtObj, i: int ) -> Optional[str]: """Gets the name of an object with the specified mjtObj type and ids. See mujoco.id2name for more info. Args: m: mujoco.MjModel or mjx.Model typ: mujoco.mjtObj type i: the id Returns: the name string, or None if not found """ num = _getnum(m, typ) if i < 0 or i >= num: return None adr = _getadr(m, typ) name = m.names[adr[i] :].decode('utf-8').split('\x00', 1)[0] return name or None
[docs] def name2id( m: Union[Model, mujoco.MjModel], typ: mujoco._enums.mjtObj, name: str ) -> int: """Gets the id of an object with the specified mjtObj type and name. See mujoco.mj_name2id for more info. Args: m: mujoco.MjModel or mjx.Model typ: mujoco.mjtObj type name: the name of the object Returns: the id, or -1 if not found """ num = _getnum(m, typ) adr = _getadr(m, typ) # TODO: consider using MjModel.names_map instead names_map = { m.names[adr[i] :].decode('utf-8').split('\x00', 1)[0]: i for i in range(num) } return names_map.get(name, -1)
class BindModel(object): """Class holding the requested MJX Model and spec id for binding a spec to Model.""" def __init__(self, model: Model, specs: Sequence[mujoco.MjStruct]): self.model = model self.prefix = '' ids = [] for spec in specs: if model.signature != spec.signature: raise ValueError( 'mjSpec signature does not match mjx.Model signature:' f' {spec.signature} != {model.signature}' ) elif spec.id < 0: raise KeyError(f'invalid id: {spec.id}') elif isinstance(spec, mujoco.MjsBody): self.prefix = 'body_' elif isinstance(spec, mujoco.MjsJoint): self.prefix = 'jnt_' elif isinstance(spec, mujoco.MjsGeom): self.prefix = 'geom_' elif isinstance(spec, mujoco.MjsSite): self.prefix = 'site_' elif isinstance(spec, mujoco.MjsLight): self.prefix = 'light_' elif isinstance(spec, mujoco.MjsCamera): self.prefix = 'cam_' elif isinstance(spec, mujoco.MjsMesh): self.prefix = 'mesh_' elif isinstance(spec, mujoco.MjsHField): self.prefix = 'hfield_' elif isinstance(spec, mujoco.MjsPair): self.prefix = 'pair_' elif isinstance(spec, mujoco.MjsTendon): self.prefix = 'tendon_' elif isinstance(spec, mujoco.MjsActuator): self.prefix = 'actuator_' elif isinstance(spec, mujoco.MjsSensor): self.prefix = 'sensor_' elif isinstance(spec, mujoco.MjsNumeric): self.prefix = 'numeric_' elif isinstance(spec, mujoco.MjsText): self.prefix = 'text_' elif isinstance(spec, mujoco.MjsTuple): self.prefix = 'tuple_' elif isinstance(spec, mujoco.MjsKey): self.prefix = 'key_' elif isinstance(spec, mujoco.MjsEquality): self.prefix = 'eq_' elif isinstance(spec, mujoco.MjsExclude): self.prefix = 'exclude_' elif isinstance(spec, mujoco.MjsSkin): self.prefix = 'skin_' elif isinstance(spec, mujoco.MjsMaterial): self.prefix = 'material_' else: raise ValueError('invalid spec type') ids.append(spec.id) if len(ids) == 1: self.id = ids[0] else: self.id = ids def _slice(self, name: str, idx: Union[int, slice, Sequence[int]]): _, expected_dim = mjxmacro.MJMODEL[name] var = getattr(self.model, name) if expected_dim == '1': return var[..., idx] elif expected_dim == '9': return var[..., idx, :, :] return var[..., idx, :] def __getattr__(self, name: str): return self._slice(self.prefix + name, self.id) def _bind_model( self: Model, obj: mujoco.MjStruct | Iterable[mujoco.MjStruct] ) -> BindModel: """Bind a Mujoco spec to an MJX Model.""" if isinstance(obj, mujoco.MjStruct): obj = (obj,) else: obj = tuple(obj) return BindModel(self, obj) class BindData(object): """Class holding the requested MJX Data and spec id for binding a spec to Data.""" def __init__( self, data: Data, model: Model, specs: Sequence[mujoco.MjStruct] ): self.data = data self.model = model self.prefix = '' ids = [] for spec in specs: if model.signature != spec.signature: raise ValueError( 'mjSpec signature does not match mjx.Model signature:' f' {spec.signature} != {model.signature}' ) if spec.id < 0: raise KeyError(f'invalid id: {spec.id}') elif isinstance(spec, mujoco.MjsBody): pass elif isinstance(spec, mujoco.MjsJoint): self.prefix = 'jnt_' elif isinstance(spec, mujoco.MjsGeom): self.prefix = 'geom_' elif isinstance(spec, mujoco.MjsSite): self.prefix = 'site_' elif isinstance(spec, mujoco.MjsLight): self.prefix = 'light_' elif isinstance(spec, mujoco.MjsCamera): self.prefix = 'cam_' elif isinstance(spec, mujoco.MjsTendon): self.prefix = 'ten_' elif isinstance(spec, mujoco.MjsActuator): self.prefix = 'actuator_' elif isinstance(spec, mujoco.MjsSensor): self.prefix = 'sensor_' elif isinstance(spec, mujoco.MjsEquality): self.prefix = 'eq_' else: raise ValueError('invalid spec type') ids.append(spec.id) if len(ids) == 1: self.id = ids[0] else: self.id = ids def __getname(self, name: str): """Get the name of the attribute and check if the type is correct.""" if name == 'sensordata': if self.prefix == 'sensor_': return name else: raise AttributeError('sensordata is not available for this type') if name == 'ctrl': if self.prefix == 'actuator_': return name else: raise AttributeError('ctrl is not available for this type') if ( name == 'qpos' or name == 'qvel' or name == 'qacc' or name.startswith('qfrc_') ): if self.prefix == 'jnt_': return name else: raise AttributeError( 'qpos, qvel, qacc, qfrc are not available for this type' ) else: return self.prefix + name def _slice(self, name: str, idx: Union[int, slice, Sequence[int]]): _, expected_dim = mjxmacro.MJDATA[name] var = getattr(self.data, name) if expected_dim == '1': return var[..., idx] elif expected_dim == '9': return var[..., idx, :, :] return var[..., idx, :] def __getattr__(self, name: str): if name in ('sensordata', 'qpos', 'qvel', 'qacc') or ( name.startswith('qfrc_') ): adr = num = 0 if name == 'sensordata': adr = self.model.sensor_adr[self.id] num = self.model.sensor_dim[self.id] elif name == 'qpos': adr = self.model.jnt_qposadr[self.id] typ = self.model.jnt_type[self.id] num = sum((typ == jt) * jt.qpos_width() for jt in JointType) elif name == 'qvel' or name == 'qacc' or name.startswith('qfrc_'): adr = self.model.jnt_dofadr[self.id] typ = self.model.jnt_type[self.id] num = sum((typ == jt) * jt.dof_width() for jt in JointType) if isinstance(self.id, list): idx = [] for a, n in zip(adr, num): idx.extend(a + j for j in range(n)) return self._slice(self.__getname(name), idx) elif num > 1: return self._slice(self.__getname(name), slice(adr, adr + num)) else: return self._slice(self.__getname(name), adr) elif name in ('mocap_pos', 'mocap_quat'): return self._slice(self.__getname(name), self.model.body_mocapid[self.id]) return self._slice(self.__getname(name), self.id) def set(self, name: str, value: jax.Array) -> Data: """Set the value of an array in an MJX Data.""" if name == 'sensordata': raise AttributeError('sensordata is readonly') array = getattr(self.data, self.__getname(name)) dim = 1 if len(array.shape) == 1 else array.shape[-1] try: iter(value) except TypeError: value = [value] if name in ('qpos', 'qvel', 'qacc', 'mocap_pos', 'mocap_quat'): adr = num = 0 if name == 'qpos': adr = self.model.jnt_qposadr[self.id] typ = self.model.jnt_type[self.id] num = sum((typ == jt) * jt.qpos_width() for jt in JointType) elif name == 'qvel' or name == 'qacc': adr = self.model.jnt_dofadr[self.id] typ = self.model.jnt_type[self.id] num = sum((typ == jt) * jt.dof_width() for jt in JointType) elif name == 'mocap_pos': adr = self.model.body_mocapid[self.id] * 3 num = np.ones_like(self.id, dtype=int) * 3 elif name == 'mocap_quat': adr = self.model.body_mocapid[self.id] * 4 num = np.ones_like(self.id, dtype=int) * 4 if not isinstance(self.id, list): adr = [adr] num = [num] elif isinstance(self.id, list): adr = (np.array(self.id) * dim).tolist() num = [dim for _ in range(len(self.id))] else: adr = [self.id * dim] num = [dim] i = 0 value = jax.numpy.array(value).flatten() for a, n in zip(adr, num): shape = array.shape array = array.flatten().at[a : a + n].set(value[i : i + n]).reshape(shape) i += n return self.data.replace(**{self.__getname(name): array}) def _bind_data( self: Data, model: Model, obj: mujoco.MjStruct | Iterable[mujoco.MjStruct] ) -> BindData: """Bind a Mujoco spec to an MJX Data.""" if isinstance(obj, mujoco.MjStruct): obj = (obj,) else: obj = tuple(obj) return BindData(self, model, obj) Model.bind = _bind_model Data.bind = _bind_data def _decode_pyramid( pyramid: jax.Array, mu: jax.Array, condim: int ) -> jax.Array: """Converts pyramid representation to contact force.""" force = jp.zeros(6, dtype=float) if condim == 1: return force.at[0].set(pyramid[0]) # force_normal = sum(pyramid0_i + pyramid1_i) force = force.at[0].set(pyramid[0 : 2 * (condim - 1)].sum()) # force_tangent_i = (pyramid0_i - pyramid1_i) * mu_i i = np.arange(0, condim - 1) force = force.at[i + 1].set((pyramid[2 * i] - pyramid[2 * i + 1]) * mu[i]) return force def contact_force( m: Model, d: Data, contact_id: int, to_world_frame: bool = False ) -> jax.Array: """Extract 6D force:torque for one contact, in contact frame by default.""" efc_address = d._impl.contact.efc_address[contact_id] # pytype: disable=attribute-error condim = d._impl.contact.dim[contact_id] # pytype: disable=attribute-error if m.opt.cone == ConeType.PYRAMIDAL: force = _decode_pyramid( d._impl.efc_force[efc_address:], d._impl.contact.friction[contact_id], condim # pytype: disable=attribute-error ) elif m.opt.cone == ConeType.ELLIPTIC: force = d._impl.efc_force[efc_address : efc_address + condim] # pytype: disable=attribute-error force = jp.concatenate([force, jp.zeros((6 - condim))]) else: raise ValueError(f'Unknown cone type: {m.opt.cone}') if to_world_frame: force = force.reshape((-1, 3)) @ d._impl.contact.frame[contact_id] # pytype: disable=attribute-error force = force.reshape(-1) return force * (efc_address >= 0) def contact_force_dim( m: Model, d: Data, dim: int ) -> Tuple[jax.Array, np.ndarray]: """Extract 6D force:torque for contacts with dimension dim.""" # valid contact and condim indices idx_dim = (d._impl.contact.efc_address >= 0) & (d._impl.contact.dim == dim) # pytype: disable=attribute-error # contact force from efc if m.opt.cone == ConeType.PYRAMIDAL: efc_address = ( d._impl.contact.efc_address[idx_dim, None] # pytype: disable=attribute-error + np.arange(np.where(dim == 1, 1, 2 * (dim - 1)))[None] ) efc_force = d._impl.efc_force[efc_address] # pytype: disable=attribute-error force = jax.vmap(_decode_pyramid, in_axes=(0, 0, None))( efc_force, d._impl.contact.friction[idx_dim], dim # pytype: disable=attribute-error ) elif m.opt.cone == ConeType.ELLIPTIC: efc_address = d._impl.contact.efc_address[idx_dim, None] + np.arange(dim)[None] # pytype: disable=attribute-error force = d._impl.efc_force[efc_address] # pytype: disable=attribute-error force = jp.hstack([force, jp.zeros((force.shape[0], 6 - dim))]) else: raise ValueError(f'Unknown cone type: {m.opt.cone}.') return force, np.where(idx_dim)[0] def _length_circle( p0: jax.Array, p1: jax.Array, ind: jax.Array, rad: jax.Array ) -> jax.Array: """Compute length of circle.""" # compute angle between 0 and pi p0n = math.normalize(p0).reshape(-1) p1n = math.normalize(p1).reshape(-1) # clip input to closed interval for jp.arccos to prevent potential nan # TODO(taylorhowell): add test for case where clip is necessary angle = jp.arccos(jp.clip(jp.dot(p0n, p1n), -1, 1)) # flip if necessary cross = p0[1] * p1[0] - p0[0] * p1[1] flip = ((cross > 0) & (ind != 0)) | ((cross < 0) & (ind == 0)) angle = jp.where(flip, 2 * jp.pi - angle, angle) return rad * angle def _is_intersect( p1: jax.Array, p2: jax.Array, p3: jax.Array, p4: jax.Array ) -> jax.Array: """Check for intersection between two lines defined by their endpoints.""" # compute determinant det = (p4[1] - p3[1]) * (p2[0] - p1[0]) - (p4[0] - p3[0]) * (p2[1] - p1[1]) # compute intersection point on each line a = math.safe_div( (p4[0] - p3[0]) * (p1[1] - p3[1]) - (p4[1] - p3[1]) * (p1[0] - p3[0]), det ) b = math.safe_div( (p2[0] - p1[0]) * (p1[1] - p3[1]) - (p2[1] - p1[1]) * (p1[0] - p3[0]), det ) return jp.where( jp.abs(det) < mujoco.mjMINVAL, 0, (a >= 0) & (a <= 1) & (b >= 0) & (b <= 1), ) def wrap_circle( d: jax.Array, sd: jax.Array, sidesite: jax.Array, rad: jax.Array ) -> Tuple[jax.Array, jax.Array]: """Compute circle wrap arc length and end points.""" # check cases sqlen0 = d[0] ** 2 + d[1] ** 2 sqlen1 = d[2] ** 2 + d[3] ** 2 sqrad = rad * rad dif = jp.array([d[2] - d[0], d[3] - d[1]]) dd = dif[0] ** 2 + dif[1] ** 2 a = jp.clip( -(dif[0] * d[0] + dif[1] * d[1]) / jp.maximum(mujoco.mjMINVAL, dd), 0, 1 ) seg = jp.array([a * dif[0] + d[0], a * dif[1] + d[1]]) point_inside0 = sqlen0 < sqrad point_inside1 = sqlen1 < sqrad circle_too_small = rad < mujoco.mjMINVAL points_too_close = dd < mujoco.mjMINVAL intersect_and_side = (seg[0] ** 2 + seg[1] ** 2 > sqrad) & ( jp.where(sidesite, 0, 1) | (jp.dot(sd, seg) >= 0) ) # construct the two solutions, compute goodness def _sol(sgn): sqrt0 = jp.sqrt(jp.maximum(mujoco.mjMINVAL, sqlen0 - sqrad)) sqrt1 = jp.sqrt(jp.maximum(mujoco.mjMINVAL, sqlen1 - sqrad)) d00 = (d[0] * sqrad + sgn * rad * d[1] * sqrt0) / jp.maximum( mujoco.mjMINVAL, sqlen0 ) d01 = (d[1] * sqrad - sgn * rad * d[0] * sqrt0) / jp.maximum( mujoco.mjMINVAL, sqlen0 ) d10 = (d[2] * sqrad - sgn * rad * d[3] * sqrt1) / jp.maximum( mujoco.mjMINVAL, sqlen1 ) d11 = (d[3] * sqrad + sgn * rad * d[2] * sqrt1) / jp.maximum( mujoco.mjMINVAL, sqlen1 ) sol = jp.array([[d00, d01], [d10, d11]]) # goodness: close to sd, or shorter path tmp0 = sol[0] + sol[1] tmp0 = math.normalize(tmp0).reshape(-1) good0 = jp.dot(tmp0, sd) tmp1 = (sol[0] - sol[1]).reshape(-1) good1 = -jp.dot(tmp1, tmp1) good = jp.where(sidesite, good0, good1) # penalize for intersection intersect = _is_intersect(d[:2], sol[0], d[2:], sol[1]) good = jp.where(intersect, -10000, good) return sol, good sol, good = jax.vmap(_sol)(jp.array([1, -1])) # select the better solution i = jp.argmax(good) sol = sol[i] pnt = sol.reshape(-1) # check for intersection intersect = _is_intersect(d[:2], pnt[:2], d[2:], pnt[2:]) # compute curve length wlen = _length_circle(sol[0], sol[1], i, rad) # check cases invalid = ( point_inside0 | point_inside1 | circle_too_small | points_too_close | intersect_and_side | intersect ) wlen = jp.where(invalid, -1, wlen) pnt = jp.where(invalid, jp.zeros(4), pnt) return wlen, pnt def wrap_inside( end: jax.Array, radius: jax.Array, maxiter: int, tolerance: float, z_init: float, ) -> Tuple[jax.Array, jax.Array]: """Compute 2D inside wrap point. Args: end: 2D points radius: radius of circle Returns: status: 0 if wrap, else -1 concatentated 2D wrap points: jax.Array """ mjMINVAL = mujoco.mjMINVAL # pylint: disable=invalid-name # constants len0 = math.norm(end[:2]) len1 = math.norm(end[2:]) dif = jp.array([end[2] - end[0], end[3] - end[1]]) dd = dif[0] * dif[0] + dif[1] * dif[1] # either point inside circle or circle too small: no wrap no_wrap0 = ( (len0 <= radius) | (len1 <= radius) | (radius < mjMINVAL) | (len0 < mjMINVAL) | (len1 < mjMINVAL) ) # find nearest point on line segment to origin: d0 + a*dif a = -1 * (dif[0] * end[0] + dif[1] * end[1]) / jp.maximum(mjMINVAL, dd) tmp = end[:2] + a * dif # segment-circle intersection: no wrap no_wrap1 = (dd > mjMINVAL) & (a > 0) & (a < 1) & (math.norm(tmp) <= radius) # prepare default in case of numerical failure: average pnt_avg = 0.5 * jp.array([end[0] + end[2], end[1] + end[3]]) pnt_avg = radius * math.normalize(pnt_avg) # compute function parameters: asin(A*z) + asin(B*z) - 2*asin(z) + G = 0 A = radius / jp.maximum(mjMINVAL, len0) # pylint: disable=invalid-name B = radius / jp.maximum(mjMINVAL, len1) # pylint: disable=invalid-name cosG = (len0 * len0 + len1 * len1 - dd) / jp.maximum(mjMINVAL, 2 * len0 * len1) # pylint: disable=invalid-name no_wrap2 = cosG < -1 + mjMINVAL early_return0 = cosG > 1 - mjMINVAL G = jp.arccos(cosG) # pylint: disable=invalid-name # initialize solver z = jp.array([z_init]) f = jp.arcsin(A * z) + jp.arcsin(B * z) - 2 * jp.arcsin(z) + G # make sure initialization is not on the other side early_return1 = f > 0 # iteratively solve with Newton's method def _newton(carry, _): # unpack z, f, status_prev = carry # check current solution converged = jp.abs(f) <= tolerance # compute derivative df = ( A / jp.maximum(mjMINVAL, jp.sqrt(1 - z * z * A * A)) + B / jp.maximum(mjMINVAL, jp.sqrt(1 - z * z * B * B)) - 2 / jp.maximum(mjMINVAL, jp.sqrt(1 - z * z)) ) # check sign; SHOULD NOT OCCUR status0 = df > -mjMINVAL # new point z_next = z - (1 - converged) * math.safe_div(f, df) # make sure we are moving to the left; SHOULD NOT OCCUR status1 = z_next > z # evaluate solution f_next = ( jp.arcsin(A * z_next) + jp.arcsin(B * z_next) - 2 * jp.arcsin(z_next) + G ) # exit if positive; SHOULD NOT OCCUR status2 = f_next > tolerance return ( z_next, f_next, status_prev | status0 | status1 | status2, ), None # TODO(taylorhowell): compare performance of jax.lax.scan and jax.lax.while_loop z, _, early_return2 = jax.lax.scan( _newton, (z, f, jp.array([False])), None, maxiter )[0] # finalize: rotation by ang from vec = a or b, depending on cross(a,b) sign sign = end[0] * end[3] - end[1] * end[2] > 0 vec = jp.where(sign, end[:2], end[2:]) vec = math.normalize(vec) ang = jp.arcsin(z) - jp.where(sign, jp.arcsin(A * z), jp.arcsin(B * z)) pnt_sol = radius * jp.array([ jp.cos(ang) * vec[0] - jp.sin(ang) * vec[1], jp.sin(ang) * vec[0] + jp.cos(ang) * vec[1], ]).reshape(-1) no_wrap = no_wrap0 | no_wrap1 | no_wrap2 early_return = early_return0 | early_return1 | early_return2 status = -1 * no_wrap * jp.ones(1) pnt = jp.where(early_return, pnt_avg, pnt_sol) pnt = jp.where(no_wrap, jp.zeros(2), pnt) return status, jp.concatenate([pnt, pnt]) def wrap( x0: jax.Array, x1: jax.Array, xpos: jax.Array, xmat: jax.Array, size: jax.Array, side: jax.Array, sidesite: jax.Array, is_sphere: jax.Array, is_wrap_inside: bool, wrap_inside_maxiter: int, wrap_inside_tolerance: float, wrap_inside_z_init: float, ) -> Tuple[jax.Array, jax.Array, jax.Array]: """Wrap tendon around sphere or cylinder.""" # map sites to wrap object's local frame p0 = xmat.T @ (x0 - xpos) p1 = xmat.T @ (x1 - xpos) close_to_origin = (jp.linalg.norm(p0) < mujoco.mjMINVAL) | ( jp.linalg.norm(p1) < mujoco.mjMINVAL ) # compute axes for sphere # 1st axis axis0 = p0 axis0 = math.normalize(axis0) # compute normal to p0-0-p1 plane = cross(p0, p1) normal = jp.cross(p0, p1) normal, nrm = math.normalize_with_norm(normal) # compute alternative normal (if (p0, p1) are parallel) # find max component of axis0 axis_alt = jp.ones(3).at[jp.argmax(axis0)].set(0) normal_alt = jp.cross(axis0, axis_alt) normal_alt = math.normalize(normal_alt) normal = jp.where(nrm < mujoco.mjMINVAL, normal_alt, normal) # 2nd axis axis1 = jp.cross(normal, axis0) axis1 = math.normalize(axis1) # set geom dependent axes axis0 = jp.where(is_sphere, axis0, jp.array([1.0, 0.0, 0.0])) axis1 = jp.where(is_sphere, axis1, jp.array([0.0, 1.0, 0.0])) # project points in 2D frame: p => d d = jp.array([ jp.dot(p0, axis0), jp.dot(p0, axis1), jp.dot(p1, axis0), jp.dot(p1, axis1), ]) # compute sidesite projection s = xmat.T @ (side - xpos) sd = jp.array([jp.dot(s, axis0), jp.dot(s, axis1)]) sd = math.normalize(sd) * size if is_wrap_inside: wlen, pnt = wrap_inside( d, size, wrap_inside_maxiter, wrap_inside_tolerance, wrap_inside_z_init ) else: wlen, pnt = wrap_circle(d, sd, sidesite, size) no_wrap = wlen < 0 # reconstruct 3D points in local frame: res res0 = axis0 * pnt[0] + axis1 * pnt[1] res1 = axis0 * pnt[2] + axis1 * pnt[3] res = jp.concatenate([res0, res1]) # perform correction for cylinder case l0 = jp.sqrt( (p0[0] - res[0]) * (p0[0] - res[0]) + (p0[1] - res[1]) * (p0[1] - res[1]) ) l1 = jp.sqrt( (p1[0] - res[3]) * (p1[0] - res[3]) + (p1[1] - res[4]) * (p1[1] - res[4]) ) r2 = p0[2] + (p1[2] - p0[2]) * math.safe_div(l0, l0 + wlen + l1) r5 = p0[2] + (p1[2] - p0[2]) * math.safe_div(l0 + wlen, l0 + wlen + l1) height = jp.abs(r5 - r2) wlen = jp.where(is_sphere, wlen, jp.sqrt(wlen * wlen + height * height)) res = jp.where( is_sphere, res, res.at[jp.array([2, 5])].set(jp.concatenate([r2, r5])) ) # map wrap points back to global frame wpnt0 = xmat @ res[:3] + xpos wpnt1 = xmat @ res[3:] + xpos # check cases for no wrap invalid = close_to_origin | no_wrap wlen = jp.where(invalid, -1, wlen) wpnt0 = jp.where(invalid, jp.zeros(3), wpnt0) wpnt1 = jp.where(invalid, jp.zeros(3), wpnt1) return wlen, wpnt0, wpnt1 def muscle_gain_length( length: jax.Array, lmin: jax.Array, lmax: jax.Array ) -> jax.Array: """Normalized muscle length-gain curve.""" # mid-ranges (maximum is at 1.0) a = 0.5 * (lmin + 1) b = 0.5 * (1 + lmax) out0 = 0.5 * jp.square( (length - lmin) / jp.maximum(mujoco.mjMINVAL, a - lmin) ) out1 = 1 - 0.5 * jp.square((1 - length) / jp.maximum(mujoco.mjMINVAL, 1 - a)) out2 = 1 - 0.5 * jp.square((length - 1) / jp.maximum(mujoco.mjMINVAL, b - 1)) out3 = 0.5 * jp.square( (lmax - length) / jp.maximum(mujoco.mjMINVAL, lmax - b) ) out = jp.where(length <= b, out2, out3) out = jp.where(length <= 1, out1, out) out = jp.where(length <= a, out0, out) out = jp.where((lmin <= length) & (length <= lmax), out, 0.0) return out def muscle_gain( length: jax.Array, vel: jax.Array, lengthrange: jax.Array, acc0: jax.Array, prm: jax.Array, ) -> jax.Array: """Muscle active force.""" # unpack parameters lrange = prm[:2] force, scale, lmin, lmax, vmax, _, fvmax = prm[2:9] force = jp.where(force < 0, scale / jp.maximum(mujoco.mjMINVAL, acc0), force) # optimum length L0 = (lengthrange[1] - lengthrange[0]) / jp.maximum( # pylint:disable=invalid-name mujoco.mjMINVAL, lrange[1] - lrange[0] ) # normalized length and velocity L = lrange[0] + (length - lengthrange[0]) / jp.maximum(mujoco.mjMINVAL, L0) # pylint:disable=invalid-name V = vel / jp.maximum(mujoco.mjMINVAL, L0 * vmax) # pylint:disable=invalid-name # length curve FL = muscle_gain_length(L, lmin, lmax) # pylint:disable=invalid-name # velocity curve y = fvmax - 1 FV = jp.where( # pylint:disable=invalid-name V <= y, fvmax - jp.square(y - V) / jp.maximum(mujoco.mjMINVAL, y), fvmax ) FV = jp.where(V <= 0, jp.square(V + 1), FV) # pylint:disable=invalid-name FV = jp.where(V <= -1, 0, FV) # pylint:disable=invalid-name # compute FVL and scale, make it negative return -force * FL * FV def muscle_bias( length: jax.Array, lengthrange: jax.Array, acc0: jax.Array, prm: jax.Array ) -> jax.Array: """Muscle passive force.""" # unpack parameters lrange = prm[:2] force, scale, _, lmax, _, fpmax = prm[2:8] force = jp.where(force < 0, scale / jp.maximum(mujoco.mjMINVAL, acc0), force) # optimum length L0 = (lengthrange[1] - lengthrange[0]) / jp.maximum( # pylint:disable=invalid-name mujoco.mjMINVAL, lrange[1] - lrange[0] ) # normalized length L = lrange[0] + (length - lengthrange[0]) / jp.maximum(mujoco.mjMINVAL, L0) # pylint:disable=invalid-name # half-quadratic to (L0 + lmax) / 2, linear beyond b = 0.5 * (1 + lmax) out1 = ( -force * fpmax * 0.5 * jp.square((L - 1) / jp.maximum(mujoco.mjMINVAL, b - 1)) ) out2 = -force * fpmax * (0.5 + (L - b) / jp.maximum(mujoco.mjMINVAL, b - 1)) out = jp.where(L <= b, out1, out2) out = jp.where(L <= 1, 0.0, out) return out def muscle_dynamics_timescale( dctrl: jax.Array, tau_act: jax.Array, tau_deact: jax.Array, smoothing_width: jax.Array, ) -> jax.Array: """Muscle time constant with optional smoothing.""" # hard switching tau_hard = jp.where(dctrl > 0, tau_act, tau_deact) def _sigmoid(x): # sigmoid function over 0 <= x <= 1 using quintic polynomial # sigmoid: f(x) = 6 * x^5 - 15 * x^4 + 10 * x^3 # solution of f(0) = f'(0) = f''(0) = 0, f(1) = 1, f'(1) = f''(1) = 0 sol = x * x * x * (3 * x * (2 * x - 5) + 10) sol = jp.where(x <= 0, 0, sol) sol = jp.where(x >= 1, 1, sol) return sol # smooth switching # scale by width, center around 0.5 midpoint, rescale to bounds tau_smooth = tau_deact + (tau_act - tau_deact) * _sigmoid( math.safe_div(dctrl, smoothing_width) + 0.5 ) return jp.where(smoothing_width < mujoco.mjMINVAL, tau_hard, tau_smooth) def muscle_dynamics( ctrl: jax.Array, act: jax.Array, prm: jax.Array ) -> jax.Array: """Muscle activation dynamics.""" # clamp control ctrlclamp = jp.clip(ctrl, 0, 1) # clamp activation actclamp = jp.clip(act, 0, 1) # compute timescales as in Millard et at. (2013) # https://doi.org/10.1115/1.4023390 tau_act = prm[0] * (0.5 + 1.5 * actclamp) # activation timescale tau_deact = prm[1] / (0.5 + 1.5 * actclamp) # deactivation timescale smoothing_width = prm[2] # width of smoothing sigmoid dctrl = ctrlclamp - act # excess excitation tau = muscle_dynamics_timescale(dctrl, tau_act, tau_deact, smoothing_width) # filter output return dctrl / jp.maximum(mujoco.mjMINVAL, tau)