Source code for embodichain.data.dataset

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import os
import sys
import shutil
import hashlib
import open3d as o3d

from embodichain.utils import logger


[docs] class EmbodiChainDataset(o3d.data.DownloadDataset):
[docs] def __init__(self, prefix, data_descriptor, path): # Perform the zip file and extracted contents check # If the zip was not valid, the zip file would have been removed # and the parent class would download and extract it again self.check_zip(prefix, data_descriptor, path) # Call the parent class constructor super().__init__(prefix, data_descriptor, path)
[docs] def check_zip(self, prefix, data_descriptor, path): """Check the integrity of the zip file and its extracted contents.""" # Path to the downloaded zip file zip_file_name = os.path.split(data_descriptor.urls[0])[1] zip_dir_path = os.path.join(path, "download", f"{prefix}") zip_path = os.path.join(path, "download", f"{prefix}", f"{zip_file_name}") # Path to the extracted directory extracted_path = os.path.join(path, "extract", prefix) def is_safe_path(path_to_check): """Verify if the path is within safe directory boundaries""" return ( "embodichain_data/download" in path_to_check or "embodichain_data/extract" in path_to_check ) def safe_remove_directory(dir_path): """Safely remove a directory after path validation""" if not is_safe_path(dir_path): logger.log_warning( f"Safety check failed, refusing to delete directory: {dir_path}" ) return False if os.path.exists(dir_path): try: shutil.rmtree(dir_path) logger.log_info(f"Successfully removed directory: {dir_path}") return True except OSError as e: logger.log_warning(f"Error while removing directory: {e}") return False return True # Check if the file already exists if os.path.exists(zip_path): # Calculate MD5 checksum of the existing file md5_existing = self.calculate_md5(zip_path) # Compare with the expected MD5 checksum if md5_existing != data_descriptor.md5: # If checksums do not match, delete the existing file os.remove(zip_path) # Ensure the extracted directory is removed if it exists safe_remove_directory(extracted_path) logger.log_warning( f"Invalid MD5 checksum detected:\n" f" - File: {zip_path}\n" f" - Expected MD5: {data_descriptor.md5}\n" f" - Actual MD5: {md5_existing}\n" f"Cleaned up invalid files and directories for fresh download." ) return else: safe_remove_directory(zip_dir_path) safe_remove_directory(extracted_path) logger.log_info( f"ZIP file not found at {zip_path}." f"Cleaning up related directories for fresh download." ) return # Check if the extracted directory exists and is not empty if not os.path.exists(extracted_path) or not os.listdir(extracted_path): # Remove the zip file to trigger Open3D's automatic download mechanism # Open3D will re-download and extract when the zip file is missing if os.path.exists(zip_path): os.remove(zip_path) # Clean up any existing empty extraction directory # This ensures a clean state for the upcoming extraction process safe_remove_directory(extracted_path) logger.log_info( f"Removed zip file {zip_path} and extracted path {extracted_path} to trigger Open3D download and extract. " f"Reason: {'Missing extraction directory.' if not os.path.exists(extracted_path) else 'Empty extraction directory.'}" ) return
[docs] def calculate_md5(self, file_path, chunk_size=8192): """Calculate the MD5 checksum of a file.""" hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(chunk_size), b""): hash_md5.update(chunk) return hash_md5.hexdigest()
DEFAULT_DATA_MODULES = [ "embodichain.data", "embodichain.data.assets", ]
[docs] def get_data_class(dataset_name: str, extra_modules: list[str] | None = None): """Retrieve the dataset class from the available modules. Args: dataset_name (str): The name of the dataset class. extra_modules (list[str] | None): Optional list of additional module names to search for the dataset class. Returns: type: The dataset class. Raises: AttributeError: If the dataset class is not found in any module. """ module_names = DEFAULT_DATA_MODULES + ( extra_modules if extra_modules is not None else [] ) for module_name in module_names: try: return getattr(sys.modules[module_name], dataset_name) except AttributeError: continue raise AttributeError(f"Dataset class '{dataset_name}' not found in any module.")
[docs] def get_data_path(data_path_in_config: str) -> str: """Get the absolute path of the data file. Resolution order: 1. If ``data_path_in_config`` is an absolute path, return it directly. 2. If a matching file/directory exists under ``EMBODICHAIN_DEFAULT_DATA_ROOT`` (which can be overridden via the ``EMBODICHAIN_DATA_ROOT`` environment variable), return that path. 3. Otherwise, resolve via the registered data-class download mechanism. Args: data_path_in_config (str): The dataset path in the format ``"dataset_name/subpath"``. Returns: str: The absolute path of the data file. """ if os.path.isabs(data_path_in_config): return data_path_in_config # Try resolving under the user-configurable data root first from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATA_ROOT local_path = os.path.join(EMBODICHAIN_DEFAULT_DATA_ROOT, data_path_in_config) if os.path.exists(local_path): return local_path # Fall back to the data-class download mechanism split_str = data_path_in_config.split("/") dataset_name = split_str[0] sub_path = os.path.join(*split_str[1:]) data_class = get_data_class(dataset_name) data_obj = data_class() data_dir = data_obj.extract_dir data_path = os.path.join(data_dir, sub_path) return data_path