Source code for mujoco_tools.data_processor
import numpy as np
from typing import Dict
[docs]def parse_data_arg(data_str: str) -> Dict[str, str]:
"""Parse data argument string into a dictionary
Example: "qpos data/qpos.npy ctrl data/ctrl.npy" -> {"qpos": "data/qpos.npy", "ctrl": "data/ctrl.npy"}
"""
if not data_str:
return {}
parts = data_str.split()
if len(parts) % 2 != 0:
raise ValueError("Data argument must be pairs of type and path")
data_name = parts[::2]
data_path = parts[1::2]
data_array = []
for path in data_path:
if path.endswith('.npy'):
data_array.append(np.load(path))
elif path.endswith('.txt'):
# Read the text file and process each line
with open(path, 'r') as f:
lines = f.readlines()
# Process each line: split by comma, convert to float
processed_data = [
[float(num) for num in line.strip().split(',')]
for line in lines
]
# Convert to numpy array
data_array.append(np.array(processed_data))
else:
raise ValueError(f"Unsupported file format for {path}. Must be .npy or .txt")
return dict(zip(data_name, data_array))
[docs]class InputDataProcessor:
def __init__(self, input_str: str):
"""Initialize with a string input
Args:
input_str: Either a path to .npz file or string in format "type1 path1 type2 path2"
"""
if not input_str:
self.input_str = ""
return None
if not isinstance(input_str, str):
raise ValueError("Input must be a string")
self.input_str = input_str
[docs] def process(self) -> Dict[str, np.ndarray]:
"""Process the input string into a standardized dictionary format
Returns:
dict: Dictionary mapping data types to their corresponding numpy arrays
"""
if self.input_str.endswith('.npz'):
return self._process_npz(self.input_str)
return parse_data_arg(self.input_str)
def _process_npz(self, npz_path: str) -> Dict[str, np.ndarray]:
"""Extract all arrays from a .npz file into a dictionary
Args:
npz_path: Path to the .npz file
Returns:
dict: Dictionary mapping array names to numpy arrays
"""
key_list = ["qpos", "qvel", "ctrl","self.data.qpos","self.data.qvel","self.data.ctrl"]
with np.load(npz_path) as data:
return {key: data[key] for key in key_list if key in data}