# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core non-smooth constraint functions."""
from typing import Optional, Tuple, Union
import jax
from jax import numpy as jp
import mujoco
from mujoco.mjx._src import collision_driver
from mujoco.mjx._src import math
from mujoco.mjx._src import support
# pylint: disable=g-importing-member
from mujoco.mjx._src.dataclasses import PyTreeNode
from mujoco.mjx._src.types import ConeType
from mujoco.mjx._src.types import ConstraintType
from mujoco.mjx._src.types import Contact
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DataJAX
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import EqType
from mujoco.mjx._src.types import JointType
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ModelJAX
from mujoco.mjx._src.types import ObjType
from mujoco.mjx._src.types import OptionJAX
# pylint: enable=g-importing-member
import numpy as np
class _Efc(PyTreeNode):
"""Support data for creating constraint matrices."""
J: jax.Array
pos_aref: jax.Array
pos_imp: jax.Array
invweight: jax.Array
solref: jax.Array
solimp: jax.Array
margin: jax.Array
frictionloss: jax.Array
def _kbi(
m: Model,
solref: jax.Array,
solimp: jax.Array,
pos: jax.Array,
) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Calculates stiffness, damping, and impedance of a constraint."""
timeconst, dampratio = solref
if not m.opt.disableflags & DisableBit.REFSAFE:
timeconst = jp.maximum(timeconst, 2 * m.opt.timestep)
dmin, dmax, width, mid, power = solimp
dmin = jp.clip(dmin, mujoco.mjMINIMP, mujoco.mjMAXIMP)
dmax = jp.clip(dmax, mujoco.mjMINIMP, mujoco.mjMAXIMP)
width = jp.maximum(mujoco.mjMINVAL, width)
mid = jp.clip(mid, mujoco.mjMINIMP, mujoco.mjMAXIMP)
power = jp.maximum(1, power)
# See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
k = 1 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio)
b = 2 / (dmax * timeconst)
# TODO(robotics-simulation): check various solparam settings in model gen test
k = jp.where(solref[0] <= 0, -solref[0] / (dmax * dmax), k)
b = jp.where(solref[1] <= 0, -solref[1] / dmax, b)
imp_x = jp.abs(pos) / width
imp_a = (1.0 / jp.power(mid, power - 1)) * jp.power(imp_x, power)
imp_b = 1 - (1.0 / jp.power(1 - mid, power - 1)) * jp.power(1 - imp_x, power)
imp_y = jp.where(imp_x < mid, imp_a, imp_b)
imp = dmin + imp_y * (dmax - dmin)
imp = jp.clip(imp, dmin, dmax)
imp = jp.where(imp_x > 1.0, dmax, imp)
return k, b, imp # corresponds to K, B, I of efc_KBIP
def _row(j: jax.Array, *args) -> _Efc:
"""Creates an efc row, ensuring args all have same row count."""
if len(j.shape) < 2:
return _Efc(j, *args) # if j isn't batched, ignore
args = list(args)
for i, arg in enumerate(args):
if not arg.shape or arg.shape[0] != j.shape[0]:
args[i] = jp.tile(arg, (j.shape[0],) + (1,) * (len(arg.shape)))
return _Efc(j, *args)
def _efc_equality_connect(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for connect equality constraints."""
eq_id = np.nonzero(m.eq_type == EqType.CONNECT)[0]
if (m.opt.disableflags & DisableBit.EQUALITY) or eq_id.size == 0:
return None
@jax.vmap
def rows(
is_site, obj1id, obj2id, body1id, body2id, data, solref, solimp, active
):
anchor1, anchor2 = data[0:3], data[3:6]
pos1 = d.xmat[body1id] @ anchor1 + d.xpos[body1id]
pos2 = d.xmat[body2id] @ anchor2 + d.xpos[body2id]
if m.nsite:
pos1 = jp.where(is_site, d.site_xpos[obj1id], pos1)
pos2 = jp.where(is_site, d.site_xpos[obj2id], pos2)
# error is difference in global positions
pos = pos1 - pos2
# compute Jacobian difference (opposite of contact: 0 - 1)
jacp1, _ = support.jac(m, d, pos1, body1id)
jacp2, _ = support.jac(m, d, pos2, body2id)
j = (jacp1 - jacp2).T
pos_imp = math.norm(pos)
invweight = m.body_invweight0[body1id, 0] + m.body_invweight0[body2id, 0]
zero = jp.zeros_like(pos)
efc = _row(j, pos, pos_imp, invweight, solref, solimp, zero, zero)
return jax.tree_util.tree_map(lambda x: x * active, efc)
is_site = m.eq_objtype == ObjType.SITE
body1id = np.copy(m.eq_obj1id)
body2id = np.copy(m.eq_obj2id)
if m.nsite:
body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]]
body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]]
args = (
is_site,
m.eq_obj1id,
m.eq_obj2id,
body1id,
body2id,
m.eq_data,
m.eq_solref,
m.eq_solimp,
d.eq_active,
)
args = jax.tree_util.tree_map(lambda x: x[eq_id], args)
# concatenate to drop row grouping
return jax.tree_util.tree_map(jp.concatenate, rows(*args))
def _efc_equality_weld(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for weld equality constraints."""
eq_id = np.nonzero(m.eq_type == EqType.WELD)[0]
if (m.opt.disableflags & DisableBit.EQUALITY) or eq_id.size == 0:
return None
@jax.vmap
def rows(
is_site, obj1id, obj2id, body1id, body2id, data, solref, solimp, active
):
anchor1, anchor2 = data[0:3], data[3:6]
relpose, torquescale = data[6:10], data[10]
# error is difference in global position and orientation
pos1 = d.xmat[body1id] @ anchor2 + d.xpos[body1id]
pos2 = d.xmat[body2id] @ anchor1 + d.xpos[body2id]
if m.nsite:
pos1 = jp.where(is_site, d.site_xpos[obj1id], pos1)
pos2 = jp.where(is_site, d.site_xpos[obj2id], pos2)
cpos = pos1 - pos2
# compute Jacobian difference (opposite of contact: 0 - 1)
jacp1, jacr1 = support.jac(m, d, pos1, body1id)
jacp2, jacr2 = support.jac(m, d, pos2, body2id)
jacdifp = jacp1 - jacp2
jacdifr = (jacr1 - jacr2) * torquescale
# compute orientation error: neg(q1) * q0 * relpose (axis components only)
quat = math.quat_mul(d.xquat[body1id], relpose)
quat1 = math.quat_inv(d.xquat[body2id])
if m.nsite:
quat = jp.where(
is_site, math.quat_mul(d.xquat[body1id], m.site_quat[obj1id]), quat
)
quat1 = jp.where(
is_site,
math.quat_inv(math.quat_mul(d.xquat[body2id], m.site_quat[obj2id])),
quat1,
)
crot = math.quat_mul(quat1, quat)[1:] # copy axis components
pos = jp.concatenate((cpos, crot * torquescale))
# correct rotation Jacobian: 0.5 * neg(q1) * (jac0-jac1) * q0 * relpose
jac_fn = lambda j: math.quat_mul(math.quat_mul_axis(quat1, j), quat)[1:]
jacdifr = 0.5 * jax.vmap(jac_fn)(jacdifr)
j = jp.concatenate((jacdifp.T, jacdifr.T))
pos_imp = math.norm(pos)
invweight = m.body_invweight0[body1id] + m.body_invweight0[body2id]
invweight = jp.repeat(invweight, 3, axis=0)
zero = jp.zeros_like(pos)
efc = _row(j, pos, pos_imp, invweight, solref, solimp, zero, zero)
return jax.tree_util.tree_map(lambda x: x * active, efc)
is_site = m.eq_objtype == ObjType.SITE
body1id = np.copy(m.eq_obj1id)
body2id = np.copy(m.eq_obj2id)
if m.nsite:
body1id[is_site] = m.site_bodyid[m.eq_obj1id[is_site]]
body2id[is_site] = m.site_bodyid[m.eq_obj2id[is_site]]
args = (
is_site,
m.eq_obj1id,
m.eq_obj2id,
body1id,
body2id,
m.eq_data,
m.eq_solref,
m.eq_solimp,
d.eq_active,
)
args = jax.tree_util.tree_map(lambda x: x[eq_id], args)
# concatenate to drop row grouping
return jax.tree_util.tree_map(jp.concatenate, rows(*args))
def _efc_equality_joint(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for joint equality constraints."""
eq_id = np.nonzero(m.eq_type == EqType.JOINT)[0]
if (m.opt.disableflags & DisableBit.EQUALITY) or eq_id.size == 0:
return None
@jax.vmap
def rows(
obj2id, data, solref, solimp, active, dofadr1, dofadr2, qposadr1, qposadr2
):
pos1, pos2 = d.qpos[qposadr1], d.qpos[qposadr2]
ref1, ref2 = m.qpos0[qposadr1], m.qpos0[qposadr2]
dif = (pos2 - ref2) * (obj2id > -1)
dif_power = jp.power(dif, jp.arange(0, 5))
pos = pos1 - ref1 - jp.dot(data[:5], dif_power)
deriv = jp.dot(data[1:5], dif_power[:4] * jp.arange(1, 5)) * (obj2id > -1)
j = jp.zeros((m.nv)).at[dofadr2].set(-deriv).at[dofadr1].set(1.0)
invweight = m.dof_invweight0[dofadr1]
invweight += m.dof_invweight0[dofadr2] * (obj2id > -1)
zero = jp.zeros_like(pos)
efc = _row(j, pos, pos, invweight, solref, solimp, zero, zero)
return jax.tree_util.tree_map(lambda x: x * active, efc)
args = (m.eq_obj1id, m.eq_obj2id, m.eq_data, m.eq_solref, m.eq_solimp)
args += (d.eq_active,)
args = jax.tree_util.tree_map(lambda x: x[eq_id], args)
dofadr1, dofadr2 = m.jnt_dofadr[args[0]], m.jnt_dofadr[args[1]]
qposadr1, qposadr2 = m.jnt_qposadr[args[0]], m.jnt_qposadr[args[1]]
args = args[1:] + (dofadr1, dofadr2, qposadr1, qposadr2)
return rows(*args)
def _efc_equality_tendon(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for tendon equality constraints."""
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
raise ValueError(
'_efc_equality_tendon requires JAX backend implementation.'
)
eq_id = np.nonzero(m.eq_type == EqType.TENDON)[0]
if (m.opt.disableflags & DisableBit.EQUALITY) or eq_id.size == 0:
return None
obj1id, obj2id, data, solref, solimp, active = jax.tree_util.tree_map(
lambda x: x[eq_id],
(
m.eq_obj1id,
m.eq_obj2id,
m.eq_data,
m.eq_solref,
m.eq_solimp,
d.eq_active,
),
)
@jax.vmap
def rows(
obj2id, data, solref, solimp, invweight, jac1, jac2, pos1, pos2, active
):
dif = pos2 * (obj2id > -1)
dif_power = jp.power(dif, jp.arange(0, 5))
pos = pos1 - jp.dot(data[:5], dif_power)
deriv = jp.dot(data[1:5], dif_power[:4] * jp.arange(1, 5)) * (obj2id > -1)
j = jac1 + jac2 * -deriv
zero = jp.zeros_like(pos)
efc = _row(j, pos, pos, invweight, solref, solimp, zero, zero)
return jax.tree_util.tree_map(lambda x: x * active, efc)
inv1, inv2 = m.tendon_invweight0[obj1id], m.tendon_invweight0[obj2id]
jac1, jac2 = d._impl.ten_J[obj1id], d._impl.ten_J[obj2id]
pos1 = d.ten_length[obj1id] - m.tendon_length0[obj1id]
pos2 = d.ten_length[obj2id] - m.tendon_length0[obj2id]
invweight = inv1 + inv2 * (obj2id > -1)
return rows(
obj2id, data, solref, solimp, invweight, jac1, jac2, pos1, pos2, active
)
def _efc_friction(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for dof frictionloss."""
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
raise ValueError('_efc_friction requires JAX backend implementation.')
dof_id = np.nonzero(m._impl.dof_hasfrictionloss)[0]
tendon_id = np.nonzero(m._impl.tendon_hasfrictionloss)[0]
size = dof_id.size + tendon_id.size
if (m.opt.disableflags & DisableBit.FRICTIONLOSS) or (size == 0):
return None
args_dof = (jp.eye(m.nv), m.dof_frictionloss, m.dof_invweight0, m.dof_solref)
args_dof += (m.dof_solimp,)
args_dof = jax.tree_util.tree_map(lambda x: x[dof_id], args_dof)
args_ten = (d._impl.ten_J, m.tendon_frictionloss, m.tendon_invweight0)
args_ten += (m.tendon_solref_fri, m.tendon_solimp_fri)
args_ten = jax.tree_util.tree_map(lambda x: x[tendon_id], args_ten)
args = jax.tree_util.tree_map(
lambda *x: jp.concatenate(x), args_dof, args_ten
)
@jax.vmap
def rows(j, frictionloss, invweight, solref, solimp):
z = jp.zeros_like(frictionloss)
return _row(j, z, z, invweight, solref, solimp, z, frictionloss)
return rows(*args)
def _efc_limit_ball(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for ball joint limits."""
jnt_id = np.nonzero((m.jnt_type == JointType.BALL) & m.jnt_limited)[0]
if (m.opt.disableflags & DisableBit.LIMIT) or jnt_id.size == 0:
return None
@jax.vmap
def rows(qposadr, dofadr, jnt_range, jnt_margin, solref, solimp):
axis, angle = math.quat_to_axis_angle(d.qpos[jp.arange(4) + qposadr])
# ball rotation angle is always positive
axis, angle = math.normalize_with_norm(axis * angle)
pos = jp.amax(jnt_range) - angle - jnt_margin
active = pos < 0
j = jp.zeros(m.nv).at[jp.arange(3) + dofadr].set(-axis)
invweight = m.dof_invweight0[dofadr]
z = jp.zeros_like(pos)
return _row(
j * active, pos * active, pos, invweight, solref, solimp, jnt_margin, z
)
args = (m.jnt_qposadr, m.jnt_dofadr, m.jnt_range, m.jnt_margin, m.jnt_solref)
args += (m.jnt_solimp,)
args = jax.tree_util.tree_map(lambda x: x[jnt_id], args)
return rows(*args)
def _efc_limit_slide_hinge(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for slide and hinge joint limits."""
slide_hinge = np.isin(m.jnt_type, (JointType.SLIDE, JointType.HINGE))
jnt_id = np.nonzero(slide_hinge & m.jnt_limited)[0]
if (m.opt.disableflags & DisableBit.LIMIT) or jnt_id.size == 0:
return None
@jax.vmap
def rows(qposadr, dofadr, jnt_range, jnt_margin, solref, solimp):
qpos = d.qpos[qposadr]
dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos
pos = jp.minimum(dist_min, dist_max) - jnt_margin
active = pos < 0
j = jp.zeros(m.nv).at[dofadr].set((dist_min < dist_max) * 2 - 1)
invweight = m.dof_invweight0[dofadr]
z = jp.zeros_like(pos)
return _row(
j * active, pos * active, pos, invweight, solref, solimp, jnt_margin, z
)
args = (m.jnt_qposadr, m.jnt_dofadr, m.jnt_range, m.jnt_margin, m.jnt_solref)
args += (m.jnt_solimp,)
args = jax.tree_util.tree_map(lambda x: x[jnt_id], args)
return rows(*args)
def _efc_limit_tendon(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for tendon limits."""
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
raise ValueError('_efc_limit_tendon requires JAX backend implementation.')
tendon_id = np.nonzero(m.tendon_limited)[0]
if (m.opt.disableflags & DisableBit.LIMIT) or tendon_id.size == 0:
return None
length, j, range_, margin, invweight, solref, solimp = jax.tree_util.tree_map(
lambda x: x[tendon_id],
(
d.ten_length,
d._impl.ten_J,
m.tendon_range,
m.tendon_margin,
m.tendon_invweight0,
m.tendon_solref_lim,
m.tendon_solimp_lim,
),
)
dist_min, dist_max = length - range_[:, 0], range_[:, 1] - length
pos = jp.minimum(dist_min, dist_max) - margin
active = pos < 0
j = jax.vmap(jp.multiply)(j, ((dist_min < dist_max) * 2 - 1) * active)
zero = jp.zeros_like(pos)
return jax.vmap(_row)(
j, pos * active, pos, invweight, solref, solimp, margin, zero
)
def _efc_contact_frictionless(m: Model, d: Data) -> Optional[_Efc]:
"""Calculates constraint rows for frictionless contacts."""
if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX):
raise ValueError(
'_efc_contact_frictionless requires JAX backend implementation.'
)
con_id = np.nonzero(d._impl.contact.dim == 1)[0]
if con_id.size == 0:
return None
@jax.vmap
def rows(c: Contact):
pos = c.dist - c.includemargin
active = pos < 0
body1, body2 = jp.array(m.geom_bodyid)[c.geom]
jac1p, _ = support.jac(m, d, c.pos, body1)
jac2p, _ = support.jac(m, d, c.pos, body2)
j = (c.frame @ (jac2p - jac1p).T)[0]
invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0]
return _row(
j * active,
pos * active,
pos,
invweight,
c.solref,
c.solimp,
c.includemargin,
jp.zeros_like(pos),
)
contact = jax.tree_util.tree_map(lambda x: x[con_id], d._impl.contact)
return rows(contact)
def _efc_contact_pyramidal(m: Model, d: Data, condim: int) -> Optional[_Efc]:
"""Calculates constraint rows for frictional pyramidal contacts."""
if (
not isinstance(m._impl, ModelJAX)
or not isinstance(d._impl, DataJAX)
or not isinstance(m.opt._impl, OptionJAX)
):
raise ValueError(
'_efc_contact_pyramidal requires JAX backend implementation.'
)
con_id = np.nonzero(d._impl.contact.dim == condim)[0]
if con_id.size == 0:
return None
@jax.vmap
def rows(c: Contact):
pos = c.dist - c.includemargin
active = pos < 0
body1, body2 = jp.array(m.geom_bodyid)[c.geom]
jac1p, jac1r = support.jac(m, d, c.pos, body1)
jac2p, jac2r = support.jac(m, d, c.pos, body2)
diff = c.frame @ (jac2p - jac1p).T
if condim > 3:
diff = jp.concatenate((diff, (c.frame @ (jac2r - jac1r).T)), axis=0)
# a pair of opposing pyramid edges per friction dimension
# repeat friction directions with positive and negative sign
fri = jp.repeat(c.friction[: condim - 1], 2, axis=0).at[1::2].mul(-1)
# repeat condims of jacdiff to match +/- friction directions
j = diff[0] + jp.repeat(diff[1:condim], 2, axis=0) * fri[:, None]
# pyramidal has common invweight across all edges
invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0]
invweight = invweight + fri[0] * fri[0] * invweight
invweight = invweight * 2 * fri[0] * fri[0] / m.opt.impratio
return _row(
j * active,
pos * active,
pos,
invweight,
c.solref,
c.solimp,
c.includemargin,
jp.zeros_like(pos),
)
contact = jax.tree_util.tree_map(lambda x: x[con_id], d._impl.contact)
# concatenate to drop row grouping
return jax.tree_util.tree_map(jp.concatenate, rows(contact))
def _efc_contact_elliptic(m: Model, d: Data, condim: int) -> Optional[_Efc]:
"""Calculates constraint rows for frictional elliptic contacts."""
if (
not isinstance(m._impl, ModelJAX)
or not isinstance(d._impl, DataJAX)
or not isinstance(m.opt._impl, OptionJAX)
):
raise ValueError(
'_efc_contact_elliptic requires JAX backend implementation.'
)
con_id = np.nonzero(d._impl.contact.dim == condim)[0]
if con_id.size == 0:
return None
@jax.vmap
def rows(c: Contact):
pos = c.dist - c.includemargin
active = pos < 0
obj1id, obj2id = jp.array(m.geom_bodyid)[c.geom]
jac1p, jac1r = support.jac(m, d, c.pos, obj1id)
jac2p, jac2r = support.jac(m, d, c.pos, obj2id)
j = c.frame @ (jac2p - jac1p).T
if condim > 3:
j = jp.concatenate((j, (c.frame @ (jac2r - jac1r).T)[: condim - 3]))
invweight = m.body_invweight0[obj1id, 0] + m.body_invweight0[obj2id, 0]
# normal row comes from solref, remaining rows from solreffriction
solreffriction = c.solreffriction + c.solref * ~c.solreffriction.any()
solreffriction = jp.tile(solreffriction, (condim - 1, 1))
solref = jp.concatenate((c.solref[None], solreffriction))
fri = jp.square(c.friction[0]) / jp.square(c.friction[1 : condim - 1])
invweight = jp.array([invweight, invweight / m.opt.impratio])
invweight = jp.concatenate((invweight, invweight[1] * fri))
pos_aref = jp.zeros(condim).at[0].set(pos)
return _row(
j * active,
pos_aref * active,
pos,
invweight,
solref,
c.solimp,
c.includemargin,
jp.zeros_like(pos),
)
contact = jax.tree_util.tree_map(lambda x: x[con_id], d._impl.contact)
# concatenate to drop row grouping
return jax.tree_util.tree_map(jp.concatenate, rows(contact))
def counts(efc_type: np.ndarray) -> Tuple[int, int, int, int]:
"""Returns equality, friction, limit, and contact constraint counts."""
ne = (efc_type == ConstraintType.EQUALITY).sum()
nf = (efc_type == ConstraintType.FRICTION_DOF).sum()
nf += (efc_type == ConstraintType.FRICTION_TENDON).sum()
nl = (efc_type == ConstraintType.LIMIT_JOINT).sum()
nl += (efc_type == ConstraintType.LIMIT_TENDON).sum()
nc_f = (efc_type == ConstraintType.CONTACT_FRICTIONLESS).sum()
nc_p = (efc_type == ConstraintType.CONTACT_PYRAMIDAL).sum()
nc_e = (efc_type == ConstraintType.CONTACT_ELLIPTIC).sum()
nc = nc_f + nc_p + nc_e
return ne, nf, nl, nc
def make_efc_type(
m: Union[Model, mujoco.MjModel], dim: Optional[np.ndarray] = None
) -> np.ndarray:
"""Returns efc_type that outlines the type of each constraint row."""
if m.opt.disableflags & DisableBit.CONSTRAINT:
return np.empty(0, dtype=int)
dim = collision_driver.make_condim(m) if dim is None else dim
efc_types = []
if not m.opt.disableflags & DisableBit.EQUALITY:
num_rows = (m.eq_type == EqType.CONNECT).sum() * 3
num_rows += (m.eq_type == EqType.WELD).sum() * 6
num_rows += (m.eq_type == EqType.JOINT).sum()
num_rows += (m.eq_type == EqType.TENDON).sum()
efc_types += [ConstraintType.EQUALITY] * num_rows
if not m.opt.disableflags & DisableBit.FRICTIONLOSS:
nf_dof = (
m._impl.dof_hasfrictionloss.sum()
if isinstance(m, Model) and isinstance(m._impl, ModelJAX)
else (m.dof_frictionloss > 0).sum()
)
efc_types += [ConstraintType.FRICTION_DOF] * nf_dof
nf_tendon = (
m._impl.tendon_hasfrictionloss.sum()
if isinstance(m, Model) and isinstance(m._impl, ModelJAX)
else (m.tendon_frictionloss > 0).sum()
)
efc_types += [ConstraintType.FRICTION_TENDON] * nf_tendon
if not m.opt.disableflags & DisableBit.LIMIT:
efc_types += [ConstraintType.LIMIT_JOINT] * m.jnt_limited.sum()
efc_types += [ConstraintType.LIMIT_TENDON] * m.tendon_limited.sum()
if not m.opt.disableflags & DisableBit.CONTACT:
for condim in (1, 3, 4, 6):
n = (dim == condim).sum()
if condim == 1:
efc_types += [ConstraintType.CONTACT_FRICTIONLESS] * n
elif m.opt.cone == ConeType.PYRAMIDAL:
efc_types += [ConstraintType.CONTACT_PYRAMIDAL] * (condim - 1) * 2 * n
elif m.opt.cone == ConeType.ELLIPTIC:
efc_types += [ConstraintType.CONTACT_ELLIPTIC] * condim * n
else:
raise ValueError(f'Unknown cone: {m.opt.cone}')
return np.array(efc_types)
def make_efc_address(
m: Union[Model, mujoco.MjModel], dim: np.ndarray, efc_type: np.ndarray
) -> np.ndarray:
"""Returns efc_address that maps contacts to constraint row address."""
offsets = np.array([0], dtype=int)
for condim in (1, 3, 4, 6):
n = (dim == condim).sum()
if n == 0:
continue
if condim == 1:
offsets = np.concatenate((offsets, [1] * n))
elif m.opt.cone == ConeType.PYRAMIDAL:
offsets = np.concatenate((offsets, [(condim - 1) * 2] * n))
elif m.opt.cone == ConeType.ELLIPTIC:
offsets = np.concatenate((offsets, [condim] * n))
else:
raise ValueError(f'Unknown cone: {m.opt.cone}')
_, _, _, nc = counts(efc_type)
address = efc_type.size - nc + np.cumsum(offsets)[:-1]
return address
[docs]
def make_constraint(m: Model, d: Data) -> Data:
"""Creates constraint jacobians and other supporting data."""
if m.opt.disableflags & DisableBit.CONSTRAINT:
efcs = ()
else:
efcs = (
_efc_equality_connect(m, d),
_efc_equality_weld(m, d),
_efc_equality_joint(m, d),
_efc_equality_tendon(m, d),
_efc_friction(m, d),
_efc_limit_ball(m, d),
_efc_limit_slide_hinge(m, d),
_efc_limit_tendon(m, d),
_efc_contact_frictionless(m, d),
)
if m.opt.cone == ConeType.ELLIPTIC:
con_fn = _efc_contact_elliptic
else:
con_fn = _efc_contact_pyramidal
efcs += tuple(con_fn(m, d, dim) for dim in (3, 4, 6))
efcs = tuple(efc for efc in efcs if efc is not None)
if not efcs:
z = jp.empty(0)
d = d.tree_replace({'_impl.efc_J': jp.empty((0, m.nv))})
d = d.tree_replace({
'_impl.efc_D': z,
'_impl.efc_aref': z,
'_impl.efc_frictionloss': z,
'_impl.efc_pos': z,
'_impl.efc_margin': z,
})
return d
efc = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *efcs)
@jax.vmap
def fn(efc):
k, b, imp = _kbi(m, efc.solref, efc.solimp, efc.pos_imp)
r = jp.maximum(efc.invweight * (1 - imp) / imp, mujoco.mjMINVAL)
aref = -b * (efc.J @ d.qvel) - k * imp * efc.pos_aref
return aref, r, efc.pos_aref + efc.margin, efc.margin, efc.frictionloss
aref, r, pos, margin, frictionloss = fn(efc)
d = d.tree_replace({
'_impl.efc_J': efc.J,
'_impl.efc_D': 1 / r,
'_impl.efc_aref': aref,
'_impl.efc_pos': pos,
'_impl.efc_margin': margin,
})
d = d.tree_replace({'_impl.efc_frictionloss': frictionloss})
return d