Source code for gsply.utils

"""Utility functions for Gaussian Splatting operations."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Final

import numpy as np
from numba import jit, njit, prange
from numpy.typing import NDArray

from gsply.formats import SH_C0

if TYPE_CHECKING:
    from gsply.gsdata import GSData

logger = logging.getLogger(__name__)

Float32Array = NDArray[np.float32]

# Default clamp values recommended by rendering pipeline
_DEFAULT_MIN_SCALE: Final[np.float32] = np.float32(1e-4)
_DEFAULT_MAX_SCALE: Final[np.float32] = np.float32(100.0)
_DEFAULT_MIN_NORM: Final[np.float32] = np.float32(1e-8)


[docs] def sh2rgb(sh: np.ndarray | float) -> np.ndarray | float: """Convert SH DC coefficients to RGB colors. :param sh: SH DC coefficients (N, 3) or scalar :returns: RGB colors in [0, 1] range Example: >>> import gsply >>> sh = np.array([[0.0, 0.5, -0.5]]) >>> rgb = gsply.sh2rgb(sh) >>> print(rgb) # [[0.5, 0.641, 0.359]] """ return sh * SH_C0 + 0.5
[docs] def rgb2sh(rgb: np.ndarray | float) -> np.ndarray | float: """Convert RGB colors to SH DC coefficients. :param rgb: RGB colors in [0, 1] range (N, 3) or scalar :returns: SH DC coefficients Example: >>> import gsply >>> rgb = np.array([[1.0, 0.5, 0.0]]) >>> sh = gsply.rgb2sh(rgb) """ return (rgb - 0.5) / SH_C0
@jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True, boundscheck=False) def _logit_impl(x: np.ndarray, out: np.ndarray, eps: float): for i in prange(x.size): val = x.flat[i] if val < eps: val = eps elif val > 1.0 - eps: val = 1.0 - eps out.flat[i] = np.log(val / (1.0 - val))
[docs] def logit(x: np.ndarray | float, eps: float = 1e-6) -> np.ndarray | float: """Compute logit function (inverse sigmoid) with numerical stability. Optimized for both scalar and array inputs using Numba. Formula: log(x / (1 - x)) :param x: Input values in [0, 1] range (probabilities) :param eps: Epsilon for numerical stability (clamping) :returns: Logit values """ if np.isscalar(x): val = float(x) val = max(eps, min(val, 1.0 - eps)) return np.log(val / (1.0 - val)) out = np.empty_like(x) _logit_impl(x, out, eps) return out
@jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True, boundscheck=False) def _sigmoid_impl(x: np.ndarray, out: np.ndarray): for i in prange(x.size): val = x.flat[i] # Stable sigmoid if val >= 0: out.flat[i] = 1.0 / (1.0 + np.exp(-val)) else: z = np.exp(val) out.flat[i] = z / (1.0 + z)
[docs] def sigmoid(x: np.ndarray | float) -> np.ndarray | float: """Compute sigmoid function (inverse logit) with numerical stability. Optimized for both scalar and array inputs using Numba. Formula: 1 / (1 + exp(-x)) :param x: Input values (logits) :returns: Values in [0, 1] range (probabilities) """ if np.isscalar(x): val = float(x) if val >= 0: return 1.0 / (1.0 + np.exp(-val)) z = np.exp(val) return z / (1.0 + z) out = np.empty_like(x) _sigmoid_impl(x, out) return out
@jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True, boundscheck=False) def _sh2rgb_inplace_jit(sh: np.ndarray, sh_c0: float): """Numba-accelerated in-place SH to RGB conversion. :param sh: (N, 3) float32 array - modified in-place :param sh_c0: SH constant (0.28209479177387814) """ n = sh.shape[0] for i in prange(n): for j in range(3): sh[i, j] = sh[i, j] * sh_c0 + 0.5 @jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True, boundscheck=False) def _rgb2sh_inplace_jit(rgb: np.ndarray, inv_sh_c0: float): """Numba-accelerated in-place RGB to SH conversion. :param rgb: (N, 3) float32 array - modified in-place :param inv_sh_c0: Inverse SH constant (1.0 / 0.28209479177387814) """ n = rgb.shape[0] for i in prange(n): for j in range(3): rgb[i, j] = (rgb[i, j] - 0.5) * inv_sh_c0 @njit(parallel=True, fastmath=True, cache=True, nogil=True) def _activate_gaussians_numba( scales: Float32Array, opacities: Float32Array, quats: Float32Array, min_scale: np.float32, max_scale: np.float32, min_quat_norm: np.float32, ) -> None: """ Fused attribute activation kernel. :param scales: Log-scale values, shape [N, 3] :param opacities: Logit opacities, shape [N] :param quats: Raw quaternions, shape [N, 4] :param min_scale: Minimum clamp value post-exp :param max_scale: Maximum clamp value post-exp :param min_quat_norm: Minimum allowable quaternion norm (safety floor) """ count = scales.shape[0] for i in prange(count): # Scale activation: exp + clamp sx = np.exp(scales[i, 0]) sy = np.exp(scales[i, 1]) sz = np.exp(scales[i, 2]) sx = min(max(sx, min_scale), max_scale) sy = min(max(sy, min_scale), max_scale) sz = min(max(sz, min_scale), max_scale) scales[i, 0] = sx scales[i, 1] = sy scales[i, 2] = sz # Opacity activation: numerically-stable sigmoid logit = opacities[i] if logit >= 0.0: exp_term = np.exp(-logit) sigmoid = 1.0 / (1.0 + exp_term) else: exp_term = np.exp(logit) sigmoid = exp_term / (1.0 + exp_term) opacities[i] = sigmoid # Quaternion activation: normalize with safety floor qx = quats[i, 0] qy = quats[i, 1] qz = quats[i, 2] qw = quats[i, 3] norm = np.sqrt(qx * qx + qy * qy + qz * qz + qw * qw) if norm < min_quat_norm: quats[i, 0] = np.float32(0.0) quats[i, 1] = np.float32(0.0) quats[i, 2] = np.float32(0.0) quats[i, 3] = np.float32(1.0) else: inv = 1.0 / norm quats[i, 0] = qx * inv quats[i, 1] = qy * inv quats[i, 2] = qz * inv quats[i, 3] = qw * inv @njit(parallel=True, fastmath=True, cache=True, nogil=True) def _deactivate_gaussians_numba( scales: Float32Array, opacities: Float32Array, min_scale: np.float32, min_opacity: np.float32, max_opacity: np.float32, ) -> None: """ Fused attribute deactivation kernel (reverse of activation). :param scales: Linear scale values, shape [N, 3] :param opacities: Linear opacities, shape [N] :param min_scale: Minimum clamp value before log :param min_opacity: Minimum clamp value before logit :param max_opacity: Maximum clamp value before logit """ count = scales.shape[0] for i in prange(count): # Scale deactivation: clamp + log sx = max(scales[i, 0], min_scale) sy = max(scales[i, 1], min_scale) sz = max(scales[i, 2], min_scale) scales[i, 0] = np.log(sx) scales[i, 1] = np.log(sy) scales[i, 2] = np.log(sz) # Opacity deactivation: clamp + logit opacity = opacities[i] if opacity < min_opacity: opacity = min_opacity elif opacity > max_opacity: opacity = max_opacity opacities[i] = np.log(opacity / (1.0 - opacity)) def _ensure_float32_contiguous(array: Float32Array | None, name: str) -> Float32Array: """ Ensure arrays passed to kernels are float32 and C-contiguous. :param array: Array to validate :param name: Attribute name (for error messages) :return: Array guaranteed to be float32 and contiguous """ if array is None: raise ValueError(f"GSData.{name} is required.") if array.dtype != np.float32: array = array.astype(np.float32, copy=False) if not array.flags["C_CONTIGUOUS"]: array = np.ascontiguousarray(array) return array
[docs] def apply_pre_activations( data: GSData, *, min_scale: float = float(_DEFAULT_MIN_SCALE), max_scale: float = float(_DEFAULT_MAX_SCALE), min_quat_norm: float = float(_DEFAULT_MIN_NORM), inplace: bool = True, ) -> GSData: """ Activate GSData attributes (scales, opacities, quaternions) in a single fused pass. This function uses a fused Numba kernel that processes all three attributes together for optimal performance (~8-15x faster than individual operations). :param data: GSData instance to process :param min_scale: Minimum allowed scale value after exponentiation :param max_scale: Maximum allowed scale value after exponentiation :param min_quat_norm: Norm floor for normalizing quaternions (avoids NaNs) :param inplace: If False, returns a copy before activation :return: GSData with activated attributes (either modified in-place or copy) Example: >>> import gsply >>> data = gsply.plyread("scene_logits.ply") >>> gsply.apply_pre_activations(data, inplace=True) """ # Lazy import to avoid circular dependency (cached in function attribute) if not hasattr(apply_pre_activations, "_GSData"): from gsply.gsdata import GSData as _GSData apply_pre_activations._GSData = _GSData if min_scale <= 0: raise ValueError("min_scale must be positive to avoid degenerate exponentiation results.") if max_scale <= 0 or max_scale < min_scale: raise ValueError("max_scale must be positive and >= min_scale.") if min_quat_norm <= 0: raise ValueError("min_quat_norm must be positive.") if not inplace: data = data.copy() scales = _ensure_float32_contiguous(data.scales, "scales") opacities = _ensure_float32_contiguous(data.opacities, "opacities") quats = _ensure_float32_contiguous(data.quats, "quats") if scales.ndim != 2 or scales.shape[1] != 3: raise ValueError("scales must have shape [N, 3].") if quats.ndim != 2 or quats.shape[1] != 4: raise ValueError("quats must have shape [N, 4].") if opacities.ndim == 2 and opacities.shape[1] == 1: opacity_view = opacities.reshape(opacities.shape[0]) elif opacities.ndim == 1: opacity_view = opacities else: raise ValueError("opacities must be 1D or have shape [N, 1].") n_gaussians = scales.shape[0] if quats.shape[0] != n_gaussians or opacity_view.shape[0] != n_gaussians: raise ValueError("scales, opacities, and quats must have matching lengths.") _activate_gaussians_numba( scales, opacity_view, quats, np.float32(min_scale), np.float32(max_scale), np.float32(min_quat_norm), ) data.scales = scales data.opacities = opacities data.quats = quats logger.debug( "[PreActivation] Activated %d Gaussians (min_scale=%.2e, max_scale=%.2f, min_quat_norm=%.2e)", scales.shape[0], min_scale, max_scale, min_quat_norm, ) return data
[docs] def apply_pre_deactivations( data: GSData, *, min_scale: float = 1e-9, min_opacity: float = 1e-4, max_opacity: float = 1.0 - 1e-4, inplace: bool = True, ) -> GSData: """ Deactivate GSData attributes (scales, opacities) in a single fused pass. This function uses a fused Numba kernel that processes scales and opacities together for optimal performance (~8-15x faster than individual operations). :param data: GSData instance to process :param min_scale: Minimum allowed scale value before logarithm :param min_opacity: Minimum allowed opacity value before logit :param max_opacity: Maximum allowed opacity value before logit :param inplace: If False, returns a copy before deactivation :return: GSData with deactivated attributes (either modified in-place or copy) Example: >>> import gsply >>> data = gsply.GSData(...) # Linear format >>> gsply.apply_pre_deactivations(data, inplace=True) """ # Lazy import to avoid circular dependency (cached in function attribute) if not hasattr(apply_pre_deactivations, "_GSData"): from gsply.gsdata import GSData as _GSData apply_pre_deactivations._GSData = _GSData if min_scale <= 0: raise ValueError("min_scale must be positive to avoid degenerate logarithm results.") if min_opacity <= 0 or max_opacity >= 1.0: raise ValueError("min_opacity must be positive and max_opacity must be < 1.0.") if max_opacity <= min_opacity: raise ValueError("max_opacity must be > min_opacity.") if not inplace: data = data.copy() scales = _ensure_float32_contiguous(data.scales, "scales") opacities = _ensure_float32_contiguous(data.opacities, "opacities") if scales.ndim != 2 or scales.shape[1] != 3: raise ValueError("scales must have shape [N, 3].") if opacities.ndim == 2 and opacities.shape[1] == 1: opacity_view = opacities.reshape(opacities.shape[0]) elif opacities.ndim == 1: opacity_view = opacities else: raise ValueError("opacities must be 1D or have shape [N, 1].") n_gaussians = scales.shape[0] if opacity_view.shape[0] != n_gaussians: raise ValueError("scales and opacities must have matching lengths.") _deactivate_gaussians_numba( scales, opacity_view, np.float32(min_scale), np.float32(min_opacity), np.float32(max_opacity), ) data.scales = scales data.opacities = opacities logger.debug( "[PreDeactivation] Deactivated %d Gaussians (min_scale=%.2e, min_opacity=%.2e, max_opacity=%.2e)", scales.shape[0], min_scale, min_opacity, max_opacity, ) return data
__all__ = [ "sh2rgb", "rgb2sh", "SH_C0", "sigmoid", "logit", "apply_pre_activations", "apply_pre_deactivations", ]