# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
"""CLI tool for pre-downloading EmbodiChain data assets.
Usage::
# List all available assets
python -m embodichain.data.download list
# List assets in a specific category
python -m embodichain.data.download list --category robot
# Download a specific asset by name
python -m embodichain.data.download download --name CobotMagicArm
# Download all assets in a category
python -m embodichain.data.download download --category robot
# Download everything
python -m embodichain.data.download download --all
"""
from __future__ import annotations
import argparse
import importlib
import inspect
import os
import shutil
import sys
import open3d as o3d
from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATA_ROOT
# Mapping from category name to the module path that defines the asset classes.
CATEGORY_MODULES: dict[str, str] = {
"demo": "embodichain.data.assets.demo_assets",
"eef": "embodichain.data.assets.eef_assets",
"materials": "embodichain.data.assets.materials",
"obj": "embodichain.data.assets.obj_assets",
"robot": "embodichain.data.assets.robot_assets",
"scene": "embodichain.data.assets.scene_assets",
"w1": "embodichain.data.assets.w1_assets",
}
def _get_asset_classes(module_path: str) -> list[tuple[str, type]]:
"""Return (name, cls) pairs for all DownloadDataset subclasses in *module_path*."""
module = importlib.import_module(module_path)
results: list[tuple[str, type]] = []
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, o3d.data.DownloadDataset)
and obj is not o3d.data.DownloadDataset
and obj.__module__ == module.__name__
):
results.append((name, obj))
results.sort(key=lambda x: x[0])
return results
[docs]
def get_registry() -> dict[str, list[tuple[str, type]]]:
"""Build ``{category: [(class_name, class), ...]}`` for every category."""
registry: dict[str, list[tuple[str, type]]] = {}
for category, module_path in CATEGORY_MODULES.items():
registry[category] = _get_asset_classes(module_path)
return registry
[docs]
def find_asset_class(
name: str, registry: dict[str, list[tuple[str, type]]]
) -> tuple[str, type] | None:
"""Find an asset class by name (case-insensitive) across all categories."""
name_lower = name.lower()
for category, assets in registry.items():
for cls_name, cls in assets:
if cls_name.lower() == name_lower:
return category, cls
return None
# ---------------------------------------------------------------------------
# Download helpers
# ---------------------------------------------------------------------------
def _ensure_extract(data_obj: o3d.data.DownloadDataset, prefix: str) -> None:
"""For non-zip assets, copy the downloaded file into the extract directory.
``o3d.data.DownloadDataset`` extracts zip archives automatically but
leaves single-file downloads (e.g. ``.glb``) only in the download
directory. This helper copies them to the extract tree so that
``get_data_path`` can find them under ``<data_root>/extract/<prefix>/``.
"""
extract_dir = os.path.join(EMBODICHAIN_DEFAULT_DATA_ROOT, "extract", prefix)
if os.path.exists(extract_dir) and os.listdir(extract_dir):
return # already extracted
download_dir = os.path.join(EMBODICHAIN_DEFAULT_DATA_ROOT, "download", prefix)
if not os.path.isdir(download_dir):
return
os.makedirs(extract_dir, exist_ok=True)
for item in os.listdir(download_dir):
src = os.path.join(download_dir, item)
dst = os.path.join(extract_dir, item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
shutil.copy2(src, dst)
print(f" Copied non-zip asset to extract dir: {extract_dir}")
[docs]
def download_asset(cls_name: str, cls: type) -> None:
"""Instantiate an asset class to trigger download, then ensure extraction."""
print(f" Downloading {cls_name} ...")
try:
data_obj = cls()
_ensure_extract(data_obj, cls_name)
print(f" ✓ {cls_name} ready")
except Exception as exc:
print(f" ✗ {cls_name} failed: {exc}", file=sys.stderr)
# ---------------------------------------------------------------------------
# CLI commands
# ---------------------------------------------------------------------------
[docs]
def cmd_list(args: argparse.Namespace) -> None:
"""List available assets."""
registry = get_registry()
categories = [args.category] if args.category else sorted(registry)
for category in categories:
assets = registry.get(category)
if assets is None:
print(f"Unknown category: {category}", file=sys.stderr)
print(
f"Available categories: {', '.join(sorted(registry))}", file=sys.stderr
)
sys.exit(1)
print(f"\n[{category}] ({len(assets)} assets)")
for cls_name, _ in assets:
# Show whether it is already downloaded
extract_dir = os.path.join(
EMBODICHAIN_DEFAULT_DATA_ROOT, "extract", cls_name
)
status = (
"✓" if os.path.isdir(extract_dir) and os.listdir(extract_dir) else " "
)
print(f" [{status}] {cls_name}")
print(f"\nData root: {EMBODICHAIN_DEFAULT_DATA_ROOT}")
[docs]
def cmd_download(args: argparse.Namespace) -> None:
"""Download assets by name, category, or everything."""
registry = get_registry()
targets: list[tuple[str, type]] = []
if args.all:
for assets in registry.values():
targets.extend(assets)
elif args.category:
assets = registry.get(args.category)
if assets is None:
print(f"Unknown category: {args.category}", file=sys.stderr)
print(
f"Available categories: {', '.join(sorted(registry))}", file=sys.stderr
)
sys.exit(1)
targets.extend(assets)
elif args.name:
result = find_asset_class(args.name, registry)
if result is None:
print(f"Asset '{args.name}' not found.", file=sys.stderr)
print("Use 'list' to see available assets.", file=sys.stderr)
sys.exit(1)
_category, cls = result
targets.append((args.name, cls))
else:
print("Specify --name, --category, or --all.", file=sys.stderr)
sys.exit(1)
print(f"Data root: {EMBODICHAIN_DEFAULT_DATA_ROOT}")
print(f"Downloading {len(targets)} asset(s) ...\n")
for cls_name, cls in targets:
download_asset(cls_name, cls)
print(f"\nDone. {len(targets)} asset(s) processed.")
[docs]
def main() -> None:
parser = argparse.ArgumentParser(
prog="embodichain.data.download",
description="Pre-download EmbodiChain data assets from HuggingFace.",
)
subparsers = parser.add_subparsers(dest="command")
# --- list ---
list_parser = subparsers.add_parser("list", help="List available assets.")
list_parser.add_argument(
"--category",
choices=sorted(CATEGORY_MODULES),
help="Show only assets in this category.",
)
# --- download ---
dl_parser = subparsers.add_parser("download", help="Download assets.")
dl_group = dl_parser.add_mutually_exclusive_group(required=True)
dl_group.add_argument("--name", help="Download a single asset by class name.")
dl_group.add_argument(
"--category",
choices=sorted(CATEGORY_MODULES),
help="Download all assets in a category.",
)
dl_group.add_argument(
"--all", action="store_true", help="Download every registered asset."
)
args = parser.parse_args()
if args.command == "list":
cmd_list(args)
elif args.command == "download":
cmd_download(args)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()