# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple
import warp as wp
from mujoco_warp._src.math import motion_cross
from mujoco_warp._src.types import ConeType
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import JointType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import State
from mujoco_warp._src.types import vec5
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope
wp.set_module_options({"enable_backward": False})
@cache_kernel
def mul_m_sparse(check_skip: bool):
@wp.kernel(module="unique")
def _mul_m_sparse(
# Model:
qM_mulm_rowadr: wp.array(dtype=int),
qM_mulm_col: wp.array(dtype=int),
qM_mulm_madr: wp.array(dtype=int),
# Data in:
qM_in: wp.array3d(dtype=float),
# In:
vec: wp.array2d(dtype=float),
skip: wp.array(dtype=bool),
# Out:
res: wp.array2d(dtype=float),
):
"""Sparse matmul: one thread per DOF, gather-based (no atomics)."""
worldid, dofid = wp.tid()
if wp.static(check_skip):
if skip[worldid]:
return
# Gather all contributions (diagonal + off-diagonal)
acc = float(0.0)
start = qM_mulm_rowadr[dofid]
end = qM_mulm_rowadr[dofid + 1]
for k in range(start, end):
col = qM_mulm_col[k]
madr = qM_mulm_madr[k]
acc += qM_in[worldid, 0, madr] * vec[worldid, col]
res[worldid, dofid] = acc
return _mul_m_sparse
@cache_kernel
def mul_m_dense(nv: int, check_skip: bool):
"""Simple SIMT dense matmul: one thread per output element."""
@wp.kernel(module="unique")
def _mul_m_dense(
# Data in:
qM_in: wp.array3d(dtype=float),
# In:
vec: wp.array2d(dtype=float),
skip: wp.array(dtype=bool),
# Out:
res: wp.array2d(dtype=float),
):
worldid, i = wp.tid()
if wp.static(check_skip):
if skip[worldid]:
return
acc = float(0.0)
for j in range(wp.static(nv)):
acc += qM_in[worldid, i, j] * vec[worldid, j]
res[worldid, i] = acc
return _mul_m_dense
[docs]
@event_scope
def mul_m(
m: Model,
d: Data,
res: wp.array2d(dtype=float),
vec: wp.array2d(dtype=float),
skip: Optional[wp.array] = None,
M: Optional[wp.array] = None,
):
"""Multiply vectors by inertia matrix; optionally skip per world.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
res: Result: qM @ vec.
vec: Input vector to multiply by qM.
skip: Per-world bitmask to skip computing output.
M: Input matrix: M @ vec.
"""
check_skip = skip is not None
skip = skip or wp.empty(0, dtype=bool)
if M is None:
M = d.qM
if m.is_sparse:
wp.launch(
mul_m_sparse(check_skip),
dim=(d.nworld, m.nv),
inputs=[m.qM_mulm_rowadr, m.qM_mulm_col, m.qM_mulm_madr, M, vec, skip],
outputs=[res],
)
else:
wp.launch(
mul_m_dense(m.nv, check_skip),
dim=(d.nworld, m.nv),
inputs=[M, vec, skip],
outputs=[res],
)
@wp.kernel
def _apply_ft(
# Model:
nbody: int,
body_parentid: wp.array(dtype=int),
body_rootid: wp.array(dtype=int),
dof_bodyid: wp.array(dtype=int),
# Data in:
xipos_in: wp.array2d(dtype=wp.vec3),
subtree_com_in: wp.array2d(dtype=wp.vec3),
cdof_in: wp.array2d(dtype=wp.spatial_vector),
# In:
ft_in: wp.array2d(dtype=wp.spatial_vector),
flg_add: bool,
# Out:
qfrc_out: wp.array2d(dtype=float),
):
worldid, dofid = wp.tid()
cdof = cdof_in[worldid, dofid]
rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2])
jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2])
dofbodyid = dof_bodyid[dofid]
accumul = float(0.0)
for bodyid in range(dofbodyid, nbody):
ft_body = ft_in[worldid, bodyid]
if ft_body == wp.spatial_vector():
continue
# any body that is in the subtree of dofbodyid is part of the jacobian
parentid = bodyid
while parentid != 0 and parentid != dofbodyid:
parentid = body_parentid[parentid]
if parentid == 0:
continue # body is not part of the subtree
offset = xipos_in[worldid, bodyid] - subtree_com_in[worldid, body_rootid[bodyid]]
cross_term = wp.cross(rotational_cdof, offset)
accumul += wp.dot(jac, ft_body) + wp.dot(cross_term, wp.spatial_top(ft_body))
if flg_add:
qfrc_out[worldid, dofid] += accumul
else:
qfrc_out[worldid, dofid] = accumul
def apply_ft(m: Model, d: Data, ft: wp.array2d(dtype=wp.spatial_vector), qfrc: wp.array2d(dtype=float), flg_add: bool):
wp.launch(
kernel=_apply_ft,
dim=(d.nworld, m.nv),
inputs=[m.nbody, m.body_parentid, m.body_rootid, m.dof_bodyid, d.xipos, d.subtree_com, d.cdof, ft, flg_add],
outputs=[qfrc],
)
[docs]
@event_scope
def xfrc_accumulate(m: Model, d: Data, qfrc: wp.array2d(dtype=float)):
"""Map applied forces at each body via Jacobians to dof space and accumulate.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
qfrc: Total applied force mapped to dof space.
"""
apply_ft(m, d, d.xfrc_applied, qfrc, True)
@wp.func
def _decode_pyramid(
njmax_in: int, pyramid: wp.array(dtype=float), efc_address: int, mu: vec5, condim: int
) -> wp.spatial_vector:
"""Converts pyramid representation to contact force."""
force = wp.spatial_vector()
if condim == 1:
force[0] = pyramid[efc_address]
return force
force[0] = float(0.0)
for i in range(condim - 1):
adr = 2 * i + efc_address
if adr < njmax_in:
dir1 = pyramid[adr]
else:
dir1 = 0.0
if adr + 1 < njmax_in:
dir2 = pyramid[adr + 1]
else:
dir2 = 0.0
force[0] += dir1 + dir2
force[i + 1] = (dir1 - dir2) * mu[i]
return force
@wp.func
def contact_force_fn(
# Model:
opt_cone: int,
# Data in:
contact_frame_in: wp.array(dtype=wp.mat33),
contact_friction_in: wp.array(dtype=vec5),
contact_dim_in: wp.array(dtype=int),
contact_efc_address_in: wp.array2d(dtype=int),
efc_force_in: wp.array2d(dtype=float),
njmax_in: int,
nacon_in: wp.array(dtype=int),
# In:
worldid: int,
contact_id: int,
to_world_frame: bool,
) -> wp.spatial_vector:
"""Extract 6D force:torque for one contact, in contact frame by default."""
force = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
condim = contact_dim_in[contact_id]
efc_address = contact_efc_address_in[contact_id, 0]
if contact_id >= 0 and contact_id <= nacon_in[0] and efc_address >= 0:
if opt_cone == ConeType.PYRAMIDAL:
force = _decode_pyramid(
njmax_in,
efc_force_in[worldid],
efc_address,
contact_friction_in[contact_id],
condim,
)
else:
for i in range(condim):
if contact_efc_address_in[contact_id, i] < njmax_in:
force[i] = efc_force_in[worldid, contact_efc_address_in[contact_id, i]]
if to_world_frame:
# Transform both top and bottom parts of spatial vector by the full contact frame matrix
t = wp.spatial_top(force) @ contact_frame_in[contact_id]
b = wp.spatial_bottom(force) @ contact_frame_in[contact_id]
force = wp.spatial_vector(t, b)
return force
@wp.kernel
def contact_force_kernel(
# Model:
opt_cone: int,
# Data in:
contact_frame_in: wp.array(dtype=wp.mat33),
contact_friction_in: wp.array(dtype=vec5),
contact_dim_in: wp.array(dtype=int),
contact_efc_address_in: wp.array2d(dtype=int),
contact_worldid_in: wp.array(dtype=int),
efc_force_in: wp.array2d(dtype=float),
njmax_in: int,
nacon_in: wp.array(dtype=int),
# In:
contact_ids: wp.array(dtype=int),
to_world_frame: bool,
# Out:
out: wp.array(dtype=wp.spatial_vector),
):
tid = wp.tid()
contactid = contact_ids[tid]
if contactid >= nacon_in[0]:
return
worldid = contact_worldid_in[contactid]
out[tid] = contact_force_fn(
opt_cone,
contact_frame_in,
contact_friction_in,
contact_dim_in,
contact_efc_address_in,
efc_force_in,
njmax_in,
nacon_in,
worldid,
contactid,
to_world_frame,
)
@wp.func
def transform_force(force: wp.vec3, torque: wp.vec3, offset: wp.vec3) -> wp.spatial_vector:
return wp.spatial_vector(torque - wp.cross(offset, force), force)
@wp.func
def transform_force(frc: wp.spatial_vector, offset: wp.vec3) -> wp.spatial_vector:
force = wp.spatial_top(frc)
torque = wp.spatial_bottom(frc)
return transform_force(force, torque, offset)
@wp.func
def jac_dof(
# Model:
body_parentid: wp.array(dtype=int),
body_rootid: wp.array(dtype=int),
dof_bodyid: wp.array(dtype=int),
# Data in:
subtree_com_in: wp.array2d(dtype=wp.vec3),
cdof_in: wp.array2d(dtype=wp.spatial_vector),
# In:
point: wp.vec3,
bodyid: int,
dofid: int,
worldid: int,
) -> Tuple[wp.vec3, wp.vec3]:
dof_bodyid_ = dof_bodyid[dofid]
in_tree = int(dof_bodyid_ == 0)
parentid = bodyid
while parentid != 0:
if parentid == dof_bodyid_:
in_tree = 1
break
parentid = body_parentid[parentid]
if not in_tree:
return wp.vec3(0.0), wp.vec3(0.0)
offset = point - wp.vec3(subtree_com_in[worldid, body_rootid[bodyid]])
cdof = cdof_in[worldid, dofid]
cdof_ang = wp.spatial_top(cdof)
cdof_lin = wp.spatial_bottom(cdof)
jacp = cdof_lin + wp.cross(cdof_ang, offset)
jacr = cdof_ang
return jacp, jacr
@cache_kernel
def _make_jac_kernel(has_jacp: bool, has_jacr: bool):
@wp.kernel(module="unique", enable_backward=False)
def _jac(
# Model:
body_parentid: wp.array(dtype=int),
body_rootid: wp.array(dtype=int),
dof_bodyid: wp.array(dtype=int),
# Data in:
subtree_com_in: wp.array2d(dtype=wp.vec3),
cdof_in: wp.array2d(dtype=wp.spatial_vector),
# In:
point_in: wp.array(dtype=wp.vec3),
bodyid_in: wp.array(dtype=int),
# Out:
jacp_out: wp.array3d(dtype=float),
jacr_out: wp.array3d(dtype=float),
):
worldid, dofid = wp.tid()
jacp_val, jacr_val = jac_dof(
body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, point_in[worldid], bodyid_in[worldid], dofid, worldid
)
if wp.static(has_jacp):
jacp_out[worldid, 0, dofid] = jacp_val[0]
jacp_out[worldid, 1, dofid] = jacp_val[1]
jacp_out[worldid, 2, dofid] = jacp_val[2]
if wp.static(has_jacr):
jacr_out[worldid, 0, dofid] = jacr_val[0]
jacr_out[worldid, 1, dofid] = jacr_val[1]
jacr_out[worldid, 2, dofid] = jacr_val[2]
return _jac
[docs]
@event_scope
def jac(
m: Model,
d: Data,
jacp: wp.array | None, # wp.array3d(dtype=float)
jacr: wp.array | None, # wp.array3d(dtype=float)
point: wp.array(dtype=wp.vec3),
body: wp.array(dtype=int),
):
"""Compute translational and rotational Jacobian for point on body.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state (device).
jacp: Output translational Jacobian (optional).
jacr: Output rotational Jacobian (optional).
point: 3D point in global coordinates.
body: Body ID for each world.
"""
kernel = _make_jac_kernel(jacp is not None, jacr is not None)
jacp_arr = jacp or wp.empty((0, 0, 0), dtype=float)
jacr_arr = jacr or wp.empty((0, 0, 0), dtype=float)
wp.launch(
kernel,
dim=(d.nworld, m.nv),
inputs=[m.body_parentid, m.body_rootid, m.dof_bodyid, d.subtree_com, d.cdof, point, body],
outputs=[jacp_arr, jacr_arr],
)
@wp.func
def jac_dot_dof(
# Model:
body_parentid: wp.array(dtype=int),
body_rootid: wp.array(dtype=int),
jnt_type: wp.array(dtype=int),
jnt_dofadr: wp.array(dtype=int),
dof_bodyid: wp.array(dtype=int),
dof_jntid: wp.array(dtype=int),
# Data in:
subtree_com_in: wp.array2d(dtype=wp.vec3),
cdof_in: wp.array2d(dtype=wp.spatial_vector),
cvel_in: wp.array2d(dtype=wp.spatial_vector),
cdof_dot_in: wp.array2d(dtype=wp.spatial_vector),
# In:
point: wp.vec3,
bodyid: int,
dofid: int,
worldid: int,
) -> Tuple[wp.vec3, wp.vec3]:
dof_bodyid_ = dof_bodyid[dofid]
in_tree = int(dof_bodyid_ == 0)
parentid = bodyid
while parentid != 0:
if parentid == dof_bodyid_:
in_tree = 1
break
parentid = body_parentid[parentid]
if not in_tree:
return wp.vec3(0.0), wp.vec3(0.0)
com = subtree_com_in[worldid, body_rootid[bodyid]]
offset = point - com
# transform spatial
cvel = cvel_in[worldid, bodyid]
pvel_lin = wp.spatial_bottom(cvel) - wp.cross(offset, wp.spatial_top(cvel))
cdof = cdof_in[worldid, dofid]
cdof_dot = cdof_dot_in[worldid, dofid]
# check for quaternion
dofjntid = dof_jntid[dofid]
jnttype = jnt_type[dofjntid]
jntdofadr = jnt_dofadr[dofjntid]
if (jnttype == JointType.BALL) or ((jnttype == JointType.FREE) and dofid >= jntdofadr + 3):
# compute cdof_dot for quaternion (use current body cvel)
cvel = cvel_in[worldid, dof_bodyid[dofid]]
cdof_dot = motion_cross(cvel, cdof)
cdof_dot_ang = wp.spatial_top(cdof_dot)
cdof_dot_lin = wp.spatial_bottom(cdof_dot)
# construct translational Jacobian (correct for rotation)
# first correction term, account for varying cdof
correction1 = wp.cross(cdof_dot_ang, offset)
# second correction term, account for point translational velocity
correction2 = wp.cross(wp.spatial_top(cdof), pvel_lin)
jacp = cdof_dot_lin + correction1 + correction2
jacr = cdof_dot_ang
return jacp, jacr
[docs]
def get_state(m: Model, d: Data, state: wp.array2d(dtype=float), sig: int, active: Optional[wp.array] = None):
"""Copy concatenated state components specified by sig from Data into state.
The bits of the integer sig correspond to element fields of State.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output information (device).
state: Concatenation of state components.
sig: Bitflag specifying state components.
active: Per-world bitmask for getting state.
"""
if sig >= (1 << State.NSTATE):
raise ValueError(f"invalid state signature {sig} >= 2^mjNSTATE")
@wp.kernel(module="unique", enable_backward=False)
def _get_state(
# Model:
nq: int,
nv: int,
nu: int,
na: int,
nbody: int,
neq: int,
nmocap: int,
# Data in:
time_in: wp.array(dtype=float),
qpos_in: wp.array2d(dtype=float),
qvel_in: wp.array2d(dtype=float),
act_in: wp.array2d(dtype=float),
qacc_warmstart_in: wp.array2d(dtype=float),
ctrl_in: wp.array2d(dtype=float),
qfrc_applied_in: wp.array2d(dtype=float),
xfrc_applied_in: wp.array2d(dtype=wp.spatial_vector),
eq_active_in: wp.array2d(dtype=bool),
mocap_pos_in: wp.array2d(dtype=wp.vec3),
mocap_quat_in: wp.array2d(dtype=wp.quat),
# In:
sig_in: int,
active_in: wp.array(dtype=bool),
# Out:
state_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
if wp.static(active is not None):
if not active_in[worldid]:
return
adr = int(0)
for i in range(State.NSTATE.value):
element = 1 << i
if element & sig_in:
if element == State.TIME:
state_out[worldid, adr] = time_in[worldid]
adr += 1
elif element == State.QPOS:
for j in range(nq):
state_out[worldid, adr + j] = qpos_in[worldid, j]
adr += nq
elif element == State.QVEL:
for j in range(nv):
state_out[worldid, adr + j] = qvel_in[worldid, j]
adr += nv
elif element == State.ACT:
for j in range(na):
state_out[worldid, adr + j] = act_in[worldid, j]
adr += na
elif element == State.WARMSTART:
for j in range(nv):
state_out[worldid, adr + j] = qacc_warmstart_in[worldid, j]
adr += nv
elif element == State.CTRL:
for j in range(nu):
state_out[worldid, adr + j] = ctrl_in[worldid, j]
adr += nu
elif element == State.QFRC_APPLIED:
for j in range(nv):
state_out[worldid, adr + j] = qfrc_applied_in[worldid, j]
adr += nv
elif element == State.XFRC_APPLIED:
for j in range(nbody):
xfrc = xfrc_applied_in[worldid, j]
state_out[worldid, adr + 0] = xfrc[0]
state_out[worldid, adr + 1] = xfrc[1]
state_out[worldid, adr + 2] = xfrc[2]
state_out[worldid, adr + 3] = xfrc[3]
state_out[worldid, adr + 4] = xfrc[4]
state_out[worldid, adr + 5] = xfrc[5]
adr += 6
elif element == State.EQ_ACTIVE:
for j in range(neq):
state_out[worldid, adr + j] = float(eq_active_in[worldid, j])
adr += j
elif element == State.MOCAP_POS:
for j in range(nmocap):
pos = mocap_pos_in[worldid, j]
state_out[worldid, adr + 0] = pos[0]
state_out[worldid, adr + 1] = pos[1]
state_out[worldid, adr + 2] = pos[2]
adr += 3
elif element == State.MOCAP_QUAT:
for j in range(nmocap):
quat = mocap_quat_in[worldid, j]
state_out[worldid, adr + 0] = quat[0]
state_out[worldid, adr + 1] = quat[1]
state_out[worldid, adr + 2] = quat[2]
state_out[worldid, adr + 3] = quat[3]
adr += 4
wp.launch(
_get_state,
dim=d.nworld,
inputs=[
m.nq,
m.nv,
m.nu,
m.na,
m.nbody,
m.neq,
m.nmocap,
d.time,
d.qpos,
d.qvel,
d.act,
d.qacc_warmstart,
d.ctrl,
d.qfrc_applied,
d.xfrc_applied,
d.eq_active,
d.mocap_pos,
d.mocap_quat,
int(sig),
active or wp.ones(d.nworld, dtype=bool),
],
outputs=[state],
)
[docs]
def set_state(m: Model, d: Data, state: wp.array2d(dtype=float), sig: int, active: Optional[wp.array] = None):
"""Copy concatenated state components specified by sig from state into Data.
The bits of the integer sig correspond to element fields of State.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output information (device).
state: Concatenation of state components.
sig: Bitflag specifying state components.
active: Per-world bitmask for setting state.
"""
if sig >= (1 << State.NSTATE):
raise ValueError(f"invalid state signature {sig} >= 2^mjNSTATE")
@wp.kernel(module="unique", enable_backward=False)
def _set_state(
# Model:
nq: int,
nv: int,
nu: int,
na: int,
nbody: int,
neq: int,
nmocap: int,
# In:
sig_in: int,
active_in: wp.array(dtype=bool),
state_in: wp.array2d(dtype=float),
# Data out:
time_out: wp.array(dtype=float),
qpos_out: wp.array2d(dtype=float),
qvel_out: wp.array2d(dtype=float),
act_out: wp.array2d(dtype=float),
qacc_warmstart_out: wp.array2d(dtype=float),
ctrl_out: wp.array2d(dtype=float),
qfrc_applied_out: wp.array2d(dtype=float),
xfrc_applied_out: wp.array2d(dtype=wp.spatial_vector),
eq_active_out: wp.array2d(dtype=bool),
mocap_pos_out: wp.array2d(dtype=wp.vec3),
mocap_quat_out: wp.array2d(dtype=wp.quat),
):
worldid = wp.tid()
if wp.static(active is not None):
if not active_in[worldid]:
return
adr = int(0)
for i in range(State.NSTATE.value):
element = 1 << i
if element & sig_in:
if element == State.TIME:
time_out[worldid] = state_in[worldid, adr]
adr += 1
elif element == State.QPOS:
for j in range(nq):
qpos_out[worldid, j] = state_in[worldid, adr + j]
adr += nq
elif element == State.QVEL:
for j in range(nv):
qvel_out[worldid, j] = state_in[worldid, adr + j]
adr += nv
elif element == State.ACT:
for j in range(na):
act_out[worldid, j] = state_in[worldid, adr + j]
adr += na
elif element == State.WARMSTART:
for j in range(nv):
qacc_warmstart_out[worldid, j] = state_in[worldid, adr + j]
adr += nv
elif element == State.CTRL:
for j in range(nu):
ctrl_out[worldid, j] = state_in[worldid, adr + j]
adr += nu
elif element == State.QFRC_APPLIED:
for j in range(nv):
qfrc_applied_out[worldid, j] = state_in[worldid, adr + j]
adr += nv
elif element == State.XFRC_APPLIED:
for j in range(nbody):
xfrc = wp.spatial_vector(
state_in[worldid, adr + 0],
state_in[worldid, adr + 1],
state_in[worldid, adr + 2],
state_in[worldid, adr + 3],
state_in[worldid, adr + 4],
state_in[worldid, adr + 5],
)
xfrc_applied_out[worldid, j] = xfrc
adr += 6
elif element == State.EQ_ACTIVE:
for j in range(neq):
eq_active_out[worldid, j] = bool(state_in[worldid, adr + j])
adr += j
elif element == State.MOCAP_POS:
for j in range(nmocap):
pos = wp.vec3(
state_in[worldid, adr + 1],
state_in[worldid, adr + 0],
state_in[worldid, adr + 2],
)
mocap_pos_out[worldid, j] = pos
adr += 3
elif element == State.MOCAP_QUAT:
for j in range(nmocap):
quat = wp.quat(
state_in[worldid, adr + 0],
state_in[worldid, adr + 1],
state_in[worldid, adr + 2],
state_in[worldid, adr + 3],
)
mocap_quat_out[worldid, j] = quat
adr += 4
wp.launch(
_set_state,
dim=d.nworld,
inputs=[
m.nq,
m.nv,
m.nu,
m.na,
m.nbody,
m.neq,
m.nmocap,
int(sig),
active or wp.ones(d.nworld, dtype=bool),
state,
],
outputs=[
d.time,
d.qpos,
d.qvel,
d.act,
d.qacc_warmstart,
d.ctrl,
d.qfrc_applied,
d.xfrc_applied,
d.eq_active,
d.mocap_pos,
d.mocap_quat,
],
)