Source code for embodichain.utils.warp.kernels

# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

import warp as wp
from typing import Any


[docs] @wp.kernel(enable_backward=False) def reshape_tiled_image( tiled_image_buffer: Any, batched_image: Any, image_height: int, image_width: int, num_channels: int, num_tiles_x: int, ): """Reshapes a tiled image into a batch of images. This function reshapes the input tiled image buffer into a batch of images. The input image buffer is assumed to be tiled in the x and y directions. The output image is a batch of images with the specified height, width, and number of channels. Args: tiled_image_buffer: The input image buffer. Shape is (height * width * num_channels * num_cameras,). batched_image: The output image. Shape is (num_cameras, height, width, num_channels). image_width: The width of the image. image_height: The height of the image. num_channels: The number of channels in the image. num_tiles_x: The number of tiles in x-direction. """ # get the thread id camera_id, height_id, width_id = wp.tid() # resolve the tile indices tile_x_id = camera_id % num_tiles_x # TODO: Currently, the tiles arranged in the bottom-to-top order, which should be changed. tile_y_id = ( num_tiles_x - 1 - (camera_id // num_tiles_x) ) # Adjust for bottom-to-top tiling # compute the start index of the pixel in the tiled image buffer pixel_start = ( num_channels * num_tiles_x * image_width * (image_height * tile_y_id + height_id) + num_channels * tile_x_id * image_width + num_channels * width_id ) # copy the pixel values into the batched image for i in range(num_channels): batched_image[camera_id, height_id, width_id, i] = batched_image.dtype( tiled_image_buffer[pixel_start + i] )
# uint32 -> int32 conversion is required for non-colored segmentation annotators wp.overload( reshape_tiled_image, { "tiled_image_buffer": wp.array(dtype=wp.uint32), "batched_image": wp.array(dtype=wp.uint32, ndim=4), }, ) # uint8 is used for 4 channel annotators wp.overload( reshape_tiled_image, { "tiled_image_buffer": wp.array(dtype=wp.uint8), "batched_image": wp.array(dtype=wp.uint8, ndim=4), }, ) # float32 is used for single channel annotators wp.overload( reshape_tiled_image, { "tiled_image_buffer": wp.array(dtype=wp.float32), "batched_image": wp.array(dtype=wp.float32, ndim=4), }, ) @wp.kernel(enable_backward=False) def scatter_contact_data( contact_data: wp.array(dtype=wp.float32, ndim=2), user_ids: wp.array(dtype=wp.int32, ndim=2), env_ids: wp.array(dtype=wp.int32, ndim=1), num_contacts_per_env: wp.array(dtype=wp.int32, ndim=1), max_contacts_per_env: int, # Output buffers position: wp.array(dtype=wp.float32, ndim=3), normal: wp.array(dtype=wp.float32, ndim=3), friction: wp.array(dtype=wp.float32, ndim=3), impulse: wp.array(dtype=wp.float32, ndim=2), distance: wp.array(dtype=wp.float32, ndim=2), user_ids_out: wp.array(dtype=wp.int32, ndim=3), is_valid: wp.array(dtype=wp.bool, ndim=2), ): """Scatters contact data into per-environment buffers. This kernel takes filtered contact data and scatters it into per-environment buffers. For each contact, it determines which environment it belongs to and the contact index within that environment (using atomic add for thread-safe counting). Args: contact_data: Input contact data. Shape is (n_contact, 11). Columns: [x, y, z, nx, ny, nz, fx, fy, fz, impulse, distance] user_ids: Input user IDs for each contact. Shape is (n_contact, 2). env_ids: Environment ID for each contact. Shape is (n_contact,). num_contacts_per_env: Output counter for contacts per environment. Shape is (num_envs,). Updated atomically during kernel execution. max_contacts_per_env: Maximum contacts per environment (buffer capacity). position: Output position buffer. Shape is (num_envs, max_contacts_per_env, 3). normal: Output normal buffer. Shape is (num_envs, max_contacts_per_env, 3). friction: Output friction buffer. Shape is (num_envs, max_contacts_per_env, 3). impulse: Output impulse buffer. Shape is (num_envs, max_contacts_per_env). distance: Output distance buffer. Shape is (num_envs, max_contacts_per_env). user_ids_out: Output user IDs buffer. Shape is (num_envs, max_contacts_per_env, 2). is_valid: Output validity mask. Shape is (num_envs, max_contacts_per_env). Note: If an environment has more contacts than max_contacts_per_env, excess contacts are silently dropped. The num_contacts_per_env output will reflect the actual number of contacts written (capped at max_contacts_per_env). """ i = wp.tid() n_contact = contact_data.shape[0] if i >= n_contact: return env_id = env_ids[i] # Atomically increment contact counter for this environment contact_idx = wp.atomic_add(num_contacts_per_env, env_id, 1) # Drop excess contacts if buffer is full if contact_idx >= max_contacts_per_env: # Decrement counter since we didn't write this contact wp.atomic_sub(num_contacts_per_env, env_id, 1) return # Extract contact data columns x = contact_data[i, 0] y = contact_data[i, 1] z = contact_data[i, 2] nx = contact_data[i, 3] ny = contact_data[i, 4] nz = contact_data[i, 5] fx = contact_data[i, 6] fy = contact_data[i, 7] fz = contact_data[i, 8] impulse_val = contact_data[i, 9] distance_val = contact_data[i, 10] # Write to output buffers position[env_id, contact_idx, 0] = x position[env_id, contact_idx, 1] = y position[env_id, contact_idx, 2] = z normal[env_id, contact_idx, 0] = nx normal[env_id, contact_idx, 1] = ny normal[env_id, contact_idx, 2] = nz friction[env_id, contact_idx, 0] = fx friction[env_id, contact_idx, 1] = fy friction[env_id, contact_idx, 2] = fz impulse[env_id, contact_idx] = impulse_val distance[env_id, contact_idx] = distance_val user_ids_out[env_id, contact_idx, 0] = user_ids[i, 0] user_ids_out[env_id, contact_idx, 1] = user_ids[i, 1] is_valid[env_id, contact_idx] = True