Source code for mujoco.mjx._src.test_util

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing."""

import os
import sys
import time
from typing import Dict, Optional, Tuple
from xml.etree import ElementTree as ET

from etils import epath
import jax
import mujoco
# pylint: disable=g-importing-member
from mujoco.mjx._src import forward
from mujoco.mjx._src import io
from mujoco.mjx._src.types import Data
# pylint: enable=g-importing-member
import numpy as np


def _measure(fn, *args) -> Tuple[float, float]:
  """Reports jit time and op time for a function."""

  beg = time.perf_counter()
  compiled_fn = fn.lower(*args).compile()
  end = time.perf_counter()
  jit_time = end - beg

  beg = time.perf_counter()
  result = compiled_fn(*args)
  jax.block_until_ready(result)
  end = time.perf_counter()
  run_time = end - beg

  return jit_time, run_time


[docs] def benchmark( m: mujoco.MjModel, nstep: int = 1000, batch_size: int = 1024, unroll_steps: int = 1, solver: str = 'newton', iterations: int = 1, ls_iterations: int = 4, ) -> Tuple[float, float, int]: """Benchmark a model.""" xla_flags = os.environ.get('XLA_FLAGS', '') xla_flags += ' --xla_gpu_triton_gemm_any=True' os.environ['XLA_FLAGS'] = xla_flags m.opt.solver = { 'cg': mujoco.mjtSolver.mjSOL_CG, 'newton': mujoco.mjtSolver.mjSOL_NEWTON, }[solver.lower()] m.opt.iterations = iterations m.opt.ls_iterations = ls_iterations m = io.put_model(m) @jax.pmap def init(key): key = jax.random.split(key, batch_size // jax.device_count()) @jax.vmap def random_init(key): d = io.make_data(m) qvel = 0.01 * jax.random.normal(key, shape=(m.nv,)) d = d.replace(qvel=qvel) return d return random_init(key) key = jax.random.split(jax.random.key(0), jax.device_count()) d = init(key) jax.block_until_ready(d) @jax.pmap def unroll(d): @jax.vmap def step(d, _): d = forward.step(m, d) return d, None d, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps) return d jit_time, run_time = _measure(unroll, d) steps = nstep * batch_size return jit_time, run_time, steps
def efc_order(m: mujoco.MjModel, d: mujoco.MjData, dx: Data) -> np.ndarray: """Returns a sort order such that dx.efc_*[order][:d._impl.nefc] == d.efc_*.""" # pytype: disable=attribute-error # reorder efc rows to skip inactive constraints and match contact order efl = dx._impl.ne + dx._impl.nf + dx._impl.nl # pytype: disable=attribute-error order = np.arange(efl) order[(dx._impl.efc_J[:efl] == 0).all(axis=1)] = 2**16 # move empty rows to end # pytype: disable=attribute-error for i in range(dx._impl.ncon): # pytype: disable=attribute-error num_rows = dx._impl.contact.dim[i] # pytype: disable=attribute-error if dx._impl.contact.dim[i] > 1 and m.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL: # pytype: disable=attribute-error num_rows = (dx._impl.contact.dim[i] - 1) * 2 # pytype: disable=attribute-error if dx._impl.contact.dist[i] > 0: # move empty contacts to end # pytype: disable=attribute-error order = np.append(order, np.repeat(2**16, num_rows)) continue contact_match = (d.contact.geom == dx._impl.contact.geom[i]).all(axis=-1) # pytype: disable=attribute-error contact_match &= (d.contact.pos == dx._impl.contact.pos[i]).all(axis=-1) # pytype: disable=attribute-error assert contact_match.any(), f'contact {i} not found' contact_id = np.nonzero(contact_match)[0][0] order = np.append(order, np.repeat(efl + contact_id, num_rows)) return np.argsort(order, kind='stable') _ACTUATOR_TYPES = ['motor', 'velocity', 'position', 'general', 'intvelocity'] _DYN_TYPES = ['none', 'integrator', 'filter', 'filterexact'] _DYN_PRMS = ['0.189', '2.1'] _JOINT_TYPES = ['free', 'hinge', 'slide', 'ball'] _JOINT_AXES = ['1 0 0', '0 1 0', '0 0 1'] _FRICTIONS = ['1.2 0.003 0.0002', '0.2 0.0001 0.0005'] _KP_POS = ['1', '2'] _KP_INTVEL = ['10000', '2000'] _KV_VEL = ['12', '1', '0', '0.1'] _PAIR_FRICTIONS = ['1.2 0.9 0.003 0.0002 0.0001'] _SOLREFS = ['0.04 1.01', '0.05 1.02', '0.03 1.1', '0.015 1.0'] _SOLIMPS = [ '0.75 0.94 0.002 0.2 2', '0.8 0.99 0.001 0.3 6', '0.6 0.9 0.003 0.1 1', ] _DIMS = ['3'] _MARGINS = ['0.0', '0.01', '0.02'] _GAPS = ['0.0', '0.005'] _GEARS = ['2.1 0.0 3.3 0 2.3 0', '5.0 3.1 0 2.3 0.0 1.1'] def p(pct: int) -> bool: assert 0 <= pct <= 100 return np.random.uniform(low=0, high=100) < pct def _make_joint(joint_type: str, name: str) -> Dict[str, str]: """Returns attributes for a joint.""" joint_attr = {'type': joint_type, 'name': name} if joint_type not in ('free', 'ball'): joint_attr['axis'] = np.random.choice(_JOINT_AXES) lb, ub = -np.random.uniform() * 90, np.random.uniform() * 90 joint_attr['range'] = f'{lb:.2f} {ub:.2f}' elif joint_type == 'ball': joint_attr['axis'] = '1 0 0' ub = np.random.uniform() * 90 joint_attr['range'] = f'0.0 {ub:.2f}' if p(50) and joint_type != 'free': lb, ub = -np.random.uniform(), np.random.uniform() joint_attr['actuatorfrcrange'] = f'{lb:.2f} {ub:.2f}' if joint_type not in ('free',): joint_attr['damping'] = '{:.2f}'.format(np.random.uniform() * 20) joint_attr['stiffness'] = '{:.2f}'.format(np.random.uniform() * 20) joint_attr['actuatorgravcomp'] = np.random.choice(['true', 'false']) return joint_attr def _geom_solparams( pair: bool = False, enable_contact: bool = True ) -> Dict[str, str]: """Returns geom solver parameters.""" params = { 'contype': np.random.choice(['0', '1']) if enable_contact else '0', 'conaffinity': np.random.choice(['0', '1']) if enable_contact else '0', 'priority': np.random.choice(['-1', '2']), 'solmix': np.random.choice(['0.0', '1.6']), 'friction': np.random.choice(_FRICTIONS), 'condim': np.random.choice(_DIMS), } pair_params = { 'solreffriction': np.random.choice(_SOLREFS), 'friction': np.random.choice(_PAIR_FRICTIONS), 'condim': np.random.choice(_DIMS), } params = pair_params if pair else params params.update({ 'solimp': np.random.choice(_SOLIMPS), 'solref': np.random.choice(_SOLREFS), 'margin': np.random.choice(_MARGINS), 'gap': np.random.choice(_GAPS), }) return params def _make_geom( pos: str, size: float, name: str, enable_contact: bool = True ) -> Dict[str, str]: """Returns attributes for a sphere geom.""" attr = { 'pos': pos, 'type': 'sphere', 'name': name, 'size': f'{size:.2f}', 'mass': '1', } attr.update(_geom_solparams(pair=False, enable_contact=enable_contact)) return attr def _make_actuator( actuator_type: str, joint: Optional[str] = None, site: Optional[str] = None, refsite: Optional[str] = None, ) -> Dict[str, str]: """Returns attributes for an actuator.""" if joint: attr = {'joint': joint} elif site: attr = {'site': site} else: raise ValueError('must provide a joint or site name') if refsite: attr['refsite'] = refsite attr['gear'] = np.random.choice(_GEARS) # set actuator type if actuator_type == 'position': attr['kp'] = np.random.choice(_KP_POS) attr['kv'] = np.random.choice(_KV_VEL) elif actuator_type == 'general': attr['biastype'] = 'affine' attr['gainprm'] = '35 0 0' attr['biasprm'] = '0 -35 -0.65' elif actuator_type == 'intvelocity': attr['kp'] = np.random.choice(_KP_INTVEL) lb, ub = -np.random.uniform(), np.random.uniform() attr['actrange'] = f'{lb:.2f} {ub:.2f}' elif actuator_type == 'velocity': attr['kv'] = np.random.choice(_KV_VEL) # set dyntype if actuator_type == 'general': attr['dyntype'] = np.random.choice(_DYN_TYPES) if attr['dyntype'] != 'none': attr['dynprm'] = np.random.choice(_DYN_PRMS) # ctrlrange if p(50) and actuator_type != 'intvelocity': lb, ub = -np.random.uniform(), np.random.uniform() attr['ctrlrange'] = f'{lb:.2f} {ub:.2f}' # forcerange if p(50): lb, ub = -np.random.uniform(), np.random.uniform() attr['forcerange'] = f'{lb*10:.2f} {ub*10:.2f}' return attr def create_mjcf( seed: int, min_trees: int = 1, max_trees: int = 1, max_tree_depth: int = 5, body_pos: Tuple[float, float, float] = (0.0, 0.0, -0.5), geom_pos: Tuple[float, float, float] = (0.0, 0.0, 0.0), max_stacked_joints=4, max_geoms_per_body=2, max_contact_excludes=1, max_contact_pairs=4, disable_actuation_pct: int = 0, add_actuators: bool = False, root_always_free: bool = False, enable_contact: bool = True, ) -> str: """Creates a random MJCF for testing. Args: seed: seed for rng min_trees: minimum number of kinematic trees to generate max_trees: maximum number of kinematic trees to generate max_tree_depth: the maximum tree depth body_pos: the default body position relative to the parent geom_pos: the default geom position in the body frame max_stacked_joints: maximum number of joints to stack for each body max_geoms_per_body: maximum number of geoms per body max_contact_excludes: maximum number of bodies to exlude from contact max_contact_pairs: maximum number of explicit geom contact pairs in the xml disable_actuation_pct: the percentage of time to disable actuation via the disable flag add_actuators: whether to add actuators root_always_free: if True, the root body of each kinematic tree has a free joint with the world enable_contact: if False, disables all contacts via contype/conaffinity Returns: an XML string for the MuJoCo config Raises: AssertionError when args are not in the correct ranges """ np.random.seed(seed) assert min_trees <= max_trees assert max_tree_depth >= 1 assert 0 <= disable_actuation_pct <= 100 assert max_stacked_joints >= 1 assert max_geoms_per_body >= 1 assert max_contact_excludes >= 1 assert max_contact_pairs >= 1 mjcf = ET.Element('mujoco') opt = ET.SubElement(mjcf, 'option', {'timestep': '0.005', 'solver': 'CG'}) world = ET.SubElement(mjcf, 'worldbody') ET.SubElement(mjcf, 'compiler', {'autolimits': 'true'}) # disable flags if p(disable_actuation_pct): ET.SubElement(opt, 'flag', {'actuation': 'disable'}) ET.SubElement( world, 'geom', { 'name': 'plane', 'type': 'plane', 'contype': '1' if enable_contact else '0', 'conaffinity': '1' if enable_contact else '0', 'size': '40 40 40', }, ) # kinematic trees tree_depth = np.random.randint(1, max_tree_depth + 1) def make_tree(body: ET.Element, depth: int) -> None: if depth >= tree_depth: return z_pos = np.random.uniform(low=-1, high=1) * 0.01 # small jitter pos = f'{body_pos[0]:.3f} {body_pos[1]:.3f} {body_pos[2] + z_pos:.3f}' n_bodies = len(list(mjcf.iter('body'))) gravcomp = np.random.uniform() * p(50) child = ET.SubElement( body, 'body', { 'pos': pos, 'name': f'body{n_bodies}', 'gravcomp': f'{gravcomp:.3f}', }, ) ET.SubElement(child, 'site', {'name': f'site{n_bodies}'}) n_joints = len(list(mjcf.iter('joint'))) for nj in range(np.random.randint(1, max_stacked_joints + 1)): joint_type = np.random.choice(_JOINT_TYPES) if nj == 0 and depth == 0 and root_always_free: joint_type = 'free' # free joint only allowed at top level while joint_type == 'free' and (depth > 0 or nj > 0): joint_type = np.random.choice(_JOINT_TYPES) joint_attr = _make_joint(joint_type, name=f'joint{n_joints + nj}') ET.SubElement(child, 'joint', joint_attr) prev_joints = child.findall('joint') had_ball_or_free = any( [j.get('type') in ('ball', 'free') for j in prev_joints] ) if had_ball_or_free: break # do not stack more joints n_geoms = len(list(mjcf.iter('geom'))) for _ in range(np.random.randint(1, max_geoms_per_body + 1)): pos = ('{:.2f} ' * 3).format(*geom_pos).strip() size = 0.2 + np.random.uniform(low=-1, high=1) * 0.02 geom_attr = _make_geom( pos, size, name=f'geom{n_geoms}', enable_contact=enable_contact ) ET.SubElement(child, 'geom', geom_attr) n_geoms += 1 make_tree(child, depth + 1) num_trees = np.random.randint(min_trees, max_trees + 1) for _ in range(num_trees): make_tree(world, 0) bodies = list(mjcf.iter('body')) n_bodies = len(bodies) # actuators if add_actuators: actuator = ET.SubElement(mjcf, 'actuator') n_joints = len(list(mjcf.iter('joint'))) nu = np.random.randint(1, n_joints + 1) actuators = [] # joint transmission for i in range(nu): actuator_type = np.random.choice(_ACTUATOR_TYPES) attr = _make_actuator(actuator_type, joint=f'joint{i}') actuators.append((actuator_type, attr)) # site transmission for i in range(np.random.randint(0, n_bodies)): actuator_type = np.random.choice(_ACTUATOR_TYPES) attr = _make_actuator(actuator_type, site=f'site{i}') actuators.append((actuator_type, attr)) # site transmission with refsite for i in range(np.random.randint(0, n_bodies)): j = np.random.randint(0, n_bodies) actuator_type = np.random.choice(_ACTUATOR_TYPES) attr = _make_actuator(actuator_type, site=f'site{i}', refsite=f'site{j}') actuators.append((actuator_type, attr)) np.random.shuffle(actuators) for typ, attr in actuators: ET.SubElement(actuator, typ, attr) # contact pairs contact = ET.SubElement(mjcf, 'contact') geoms = list(mjcf.iter('geom')) geom_names = [geom.get('name') for geom in geoms] n_geoms = len(geoms) pairs = set() for _ in range(min(max_contact_pairs, n_geoms * (n_geoms - 1) // 2)): if p(80): continue geom1, geom2 = np.random.choice(geom_names, replace=False, size=2) if geom1 > geom2: geom1, geom2 = geom2, geom1 if (geom1, geom2) in pairs: continue pairs.add((geom1, geom2)) attr = {'geom1': geom1, 'geom2': geom2} attr.update(_geom_solparams(pair=True)) ET.SubElement(contact, 'pair', attr) # exclude contacts body_names = [b.get('name') for b in bodies] for _ in range(min(max_contact_excludes, (n_bodies * (n_bodies - 1) // 2))): if p(50): continue body1, body2 = np.random.choice(body_names, replace=False, size=2) ET.SubElement(contact, 'exclude', {'body1': body1, 'body2': body2}) # ElementTree.indent is not available before Python 3.9 if sys.version_info.minor >= 9: ET.indent(mjcf) return ET.tostring(mjcf).decode('utf-8') def load_test_file(name: str) -> mujoco.MjModel: """Loads a mujoco.MjModel based on the file name.""" path = epath.resource_path('mujoco.mjx') / 'test_data' / name m = mujoco.MjModel.from_xml_path(path.as_posix()) return m