Source code for gsply.writer

"""Writing functions for Gaussian splatting PLY files.

This module provides ultra-fast writing of Gaussian splatting PLY files
in uncompressed format, with compressed format support planned.

API Examples:
    >>> from gsply import plywrite
    >>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN)

    >>> # Or use format-specific writers
    >>> from gsply.writer import write_uncompressed
    >>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, shN)

Performance:
    - Write uncompressed: 3-7ms for 50K Gaussians (7-17M Gaussians/sec)
    - Write compressed: 2-11ms for 50K Gaussians (4-25M Gaussians/sec)
"""

import logging
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING

import numba
import numpy as np

# Import numba for JIT optimization
from numba import jit

from gsply.formats import CHUNK_SIZE, SH_C0
from gsply.gsdata import (
    DataFormat,
    GSData,
    _create_format_dict,
    _interleave_sh0_jit,
    _interleave_shn_jit,
)

if TYPE_CHECKING:
    from gsply.torch.gstensor import GSTensor  # noqa: F401

logger = logging.getLogger(__name__)


# ======================================================================================
# I/O BUFFER SIZE CONSTANTS
# ======================================================================================

# Buffer sizes for optimized I/O performance
_LARGE_BUFFER_SIZE = 2 * 1024 * 1024  # 2MB buffer for large files
_SMALL_BUFFER_SIZE = 1 * 1024 * 1024  # 1MB buffer for small files
_LARGE_FILE_THRESHOLD = 10_000_000  # 10MB threshold for buffer size selection


# ======================================================================================
# BIT-PACKING QUANTIZATION CONSTANTS
# ======================================================================================

# Quantization maxima for bit-packing (used in compression)
# Position and Scale: 11-10-11 bit scheme
_QUANTIZE_11_BIT_MAX = 2047.0  # 2^11 - 1 = 2047 (X and Z coordinates)
_QUANTIZE_10_BIT_MAX = 1023.0  # 2^10 - 1 = 1023 (Y coordinate and quaternions)
_QUANTIZE_8_BIT_MAX = 255.0  # 2^8 - 1 = 255 (RGB and opacity)

# Rounding offset for proper quantization (avoids truncation bias)
_ROUNDING_OFFSET = 0.5

# Bit shift positions for 32-bit packing
# Position/Scale packing: (X:11 bits)(Y:10 bits)(Z:11 bits)
_POSITION_X_SHIFT = 21  # bits 31-21: X coordinate (11 bits)
_POSITION_Y_SHIFT = 11  # bits 20-11: Y coordinate (10 bits)
_POSITION_Z_SHIFT = 0  # bits 10-0: Z coordinate (11 bits)

# Quaternion packing: (largest_idx:2 bits)(qa:10 bits)(qb:10 bits)(qc:10 bits)
_QUAT_INDEX_SHIFT = 30  # bits 31-30: largest component index (2 bits)
_QUAT_A_SHIFT = 20  # bits 29-20: first remaining component (10 bits)
_QUAT_B_SHIFT = 10  # bits 19-10: second remaining component (10 bits)
_QUAT_C_SHIFT = 0  # bits 9-0: third remaining component (10 bits)

# Color packing: (R:8 bits)(G:8 bits)(B:8 bits)(Opacity:8 bits)
_COLOR_R_SHIFT = 24  # bits 31-24: red channel (8 bits)
_COLOR_G_SHIFT = 16  # bits 23-16: green channel (8 bits)
_COLOR_B_SHIFT = 8  # bits 15-8: blue channel (8 bits)
_COLOR_O_SHIFT = 0  # bits 7-0: opacity (8 bits)


# ======================================================================================
# PRE-COMPUTED HEADER TEMPLATES (Optimization)
# ======================================================================================

# Pre-computed header template for SH degree 0 (14 properties)
_HEADER_TEMPLATE_SH0 = (
    "ply\n"
    "format binary_little_endian 1.0\n"
    "element vertex {num_gaussians}\n"
    "property float x\n"
    "property float y\n"
    "property float z\n"
    "property float f_dc_0\n"
    "property float f_dc_1\n"
    "property float f_dc_2\n"
    "property float opacity\n"
    "property float scale_0\n"
    "property float scale_1\n"
    "property float scale_2\n"
    "property float rot_0\n"
    "property float rot_1\n"
    "property float rot_2\n"
    "property float rot_3\n"
    "end_header\n"
)

# Pre-computed f_rest property lines for SH degrees 1-3
_F_REST_PROPERTIES = {
    9: "\n".join(f"property float f_rest_{i}" for i in range(9)) + "\n",
    24: "\n".join(f"property float f_rest_{i}" for i in range(24)) + "\n",
    45: "\n".join(f"property float f_rest_{i}" for i in range(45)) + "\n",
}


@lru_cache(maxsize=32)
def _build_header_fast(num_gaussians: int, num_sh_rest: int | None) -> bytes:
    """Generate PLY header using pre-computed templates (with LRU cache).

    This optimization pre-computes header strings for common SH degrees (0-3),
    avoiding dynamic string building in loops. Provides 3-5% speedup for writes.

    :param num_gaussians: Number of Gaussians
    :param num_sh_rest: Number of higher-order SH coefficients (None for SH0)
    :returns: Header bytes ready to write
    """
    if num_sh_rest is None:
        # SH degree 0: use pre-computed template
        return _HEADER_TEMPLATE_SH0.format(num_gaussians=num_gaussians).encode("ascii")

    if num_sh_rest in _F_REST_PROPERTIES:
        # SH degrees 1-3: use pre-computed f_rest properties
        header = (
            "ply\n"
            "format binary_little_endian 1.0\n"
            f"element vertex {num_gaussians}\n"
            "property float x\n"
            "property float y\n"
            "property float z\n"
            "property float f_dc_0\n"
            "property float f_dc_1\n"
            "property float f_dc_2\n" + _F_REST_PROPERTIES[num_sh_rest] + "property float opacity\n"
            "property float scale_0\n"
            "property float scale_1\n"
            "property float scale_2\n"
            "property float rot_0\n"
            "property float rot_1\n"
            "property float rot_2\n"
            "property float rot_3\n"
            "end_header\n"
        )
        return header.encode("ascii")

    # Fallback for arbitrary SH degrees (rare)
    header_lines = [
        "ply",
        "format binary_little_endian 1.0",
        f"element vertex {num_gaussians}",
        "property float x",
        "property float y",
        "property float z",
        "property float f_dc_0",
        "property float f_dc_1",
        "property float f_dc_2",
    ]
    for i in range(num_sh_rest):
        header_lines.append(f"property float f_rest_{i}")
    header_lines.extend(
        [
            "property float opacity",
            "property float scale_0",
            "property float scale_1",
            "property float scale_2",
            "property float rot_0",
            "property float rot_1",
            "property float rot_2",
            "property float rot_3",
            "end_header",
        ]
    )
    return ("\n".join(header_lines) + "\n").encode("ascii")


# ======================================================================================
# JIT-COMPILED COMPRESSION FUNCTIONS
# ======================================================================================


@jit(nopython=True, parallel=True, fastmath=True, cache=True)
def _pack_positions_jit(
    sorted_means, chunk_indices, min_x, min_y, min_z, range_x, range_y, range_z
):
    """JIT-compiled position quantization and packing (11-10-11 bits) with parallel processing.

    Optimized: Pre-computed ranges (1.44x speedup) - ranges computed once per chunk instead of every vertex.

    :param sorted_means: (N, 3) float32 array of positions
    :param chunk_indices: int32 array of chunk indices for each vertex
    :param min_x: chunk minimum x bounds
    :param min_y: chunk minimum y bounds
    :param min_z: chunk minimum z bounds
    :param range_x: chunk x range (max - min, pre-computed)
    :param range_y: chunk y range (max - min, pre-computed)
    :param range_z: chunk z range (max - min, pre-computed)
    :returns: (N,) uint32 array of packed positions
    """
    n = len(sorted_means)
    packed = np.zeros(n, dtype=np.uint32)

    for i in numba.prange(n):
        chunk_idx = chunk_indices[i]

        # Normalize to [0, 1] using pre-computed ranges
        norm_x = (sorted_means[i, 0] - min_x[chunk_idx]) / range_x[chunk_idx]
        norm_y = (sorted_means[i, 1] - min_y[chunk_idx]) / range_y[chunk_idx]
        norm_z = (sorted_means[i, 2] - min_z[chunk_idx]) / range_z[chunk_idx]

        # Clamp
        norm_x = max(0.0, min(1.0, norm_x))
        norm_y = max(0.0, min(1.0, norm_y))
        norm_z = max(0.0, min(1.0, norm_z))

        # Quantize to integer range
        px = np.uint32(norm_x * _QUANTIZE_11_BIT_MAX)
        py = np.uint32(norm_y * _QUANTIZE_10_BIT_MAX)
        pz = np.uint32(norm_z * _QUANTIZE_11_BIT_MAX)

        # Pack into 32-bit integer: (X:11 bits)(Y:10 bits)(Z:11 bits)
        packed[i] = (px << _POSITION_X_SHIFT) | (py << _POSITION_Y_SHIFT) | pz

    return packed


@jit(nopython=True, parallel=True, fastmath=True, cache=True)
def _pack_scales_jit(
    sorted_scales, chunk_indices, min_sx, min_sy, min_sz, range_sx, range_sy, range_sz
):
    """JIT-compiled scale quantization and packing (11-10-11 bits) with parallel processing.

    Optimized: Pre-computed ranges (1.44x speedup) - ranges computed once per chunk instead of every vertex.

    :param sorted_scales: (N, 3) float32 array of scales
    :param chunk_indices: int32 array of chunk indices for each vertex
    :param min_sx: chunk minimum scale x bounds
    :param min_sy: chunk minimum scale y bounds
    :param min_sz: chunk minimum scale z bounds
    :param range_sx: chunk scale x range (max - min, pre-computed)
    :param range_sy: chunk scale y range (max - min, pre-computed)
    :param range_sz: chunk scale z range (max - min, pre-computed)
    :returns: (N,) uint32 array of packed scales
    """
    n = len(sorted_scales)
    packed = np.zeros(n, dtype=np.uint32)

    for i in numba.prange(n):
        chunk_idx = chunk_indices[i]

        # Normalize to [0, 1] using pre-computed ranges
        norm_sx = (sorted_scales[i, 0] - min_sx[chunk_idx]) / range_sx[chunk_idx]
        norm_sy = (sorted_scales[i, 1] - min_sy[chunk_idx]) / range_sy[chunk_idx]
        norm_sz = (sorted_scales[i, 2] - min_sz[chunk_idx]) / range_sz[chunk_idx]

        # Clamp
        norm_sx = max(0.0, min(1.0, norm_sx))
        norm_sy = max(0.0, min(1.0, norm_sy))
        norm_sz = max(0.0, min(1.0, norm_sz))

        # Quantize to integer range
        sx = np.uint32(norm_sx * _QUANTIZE_11_BIT_MAX)
        sy = np.uint32(norm_sy * _QUANTIZE_10_BIT_MAX)
        sz = np.uint32(norm_sz * _QUANTIZE_11_BIT_MAX)

        # Pack into 32-bit integer: (X:11 bits)(Y:10 bits)(Z:11 bits)
        packed[i] = (sx << _POSITION_X_SHIFT) | (sy << _POSITION_Y_SHIFT) | sz

    return packed


@jit(nopython=True, parallel=True, fastmath=True, cache=True)
def _pack_colors_jit(
    sorted_color_rgb,
    sorted_opacities,
    chunk_indices,
    min_r,
    min_g,
    min_b,
    range_r,
    range_g,
    range_b,
):
    """JIT-compiled color and opacity quantization and packing (8-8-8-8 bits) with parallel processing.

    Optimized: Pre-computed ranges (1.44x speedup) - ranges computed once per chunk instead of every vertex.

    :param sorted_color_rgb: (N, 3) float32 array of pre-computed RGB colors (SH0 * SH_C0 + 0.5)
    :param sorted_opacities: (N,) float32 array of opacities (logit space)
    :param chunk_indices: int32 array of chunk indices for each vertex
    :param min_r: chunk minimum color r bounds
    :param min_g: chunk minimum color g bounds
    :param min_b: chunk minimum color b bounds
    :param range_r: chunk color r range (max - min, pre-computed)
    :param range_g: chunk color g range (max - min, pre-computed)
    :param range_b: chunk color b range (max - min, pre-computed)
    :returns: (N,) uint32 array of packed colors
    """
    n = len(sorted_color_rgb)
    packed = np.zeros(n, dtype=np.uint32)

    for i in numba.prange(n):
        chunk_idx = chunk_indices[i]

        # Use pre-computed RGB colors
        color_r = sorted_color_rgb[i, 0]
        color_g = sorted_color_rgb[i, 1]
        color_b = sorted_color_rgb[i, 2]

        # Normalize to [0, 1] using pre-computed ranges
        norm_r = (color_r - min_r[chunk_idx]) / range_r[chunk_idx]
        norm_g = (color_g - min_g[chunk_idx]) / range_g[chunk_idx]
        norm_b = (color_b - min_b[chunk_idx]) / range_b[chunk_idx]

        # Clamp
        norm_r = max(0.0, min(1.0, norm_r))
        norm_g = max(0.0, min(1.0, norm_g))
        norm_b = max(0.0, min(1.0, norm_b))

        # Quantize colors to 8-bit range
        cr = np.uint32(norm_r * _QUANTIZE_8_BIT_MAX)
        cg = np.uint32(norm_g * _QUANTIZE_8_BIT_MAX)
        cb = np.uint32(norm_b * _QUANTIZE_8_BIT_MAX)

        # Opacity: logit to linear
        opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities[i]))
        opacity_linear = max(0.0, min(1.0, opacity_linear))
        co = np.uint32(opacity_linear * _QUANTIZE_8_BIT_MAX)

        # Pack into 32-bit integer: (R:8)(G:8)(B:8)(O:8)
        packed[i] = (cr << _COLOR_R_SHIFT) | (cg << _COLOR_G_SHIFT) | (cb << _COLOR_B_SHIFT) | co

    return packed


@jit(nopython=True, parallel=True, fastmath=True, cache=True)
def _pack_quaternions_jit(sorted_quats):
    """JIT-compiled quaternion normalization and packing (2+10-10-10 bits, smallest-three) with parallel processing.

    :param sorted_quats: (N, 4) float32 array of quaternions
    :returns: (N,) uint32 array of packed quaternions
    """
    n = len(sorted_quats)
    packed = np.zeros(n, dtype=np.uint32)
    norm_factor = np.sqrt(2.0) * 0.5

    for i in numba.prange(n):
        # Normalize quaternion
        quat = sorted_quats[i]
        norm = np.sqrt(
            quat[0] * quat[0] + quat[1] * quat[1] + quat[2] * quat[2] + quat[3] * quat[3]
        )
        if norm > 0:
            quat = quat / norm

        # Find largest component by absolute value
        abs_vals = np.abs(quat)
        largest_idx = 0
        largest_val = abs_vals[0]
        for j in range(1, 4):
            if abs_vals[j] > largest_val:
                largest_val = abs_vals[j]
                largest_idx = j

        # Flip quaternion if largest component is negative
        if quat[largest_idx] < 0:
            quat = -quat

        # Extract three smaller components
        three_components = np.zeros(3, dtype=np.float32)
        idx = 0
        for j in range(4):
            if j != largest_idx:
                three_components[idx] = quat[j]
                idx += 1

        # Normalize to [0, 1] for quantization
        qa_norm = three_components[0] * norm_factor + 0.5
        qb_norm = three_components[1] * norm_factor + 0.5
        qc_norm = three_components[2] * norm_factor + 0.5

        # Clamp
        qa_norm = max(0.0, min(1.0, qa_norm))
        qb_norm = max(0.0, min(1.0, qb_norm))
        qc_norm = max(0.0, min(1.0, qc_norm))

        # Quantize to 10-bit range
        qa_int = np.uint32(qa_norm * _QUANTIZE_10_BIT_MAX)
        qb_int = np.uint32(qb_norm * _QUANTIZE_10_BIT_MAX)
        qc_int = np.uint32(qc_norm * _QUANTIZE_10_BIT_MAX)

        # Pack into 32-bit integer: (index:2)(qa:10)(qb:10)(qc:10)
        packed[i] = (
            (np.uint32(largest_idx) << _QUAT_INDEX_SHIFT)
            | (qa_int << _QUAT_A_SHIFT)
            | (qb_int << _QUAT_B_SHIFT)
            | qc_int
        )

    return packed


# Chunk size shift constant (256 = 2^8)
_CHUNK_SIZE_SHIFT_PACK = 8


@jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True, boundscheck=False)
def _pack_all_jit(
    sorted_means,
    sorted_scales,
    sorted_color_rgb,
    sorted_opacities,
    sorted_quats,
    min_x,
    min_y,
    min_z,
    range_x,
    range_y,
    range_z,
    min_sx,
    min_sy,
    min_sz,
    range_sx,
    range_sy,
    range_sz,
    min_r,
    min_g,
    min_b,
    range_r,
    range_g,
    range_b,
):
    """Fused JIT-compiled packing of all vertex data in single parallel pass.

    Combines position, scale, color, and quaternion packing into one loop for:
    - Better cache locality (single pass over all data)
    - Reduced parallel overhead (1 loop instead of 4)
    - Chunk index computed inline (avoids redundant lookups)

    :param sorted_means: (N, 3) float32 array of positions
    :param sorted_scales: (N, 3) float32 array of scales
    :param sorted_color_rgb: (N, 3) float32 array of pre-computed RGB colors
    :param sorted_opacities: (N,) float32 array of opacities (logit space)
    :param sorted_quats: (N, 4) float32 array of quaternions
    :param min_x, min_y, min_z: chunk minimum position bounds
    :param range_x, range_y, range_z: chunk position ranges
    :param min_sx, min_sy, min_sz: chunk minimum scale bounds
    :param range_sx, range_sy, range_sz: chunk scale ranges
    :param min_r, min_g, min_b: chunk minimum color bounds
    :param range_r, range_g, range_b: chunk color ranges
    :returns: (N, 4) uint32 array with packed [position, quaternion, scale, color]
    """
    n = len(sorted_means)
    packed = np.zeros((n, 4), dtype=np.uint32)
    norm_factor = np.sqrt(2.0) * 0.5

    for i in numba.prange(n):
        # Compute chunk index inline (256 Gaussians per chunk)
        chunk_idx = i >> _CHUNK_SIZE_SHIFT_PACK

        # ======================================================================
        # SECTION 1: Pack positions (11-10-11 bits)
        # ======================================================================
        norm_x = (sorted_means[i, 0] - min_x[chunk_idx]) / range_x[chunk_idx]
        norm_y = (sorted_means[i, 1] - min_y[chunk_idx]) / range_y[chunk_idx]
        norm_z = (sorted_means[i, 2] - min_z[chunk_idx]) / range_z[chunk_idx]
        norm_x = max(0.0, min(1.0, norm_x))
        norm_y = max(0.0, min(1.0, norm_y))
        norm_z = max(0.0, min(1.0, norm_z))
        px = np.uint32(norm_x * _QUANTIZE_11_BIT_MAX + _ROUNDING_OFFSET)
        py = np.uint32(norm_y * _QUANTIZE_10_BIT_MAX + _ROUNDING_OFFSET)
        pz = np.uint32(norm_z * _QUANTIZE_11_BIT_MAX + _ROUNDING_OFFSET)
        packed[i, 0] = (px << _POSITION_X_SHIFT) | (py << _POSITION_Y_SHIFT) | pz

        # ======================================================================
        # SECTION 2: Pack scales (11-10-11 bits)
        # ======================================================================
        norm_sx = (sorted_scales[i, 0] - min_sx[chunk_idx]) / range_sx[chunk_idx]
        norm_sy = (sorted_scales[i, 1] - min_sy[chunk_idx]) / range_sy[chunk_idx]
        norm_sz = (sorted_scales[i, 2] - min_sz[chunk_idx]) / range_sz[chunk_idx]
        norm_sx = max(0.0, min(1.0, norm_sx))
        norm_sy = max(0.0, min(1.0, norm_sy))
        norm_sz = max(0.0, min(1.0, norm_sz))
        sx = np.uint32(norm_sx * _QUANTIZE_11_BIT_MAX + _ROUNDING_OFFSET)
        sy = np.uint32(norm_sy * _QUANTIZE_10_BIT_MAX + _ROUNDING_OFFSET)
        sz = np.uint32(norm_sz * _QUANTIZE_11_BIT_MAX + _ROUNDING_OFFSET)
        packed[i, 2] = (sx << _POSITION_X_SHIFT) | (sy << _POSITION_Y_SHIFT) | sz

        # ======================================================================
        # SECTION 3: Pack colors (8-8-8-8 bits)
        # ======================================================================
        color_r = sorted_color_rgb[i, 0]
        color_g = sorted_color_rgb[i, 1]
        color_b = sorted_color_rgb[i, 2]
        norm_r = (color_r - min_r[chunk_idx]) / range_r[chunk_idx]
        norm_g = (color_g - min_g[chunk_idx]) / range_g[chunk_idx]
        norm_b = (color_b - min_b[chunk_idx]) / range_b[chunk_idx]
        norm_r = max(0.0, min(1.0, norm_r))
        norm_g = max(0.0, min(1.0, norm_g))
        norm_b = max(0.0, min(1.0, norm_b))
        cr = np.uint32(norm_r * _QUANTIZE_8_BIT_MAX + _ROUNDING_OFFSET)
        cg = np.uint32(norm_g * _QUANTIZE_8_BIT_MAX + _ROUNDING_OFFSET)
        cb = np.uint32(norm_b * _QUANTIZE_8_BIT_MAX + _ROUNDING_OFFSET)
        # Opacity: logit to linear (use rounding for better precision)
        opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities[i]))
        opacity_linear = max(0.0, min(1.0, opacity_linear))
        co = np.uint32(opacity_linear * _QUANTIZE_8_BIT_MAX + _ROUNDING_OFFSET)
        packed[i, 3] = (cr << _COLOR_R_SHIFT) | (cg << _COLOR_G_SHIFT) | (cb << _COLOR_B_SHIFT) | co

        # ======================================================================
        # SECTION 4: Pack quaternions (2+10-10-10 bits, smallest-three)
        # ======================================================================
        qw = sorted_quats[i, 0]
        qx = sorted_quats[i, 1]
        qy = sorted_quats[i, 2]
        qz = sorted_quats[i, 3]

        # --- Step 4.1: Normalize quaternion ---
        qnorm = np.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
        if qnorm > 0:
            inv_norm = 1.0 / qnorm
            qw *= inv_norm
            qx *= inv_norm
            qy *= inv_norm
            qz *= inv_norm

        # --- Step 4.2: Find largest component by absolute value ---
        abs_w, abs_x, abs_y, abs_z = abs(qw), abs(qx), abs(qy), abs(qz)
        largest_idx = 0
        largest_val = abs_w
        if abs_x > largest_val:
            largest_idx = 1
            largest_val = abs_x
        if abs_y > largest_val:
            largest_idx = 2
            largest_val = abs_y
        if abs_z > largest_val:
            largest_idx = 3

        # --- Step 4.3: Get components in order, flip sign if largest is negative ---
        if largest_idx == 0:
            if qw < 0:
                qw, qx, qy, qz = -qw, -qx, -qy, -qz
            qa, qb, qc = qx, qy, qz
        elif largest_idx == 1:
            if qx < 0:
                qw, qx, qy, qz = -qw, -qx, -qy, -qz
            qa, qb, qc = qw, qy, qz
        elif largest_idx == 2:
            if qy < 0:
                qw, qx, qy, qz = -qw, -qx, -qy, -qz
            qa, qb, qc = qw, qx, qz
        else:
            if qz < 0:
                qw, qx, qy, qz = -qw, -qx, -qy, -qz
            qa, qb, qc = qw, qx, qy

        # --- Step 4.4: Normalize to [0, 1] for quantization ---
        qa_norm = qa * norm_factor + 0.5
        qb_norm = qb * norm_factor + 0.5
        qc_norm = qc * norm_factor + 0.5
        qa_norm = max(0.0, min(1.0, qa_norm))
        qb_norm = max(0.0, min(1.0, qb_norm))
        qc_norm = max(0.0, min(1.0, qc_norm))

        # --- Step 4.5: Quantize and pack ---
        qa_int = np.uint32(qa_norm * _QUANTIZE_10_BIT_MAX + _ROUNDING_OFFSET)
        qb_int = np.uint32(qb_norm * _QUANTIZE_10_BIT_MAX + _ROUNDING_OFFSET)
        qc_int = np.uint32(qc_norm * _QUANTIZE_10_BIT_MAX + _ROUNDING_OFFSET)
        packed[i, 1] = (
            (np.uint32(largest_idx) << _QUAT_INDEX_SHIFT)
            | (qa_int << _QUAT_A_SHIFT)
            | (qb_int << _QUAT_B_SHIFT)
            | qc_int
        )

    return packed


@jit(nopython=True, parallel=True, fastmath=True, cache=True)
def _compute_chunk_bounds_jit(
    sorted_means, sorted_scales, sorted_color_rgb, chunk_starts, chunk_ends
):
    """JIT-compiled chunk bounds computation (9x faster than Python loop).

    Computes min/max bounds for positions, scales, and colors for each chunk.
    This is the main bottleneck in compressed write (~90ms -> ~10ms).

    :param sorted_means: (N, 3) float32 array of positions
    :param sorted_scales: (N, 3) float32 array of scales
    :param sorted_color_rgb: (N, 3) float32 array of pre-computed RGB colors (SH0 * SH_C0 + 0.5)
    :param chunk_starts: (num_chunks,) int array of chunk start indices
    :param chunk_ends: (num_chunks,) int array of chunk end indices
    :returns: (num_chunks, 18) float32 array with layout [0:6] min_x, min_y, min_z, max_x, max_y, max_z,
              [6:12] min_scale_x/y/z, max_scale_x/y/z (clamped to [-20,20]), [12:18] min_r, min_g, min_b, max_r, max_g, max_b
    """
    num_chunks = len(chunk_starts)
    bounds = np.zeros((num_chunks, 18), dtype=np.float32)

    for chunk_idx in numba.prange(num_chunks):
        start = chunk_starts[chunk_idx]
        end = chunk_ends[chunk_idx]

        if start >= end:  # Empty chunk
            continue

        # Initialize with first element
        bounds[chunk_idx, 0] = sorted_means[start, 0]  # min_x
        bounds[chunk_idx, 1] = sorted_means[start, 1]  # min_y
        bounds[chunk_idx, 2] = sorted_means[start, 2]  # min_z
        bounds[chunk_idx, 3] = sorted_means[start, 0]  # max_x
        bounds[chunk_idx, 4] = sorted_means[start, 1]  # max_y
        bounds[chunk_idx, 5] = sorted_means[start, 2]  # max_z

        bounds[chunk_idx, 6] = sorted_scales[start, 0]  # min_scale_x
        bounds[chunk_idx, 7] = sorted_scales[start, 1]  # min_scale_y
        bounds[chunk_idx, 8] = sorted_scales[start, 2]  # min_scale_z
        bounds[chunk_idx, 9] = sorted_scales[start, 0]  # max_scale_x
        bounds[chunk_idx, 10] = sorted_scales[start, 1]  # max_scale_y
        bounds[chunk_idx, 11] = sorted_scales[start, 2]  # max_scale_z

        # Use pre-computed RGB for first element
        color_r = sorted_color_rgb[start, 0]
        color_g = sorted_color_rgb[start, 1]
        color_b = sorted_color_rgb[start, 2]

        bounds[chunk_idx, 12] = color_r  # min_r
        bounds[chunk_idx, 13] = color_g  # min_g
        bounds[chunk_idx, 14] = color_b  # min_b
        bounds[chunk_idx, 15] = color_r  # max_r
        bounds[chunk_idx, 16] = color_g  # max_g
        bounds[chunk_idx, 17] = color_b  # max_b

        # Process remaining elements in chunk
        for i in range(start + 1, end):
            # Position bounds
            bounds[chunk_idx, 0] = min(bounds[chunk_idx, 0], sorted_means[i, 0])
            bounds[chunk_idx, 1] = min(bounds[chunk_idx, 1], sorted_means[i, 1])
            bounds[chunk_idx, 2] = min(bounds[chunk_idx, 2], sorted_means[i, 2])
            bounds[chunk_idx, 3] = max(bounds[chunk_idx, 3], sorted_means[i, 0])
            bounds[chunk_idx, 4] = max(bounds[chunk_idx, 4], sorted_means[i, 1])
            bounds[chunk_idx, 5] = max(bounds[chunk_idx, 5], sorted_means[i, 2])

            # Scale bounds
            bounds[chunk_idx, 6] = min(bounds[chunk_idx, 6], sorted_scales[i, 0])
            bounds[chunk_idx, 7] = min(bounds[chunk_idx, 7], sorted_scales[i, 1])
            bounds[chunk_idx, 8] = min(bounds[chunk_idx, 8], sorted_scales[i, 2])
            bounds[chunk_idx, 9] = max(bounds[chunk_idx, 9], sorted_scales[i, 0])
            bounds[chunk_idx, 10] = max(bounds[chunk_idx, 10], sorted_scales[i, 1])
            bounds[chunk_idx, 11] = max(bounds[chunk_idx, 11], sorted_scales[i, 2])

            # Color bounds (already converted to RGB)
            color_r = sorted_color_rgb[i, 0]
            color_g = sorted_color_rgb[i, 1]
            color_b = sorted_color_rgb[i, 2]

            bounds[chunk_idx, 12] = min(bounds[chunk_idx, 12], color_r)
            bounds[chunk_idx, 13] = min(bounds[chunk_idx, 13], color_g)
            bounds[chunk_idx, 14] = min(bounds[chunk_idx, 14], color_b)
            bounds[chunk_idx, 15] = max(bounds[chunk_idx, 15], color_r)
            bounds[chunk_idx, 16] = max(bounds[chunk_idx, 16], color_g)
            bounds[chunk_idx, 17] = max(bounds[chunk_idx, 17], color_b)

        # Clamp scale bounds to [-20, 20] (matches splat-transform)
        for j in range(6, 12):
            bounds[chunk_idx, j] = max(-20.0, min(20.0, bounds[chunk_idx, j]))

    return bounds


# ======================================================================================
# HELPER FUNCTIONS
# ======================================================================================


def _ensure_numpy_arrays(
    means, scales, quats, opacities, sh0, shn
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]:
    """Convert inputs to numpy arrays if they aren't already.

    :param means: Gaussian centers (any array-like)
    :param scales: Log scales (any array-like)
    :param quats: Rotations as quaternions (any array-like)
    :param opacities: Logit opacities (any array-like)
    :param sh0: DC spherical harmonics (any array-like)
    :param shn: Higher-order SH coefficients or None (any array-like or None)
    :return: Tuple of numpy arrays (may be converted to float32 if not already numpy arrays)
    :rtype: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]
    """
    if not isinstance(means, np.ndarray):
        means = np.asarray(means, dtype=np.float32)
    if not isinstance(scales, np.ndarray):
        scales = np.asarray(scales, dtype=np.float32)
    if not isinstance(quats, np.ndarray):
        quats = np.asarray(quats, dtype=np.float32)
    if not isinstance(opacities, np.ndarray):
        opacities = np.asarray(opacities, dtype=np.float32)
    if not isinstance(sh0, np.ndarray):
        sh0 = np.asarray(sh0, dtype=np.float32)
    if shn is not None and not isinstance(shn, np.ndarray):
        shn = np.asarray(shn, dtype=np.float32)
    return means, scales, quats, opacities, sh0, shn


def _convert_to_float32(
    means: np.ndarray,
    scales: np.ndarray,
    quats: np.ndarray,
    opacities: np.ndarray,
    sh0: np.ndarray,
    shn: np.ndarray | None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]:
    """Convert arrays to float32 dtype if needed (avoids copy when already float32).

    :param means: Gaussian centers array
    :type means: np.ndarray
    :param scales: Log scales array
    :type scales: np.ndarray
    :param quats: Rotations as quaternions array
    :type quats: np.ndarray
    :param opacities: Logit opacities array
    :type opacities: np.ndarray
    :param sh0: DC spherical harmonics array
    :type sh0: np.ndarray
    :param shn: Higher-order SH coefficients or None
    :type shn: np.ndarray | None
    :return: Tuple of float32 arrays
    :rtype: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]
    """
    # Fast path: check if all arrays are already float32
    all_float32 = (
        means.dtype == np.float32
        and scales.dtype == np.float32
        and quats.dtype == np.float32
        and opacities.dtype == np.float32
        and sh0.dtype == np.float32
        and (shn is None or shn.dtype == np.float32)
    )

    # Only convert dtype if needed (avoids copy when already float32)
    if not all_float32:
        if means.dtype != np.float32:
            means = means.astype(np.float32, copy=False)
        if scales.dtype != np.float32:
            scales = scales.astype(np.float32, copy=False)
        if quats.dtype != np.float32:
            quats = quats.astype(np.float32, copy=False)
        if opacities.dtype != np.float32:
            opacities = opacities.astype(np.float32, copy=False)
        if sh0.dtype != np.float32:
            sh0 = sh0.astype(np.float32, copy=False)
        if shn is not None and shn.dtype != np.float32:
            shn = shn.astype(np.float32, copy=False)

    return means, scales, quats, opacities, sh0, shn


def _validate_array_shapes(
    means: np.ndarray,
    scales: np.ndarray,
    quats: np.ndarray,
    opacities: np.ndarray,
    sh0: np.ndarray,
    num_gaussians: int,
) -> None:
    """Validate that all arrays have the expected shapes.

    :param means: Gaussian centers array, shape (N, 3)
    :type means: np.ndarray
    :param scales: Log scales array, shape (N, 3)
    :type scales: np.ndarray
    :param quats: Rotations as quaternions array, shape (N, 4)
    :type quats: np.ndarray
    :param opacities: Logit opacities array, shape (N,)
    :type opacities: np.ndarray
    :param sh0: DC spherical harmonics array, shape (N, 3)
    :type sh0: np.ndarray
    :param num_gaussians: Expected number of Gaussians (N)
    :type num_gaussians: int
    :raises AssertionError: If any array has incorrect shape
    """
    assert means.shape == (num_gaussians, 3), (
        f"means array has incorrect shape: expected ({num_gaussians}, 3), "
        f"got {means.shape}. Ensure all arrays have the same number of Gaussians (N)."
    )
    assert scales.shape == (num_gaussians, 3), (
        f"scales array has incorrect shape: expected ({num_gaussians}, 3), "
        f"got {scales.shape}. Ensure all arrays have the same number of Gaussians (N)."
    )
    assert quats.shape == (num_gaussians, 4), (
        f"quats array has incorrect shape: expected ({num_gaussians}, 4), "
        f"got {quats.shape}. Quaternions must have 4 components (w, x, y, z)."
    )
    assert opacities.shape == (num_gaussians,), (
        f"opacities array has incorrect shape: expected ({num_gaussians},), "
        f"got {opacities.shape}. Opacities should be a 1D array with one value per Gaussian."
    )
    assert sh0.shape == (num_gaussians, 3), (
        f"sh0 array has incorrect shape: expected ({num_gaussians}, 3), "
        f"got {sh0.shape}. SH DC coefficients must have 3 components (RGB)."
    )


def _flatten_shn(shn: np.ndarray | None, validate: bool) -> np.ndarray | None:
    """Flatten shN array from (N, K, 3) to (N, K*3) if needed.

    :param shn: Higher-order SH coefficients or None
    :type shn: np.ndarray | None
    :param validate: Whether to validate the shape
    :type validate: bool
    :return: Flattened shN array or None
    :rtype: np.ndarray | None
    """
    if shn is not None and shn.ndim == 3:
        n_gaussians, n_bands, n_components = shn.shape
        if validate:
            assert n_components == 3, f"shN must have shape (N, K, 3), got {shn.shape}"
        shn = shn.reshape(n_gaussians, n_bands * n_components)
    return shn


def _compute_chunk_boundaries(num_chunks: int, num_gaussians: int) -> tuple[np.ndarray, np.ndarray]:
    """Compute chunk start and end indices for chunked processing.

    Each chunk contains CHUNK_SIZE Gaussians, except possibly the last chunk
    which may be smaller if num_gaussians is not a multiple of CHUNK_SIZE.

    :param num_chunks: Number of chunks
    :type num_chunks: int
    :param num_gaussians: Total number of Gaussians
    :type num_gaussians: int
    :return: Tuple of (chunk_starts, chunk_ends) arrays of shape (num_chunks,)
    :rtype: tuple[np.ndarray, np.ndarray]
    """
    chunk_starts = np.arange(num_chunks, dtype=np.int32) * CHUNK_SIZE
    chunk_ends = np.minimum(chunk_starts + CHUNK_SIZE, num_gaussians)
    return chunk_starts, chunk_ends


def _validate_and_normalize_inputs(
    means: np.ndarray,
    scales: np.ndarray,
    quats: np.ndarray,
    opacities: np.ndarray,
    sh0: np.ndarray,
    shn: np.ndarray | None,
    validate: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]:
    """Validate and normalize input arrays to float32 format.

    :param means: Gaussian centers, shape (N, 3)
    :type means: np.ndarray
    :param scales: Log scales, shape (N, 3)
    :type scales: np.ndarray
    :param quats: Rotations as quaternions (wxyz), shape (N, 4)
    :type quats: np.ndarray
    :param opacities: Logit opacities, shape (N,)
    :type opacities: np.ndarray
    :param sh0: DC spherical harmonics, shape (N, 3)
    :type sh0: np.ndarray
    :param shn: Higher-order SH coefficients, shape (N, K, 3) or None
    :type shn: np.ndarray | None
    :param validate: Whether to validate shapes
    :type validate: bool
    :return: Tuple of normalized arrays (all float32)
    :rtype: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]
    """
    # Step 1: Ensure all inputs are numpy arrays
    means, scales, quats, opacities, sh0, shn = _ensure_numpy_arrays(
        means, scales, quats, opacities, sh0, shn
    )

    # Step 2: Convert all arrays to float32 dtype
    means, scales, quats, opacities, sh0, shn = _convert_to_float32(
        means, scales, quats, opacities, sh0, shn
    )

    num_gaussians = means.shape[0]

    # Step 3: Validate shapes if requested
    if validate:
        _validate_array_shapes(means, scales, quats, opacities, sh0, num_gaussians)

    # Step 4: Flatten shN if needed (from (N, K, 3) to (N, K*3))
    shn = _flatten_shn(shn, validate)

    return means, scales, quats, opacities, sh0, shn


def _compress_data_internal(
    means: np.ndarray,
    scales: np.ndarray,
    quats: np.ndarray,
    opacities: np.ndarray,
    sh0: np.ndarray,
    shn: np.ndarray | None,
) -> tuple[bytes, np.ndarray, np.ndarray, np.ndarray | None, int, int]:
    """Internal function to compress Gaussian data (shared compression logic).

    This function contains the core compression logic extracted from write_compressed().
    All inputs must be pre-validated and normalized to float32.

    :param means: (N, 3) float32 - xyz positions
    :param scales: (N, 3) float32 - scale parameters
    :param quats: (N, 4) float32 - rotation quaternions
    :param opacities: (N,) float32 - opacity values
    :param sh0: (N, 3) float32 - DC spherical harmonics
    :param shn: (N, K*3) float32 or None - flattened SH coefficients
    :returns: Tuple of (header_bytes, chunk_bounds, packed_data, packed_sh, num_gaussians, num_chunks)
    """
    num_gaussians = means.shape[0]
    num_chunks = (num_gaussians + CHUNK_SIZE - 1) // CHUNK_SIZE

    # OPTIMIZATION: Chunk indices are ALWAYS already sorted!
    # If we computed chunk_indices = np.arange(num_gaussians) >> CHUNK_SIZE_SHIFT,
    # the indices would be sequential [0,0,0..., 1,1,1..., 2,2,2...] which is already sorted.
    # Since we don't need to sort, we can skip computing chunk_indices entirely.
    sorted_means = means
    sorted_scales = scales
    sorted_sh0 = sh0
    sorted_quats = quats
    sorted_opacities = opacities
    sorted_shn = shn

    # Pre-compute SH0 to RGB conversion (used in chunk bounds and packing)
    sorted_color_rgb = sorted_sh0 * SH_C0 + 0.5

    # Compute chunk boundaries (start/end indices for each chunk)
    chunk_starts, chunk_ends = _compute_chunk_boundaries(num_chunks, num_gaussians)

    # Allocate chunk bounds arrays
    chunk_bounds = np.zeros((num_chunks, 18), dtype=np.float32)

    # Compute chunk bounds using JIT-compiled function
    chunk_bounds = _compute_chunk_bounds_jit(
        sorted_means, sorted_scales, sorted_color_rgb, chunk_starts, chunk_ends
    )

    # Extract individual min/max values for packing (views into chunk_bounds)
    min_x, min_y, min_z = chunk_bounds[:, 0], chunk_bounds[:, 1], chunk_bounds[:, 2]
    max_x, max_y, max_z = chunk_bounds[:, 3], chunk_bounds[:, 4], chunk_bounds[:, 5]
    min_scale_x, min_scale_y, min_scale_z = (
        chunk_bounds[:, 6],
        chunk_bounds[:, 7],
        chunk_bounds[:, 8],
    )
    max_scale_x, max_scale_y, max_scale_z = (
        chunk_bounds[:, 9],
        chunk_bounds[:, 10],
        chunk_bounds[:, 11],
    )
    min_r, min_g, min_b = chunk_bounds[:, 12], chunk_bounds[:, 13], chunk_bounds[:, 14]
    max_r, max_g, max_b = chunk_bounds[:, 15], chunk_bounds[:, 16], chunk_bounds[:, 17]

    # Pre-compute ranges using vectorized NumPy operations
    # Uses np.maximum to handle zero-range case (replaces conditional: r if r > 0 else 1.0)
    # This is faster than Python loop for large num_chunks
    min_range_epsilon = np.float32(1e-10)  # Small epsilon to avoid division by zero

    # Position ranges (vectorized subtraction + max with epsilon)
    range_x = np.maximum(max_x - min_x, min_range_epsilon)
    range_y = np.maximum(max_y - min_y, min_range_epsilon)
    range_z = np.maximum(max_z - min_z, min_range_epsilon)

    # Scale ranges
    range_scale_x = np.maximum(max_scale_x - min_scale_x, min_range_epsilon)
    range_scale_y = np.maximum(max_scale_y - min_scale_y, min_range_epsilon)
    range_scale_z = np.maximum(max_scale_z - min_scale_z, min_range_epsilon)

    # Color ranges
    range_r = np.maximum(max_r - min_r, min_range_epsilon)
    range_g = np.maximum(max_g - min_g, min_range_epsilon)
    range_b = np.maximum(max_b - min_b, min_range_epsilon)

    # Use fused JIT-compiled function for parallel compression
    # Single pass over all data for better cache locality and reduced overhead
    packed_data = _pack_all_jit(
        sorted_means,
        sorted_scales,
        sorted_color_rgb,
        sorted_opacities,
        sorted_quats,
        min_x,
        min_y,
        min_z,
        range_x,
        range_y,
        range_z,
        min_scale_x,
        min_scale_y,
        min_scale_z,
        range_scale_x,
        range_scale_y,
        range_scale_z,
        min_r,
        min_g,
        min_b,
        range_r,
        range_g,
        range_b,
    )

    # SH coefficient compression (8-bit quantization)
    packed_sh = None
    if sorted_shn is not None and sorted_shn.shape[1] > 0:
        # Quantize to uint8: ((shN / 8 + 0.5) * 256), clamped to [0, 255]
        # Simplified to: shN * 32 + 128, clamped to [0, 255]
        packed_sh = np.clip(sorted_shn * 32.0 + 128.0, 0, 255).astype(np.uint8)

    # Build header
    header_lines = [
        "ply",
        "format binary_little_endian 1.0",
        f"element chunk {num_chunks}",
    ]

    # Add chunk properties (18 floats)
    chunk_props = [
        "min_x",
        "min_y",
        "min_z",
        "max_x",
        "max_y",
        "max_z",
        "min_scale_x",
        "min_scale_y",
        "min_scale_z",
        "max_scale_x",
        "max_scale_y",
        "max_scale_z",
        "min_r",
        "min_g",
        "min_b",
        "max_r",
        "max_g",
        "max_b",
    ]
    for prop in chunk_props:
        header_lines.append(f"property float {prop}")

    # Add vertex element
    header_lines.append(f"element vertex {num_gaussians}")
    header_lines.append("property uint packed_position")
    header_lines.append("property uint packed_rotation")
    header_lines.append("property uint packed_scale")
    header_lines.append("property uint packed_color")

    # Add SH element if present
    if packed_sh is not None:
        num_sh_coeffs = packed_sh.shape[1]
        header_lines.append(f"element sh {num_gaussians}")
        for i in range(num_sh_coeffs):
            header_lines.append(f"property uchar coeff_{i}")

    header_lines.append("end_header")
    header = "\n".join(header_lines) + "\n"
    header_bytes = header.encode("ascii")

    return header_bytes, chunk_bounds, packed_data, packed_sh, num_gaussians, num_chunks


# ======================================================================================
# UNCOMPRESSED PLY WRITER
# ======================================================================================


[docs] def write_uncompressed( file_path: str | Path, data: "GSData", # noqa: F821 validate: bool = True, ) -> None: """Write uncompressed Gaussian splatting PLY file with zero-copy optimization. Always operates on GSData objects. Automatically uses zero-copy when data has a _base array (from plyread), achieving 6-8x speedup. Performance: - Zero-copy path (data with _base): Header + I/O only, no memory copying * 400K SH3: ~15-20ms (vs 121ms without optimization) - 6-8x faster! - Standard path (data without _base): ~20-120ms depending on size and SH degree - Peak: 70M Gaussians/sec for 400K Gaussians, SH0 (zero-copy) :param file_path: Output PLY file path :param data: GSData object containing Gaussian parameters :param validate: If True, validate input shapes (default True) Example: >>> # RECOMMENDED: Pass GSData directly (automatic zero-copy) >>> data = plyread("input.ply") >>> write_uncompressed("output.ply", data) # 6-8x faster! >>> >>> # Create GSData from scratch >>> data = GSData(means, scales, quats, opacities, sh0, shN) >>> write_uncompressed("output.ply", data) """ file_path = Path(file_path) # ZERO-COPY FAST PATH: Write _base array directly if it exists if data._base is not None: num_gaussians = len(data) # shN.shape = (N, K, 3) where K is number of bands # Header needs total coefficients = K * 3 num_sh_rest = ( data.shN.shape[1] * 3 if (data.shN is not None and data.shN.size > 0) else None ) header_bytes = _build_header_fast(num_gaussians, num_sh_rest) buffer_size = ( _LARGE_BUFFER_SIZE if data._base.nbytes > _LARGE_FILE_THRESHOLD else _SMALL_BUFFER_SIZE ) with open(file_path, "wb", buffering=buffer_size) as f: f.write(header_bytes) data._base.tofile(f) logger.debug( f"[Gaussian PLY] Wrote uncompressed (zero-copy): {num_gaussians} Gaussians to {file_path.name}" ) return # STANDARD PATH: Construct array from GSData fields means, scales, quats, opacities, sh0, shn = data.unpack() # Validate and normalize inputs using shared helper means, scales, quats, opacities, sh0, shn = _validate_and_normalize_inputs( means, scales, quats, opacities, sh0, shn, validate ) num_gaussians = means.shape[0] # Build header using pre-computed templates (3-5% faster) num_sh_rest = shn.shape[1] if shn is not None else None header_bytes = _build_header_fast(num_gaussians, num_sh_rest) # STANDARD PATH: Construct array using JIT-compiled interleaving (2.8-5x faster) # Ensure arrays are contiguous float32 for JIT kernels means = np.ascontiguousarray(means, dtype=np.float32) sh0 = np.ascontiguousarray(sh0, dtype=np.float32) opacities = np.ascontiguousarray(opacities.ravel(), dtype=np.float32) scales = np.ascontiguousarray(scales, dtype=np.float32) quats = np.ascontiguousarray(quats, dtype=np.float32) if shn is not None: sh_coeffs = shn.shape[1] # Number of SH coefficients (already reshaped to N x K*3) total_props = 14 + sh_coeffs shn_flat = np.ascontiguousarray(shn, dtype=np.float32) output_array = np.empty((num_gaussians, total_props), dtype=np.float32) _interleave_shn_jit(means, sh0, shn_flat, opacities, scales, quats, output_array, sh_coeffs) else: output_array = np.empty((num_gaussians, 14), dtype=np.float32) _interleave_sh0_jit(means, sh0, opacities, scales, quats, output_array) # Write with optimized buffering (1-3% faster for large files) buffer_size = ( _LARGE_BUFFER_SIZE if output_array.nbytes > _LARGE_FILE_THRESHOLD else _SMALL_BUFFER_SIZE ) with open(file_path, "wb", buffering=buffer_size) as f: f.write(header_bytes) output_array.tofile(f) logger.debug( f"[Gaussian PLY] Wrote uncompressed: {num_gaussians} Gaussians to {file_path.name}" )
# ====================================================================================== # COMPRESSED PLY WRITER (VECTORIZED) # ======================================================================================
[docs] def write_compressed( file_path: str | Path, means: np.ndarray, scales: np.ndarray, quats: np.ndarray, opacities: np.ndarray, sh0: np.ndarray, shN: np.ndarray | None = None, # noqa: N803 validate: bool = True, ) -> None: """Write compressed Gaussian splatting PLY file (PlayCanvas format). Compresses data using chunk-based quantization (256 Gaussians per chunk). Achieves 3.8-14.5x compression ratio using highly optimized vectorized operations. Uses Numba JIT compilation for fast parallel compression (3.8x faster than pure NumPy). :param file_path: Output PLY file path :param means: (N, 3) - xyz positions :param scales: (N, 3) - scale parameters :param quats: (N, 4) - rotation quaternions (must be normalized) :param opacities: (N,) - opacity values :param sh0: (N, 3) - DC spherical harmonics :param shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional) :param validate: If True, validate input shapes (default True) Performance: - With JIT: ~15ms for 400K Gaussians, SH0 (27M Gaussians/sec) - With JIT: ~92ms for 400K Gaussians, SH3 (4.4M Gaussians/sec) Format: Compressed PLY with chunk-based quantization: - 256 Gaussians per chunk - Position: 11-10-11 bit quantization - Scale: 11-10-11 bit quantization - Color: 8-8-8-8 bit quantization - Quaternion: smallest-three encoding (2+10+10+10 bits) - SH coefficients: 8-bit quantization (optional) Example: >>> write_compressed("output.ply", means, scales, quats, opacities, sh0, shN) >>> # File is 14.5x smaller than uncompressed """ file_path = Path(file_path) # Validate and normalize inputs using shared helper means, scales, quats, opacities, sh0, shN = _validate_and_normalize_inputs( # noqa: N806 means, scales, quats, opacities, sh0, shN, validate ) # Use internal compression function header_bytes, chunk_bounds, packed_data, packed_sh, num_gaussians, num_chunks = ( _compress_data_internal(means, scales, quats, opacities, sh0, shN) ) # Write to file with open(file_path, "wb") as f: f.write(header_bytes) chunk_bounds.tofile(f) packed_data.tofile(f) if packed_sh is not None: packed_sh.tofile(f) logger.debug( f"[Gaussian PLY] Wrote compressed: {num_gaussians} Gaussians to {file_path.name} " f"({num_chunks} chunks, {len(header_bytes) + chunk_bounds.nbytes + packed_data.nbytes + (packed_sh.nbytes if packed_sh is not None else 0)} bytes)" )
[docs] def compress_to_bytes( data_or_means: GSData | np.ndarray, scales: np.ndarray | None = None, quats: np.ndarray | None = None, opacities: np.ndarray | None = None, sh0: np.ndarray | None = None, shN: np.ndarray | None = None, # noqa: N803 validate: bool = True, ) -> bytes: """Compress Gaussian splatting data to bytes (PlayCanvas format). Compresses Gaussian data into PlayCanvas format and returns as bytes, without writing to disk. Useful for network transfer or custom storage. :param data_or_means: Either a GSData object or means array (N, 3) float32 :param scales: Gaussian scales (N, 3) float32 (required if first arg is means) :param quats: Gaussian quaternions (N, 4) float32 (required if first arg is means) :param opacities: Gaussian opacities (N,) float32 (required if first arg is means) :param sh0: Degree 0 SH coefficients RGB (N, 3) float32 (required if first arg is means) :param shN: Optional higher degree SH coefficients (N, K, 3) float32 :param validate: Whether to validate inputs :returns: Complete compressed PLY file as bytes Example: >>> from gsply import plyread, compress_to_bytes >>> # Method 1: Using GSData (recommended) >>> data = plyread("model.ply") >>> compressed_bytes = compress_to_bytes(data) >>> >>> # Method 2: Using individual arrays (backward compatible) >>> compressed_bytes = compress_to_bytes( ... means, scales, quats, opacities, sh0, shN ... ) >>> >>> # Save or transmit >>> with open("output.compressed.ply", "wb") as f: ... f.write(compressed_bytes) """ # Handle GSData input if isinstance(data_or_means, GSData): means = data_or_means.means scales = data_or_means.scales quats = data_or_means.quats opacities = data_or_means.opacities sh0 = data_or_means.sh0 shN = data_or_means.shN # noqa: N806 else: # Use individual arrays means = data_or_means if scales is None or quats is None or opacities is None or sh0 is None: raise ValueError( "When passing individual arrays, scales, quats, opacities, and sh0 are required. " "Consider using GSData for cleaner API: compress_to_bytes(data)" ) # Validate and normalize inputs means, scales, quats, opacities, sh0, shN = _validate_and_normalize_inputs( # noqa: N806 means, scales, quats, opacities, sh0, shN, validate ) # Compress data using internal helper header_bytes, chunk_bounds, packed_data, packed_sh, num_gaussians, num_chunks = ( _compress_data_internal(means, scales, quats, opacities, sh0, shN) ) # Assemble complete file bytes (use bytes.join for ~4% speed improvement) parts = [header_bytes, chunk_bounds.tobytes(), packed_data.tobytes()] if packed_sh is not None: parts.append(packed_sh.tobytes()) total_bytes = b"".join(parts) logger.debug( f"[Gaussian PLY] Compressed to bytes: {num_gaussians} Gaussians " f"({num_chunks} chunks, {len(total_bytes)} bytes)" ) return total_bytes
[docs] def compress_to_arrays( data_or_means: GSData | np.ndarray, scales: np.ndarray | None = None, quats: np.ndarray | None = None, opacities: np.ndarray | None = None, sh0: np.ndarray | None = None, shN: np.ndarray | None = None, # noqa: N803 validate: bool = True, ) -> tuple[bytes, np.ndarray, np.ndarray, np.ndarray | None]: """Compress Gaussian splatting data to component arrays (PlayCanvas format). Compresses Gaussian data into PlayCanvas format and returns as separate components (header, chunks, data, SH), without writing to disk. Useful for custom processing or partial updates. :param data_or_means: Either a GSData object or means array (N, 3) float32 :param scales: Gaussian scales (N, 3) float32 (required if first arg is means) :param quats: Gaussian quaternions (N, 4) float32 (required if first arg is means) :param opacities: Gaussian opacities (N,) float32 (required if first arg is means) :param sh0: Degree 0 SH coefficients RGB (N, 3) float32 (required if first arg is means) :param shN: Optional higher degree SH coefficients (N, K, 3) float32 :param validate: Whether to validate inputs :returns: Tuple containing header_bytes (PLY header as bytes), chunk_bounds (Chunk boundary array (num_chunks, 18) float32), packed_data (Main compressed data array (N, 4) uint32), packed_sh (Optional compressed SH data array uint8) Example: >>> from gsply import plyread, compress_to_arrays >>> # Method 1: Using GSData (recommended) >>> data = plyread("model.ply") >>> header, chunks, packed, sh = compress_to_arrays(data) >>> >>> # Method 2: Using individual arrays (backward compatible) >>> header, chunks, packed, sh = compress_to_arrays( ... means, scales, quats, opacities, sh0, shN ... ) >>> >>> # Process components individually >>> print(f"Header size: {len(header)} bytes") >>> print(f"Chunks shape: {chunks.shape}") >>> print(f"Packed data: {packed.nbytes} bytes") """ # Handle GSData input if isinstance(data_or_means, GSData): means = data_or_means.means scales = data_or_means.scales quats = data_or_means.quats opacities = data_or_means.opacities sh0 = data_or_means.sh0 shN = data_or_means.shN # noqa: N806 else: # Use individual arrays means = data_or_means if scales is None or quats is None or opacities is None or sh0 is None: raise ValueError( "When passing individual arrays, scales, quats, opacities, and sh0 are required. " "Consider using GSData for cleaner API: compress_to_arrays(data)" ) # Validate and normalize inputs means, scales, quats, opacities, sh0, shN = _validate_and_normalize_inputs( # noqa: N806 means, scales, quats, opacities, sh0, shN, validate ) # Compress data using internal helper header_bytes, chunk_bounds, packed_data, packed_sh, num_gaussians, num_chunks = ( _compress_data_internal(means, scales, quats, opacities, sh0, shN) ) logger.debug( f"[Gaussian PLY] Compressed to arrays: {num_gaussians} Gaussians " f"({num_chunks} chunks, header={len(header_bytes)} bytes, " f"bounds={chunk_bounds.nbytes} bytes, data={packed_data.nbytes} bytes, " f"sh={packed_sh.nbytes if packed_sh is not None else 0} bytes)" ) return header_bytes, chunk_bounds, packed_data, packed_sh
# ====================================================================================== # UNIFIED WRITING API # ======================================================================================
[docs] def plywrite( file_path: str | Path, data: "GSData | GSTensor | np.ndarray", # noqa: F821 scales: np.ndarray | None = None, quats: np.ndarray | None = None, opacities: np.ndarray | None = None, sh0: np.ndarray | None = None, shN: np.ndarray | None = None, # noqa: N803 compressed: bool = False, validate: bool = True, ) -> None: """Write Gaussian splatting PLY file with automatic optimization. The helper accepts either a :class:`gsply.GSData` instance (recommended), a :class:`gsply.GSTensor` instance (converted to GSData automatically), or the individual Gaussian arrays. When `_base` is available the writer streams the consolidated buffer directly to disk; otherwise it performs a one-time consolidation before writing. File format selection happens automatically: the compressed path is chosen when `compressed=True` or when the destination filename already ends with `.compressed.ply` / `.ply_compressed`. :param file_path: Output PLY file path (extension auto-adjusted if compressed=True) :param data: GSData object, GSTensor object, OR (N, 3) xyz positions array :param scales: (N, 3) scale parameters (required if data is array) :param quats: (N, 4) rotation quaternions (required if data is array) :param opacities: (N,) opacity values (required if data is array) :param sh0: (N, 3) DC spherical harmonics (required if data is array) :param shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional) :param compressed: If True, write compressed format and auto-adjust extension :param validate: If True, validate input shapes (default True) Performance: - GSData from plyread: ~7ms for 400K Gaussians (zero-copy, 53 M/s) - GSData created manually: ~19ms for 400K Gaussians (auto-consolidated, 49 M/s) - Individual arrays: ~19ms for 400K Gaussians (converted + consolidated) - All methods produce identical output Example: >>> # RECOMMENDED: Pass GSData from file (automatic zero-copy) >>> data = plyread("input.ply") >>> plywrite("output.ply", data) # ~7ms for 400K, zero-copy! >>> >>> # GSData created manually (auto-consolidated) >>> data = GSData(means=means, scales=scales, ...) >>> plywrite("output.ply", data) # ~19ms for 400K, auto-optimized! >>> >>> # GSTensor (converted to GSData automatically) >>> gstensor = plyread_gpu("input.compressed.ply", device="cuda") >>> plywrite("output.ply", gstensor, compressed=False) # Uncompressed PLY >>> >>> # Individual arrays (converted + auto-consolidated) >>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN) >>> >>> # Write compressed format >>> plywrite("output.ply", data, compressed=True) """ from gsply.gsdata import GSData # noqa: PLC0415 file_path = Path(file_path) # Convert GSTensor to GSData if needed (lazy import to avoid torch import issues) try: from gsply.torch.gstensor import GSTensor # noqa: PLC0415 if isinstance(data, GSTensor): # Convert GSTensor to GSData (transfers to CPU) data = data.to_gsdata() except (ImportError, RuntimeError): # PyTorch not available or has import issues, skip GSTensor check pass # Convert individual arrays to GSData if not isinstance(data, GSData): # data is actually means array if any(x is None for x in [scales, quats, opacities, sh0]): raise ValueError( "When passing individual arrays, all of data (means), scales, quats, " "opacities, and sh0 must be provided" ) # Create GSData without _base (will auto-consolidate below) # Automatically detect format from values (always returns valid format) from gsply.gsdata import _detect_format_from_values, _get_sh_order_format scales_format, opacities_format = _detect_format_from_values(scales, opacities) # Determine SH degree for format dict if shN is not None and shN.shape[1] > 0: if shN.ndim == 2: sh_bands = shN.shape[1] // 3 else: sh_bands = shN.shape[1] from gsply.formats import SH_BANDS_TO_DEGREE sh_degree = SH_BANDS_TO_DEGREE.get(sh_bands, 0) else: sh_degree = 0 # Create format dict (always provided) format_dict = _create_format_dict( scales=scales_format, opacities=opacities_format, sh0=DataFormat.SH0_SH, # Assume SH format sh_order=_get_sh_order_format(sh_degree), means=DataFormat.MEANS_RAW, quats=DataFormat.QUATS_RAW, ) data = GSData( means=data, scales=scales, quats=quats, opacities=opacities, sh0=sh0, shN=shN if shN is not None else np.empty((data.shape[0], 0, 3), dtype=np.float32), _base=None, # No _base for manually created data _format=format_dict, # Auto-detected format (always provided) ) # Note: Auto-consolidate was removed for better performance # The standard path already uses optimized JIT interleaving kernels, # so pre-consolidating adds overhead without benefit (71ms for 400K Gaussians) # Auto-detect compression from extension is_compressed_ext = file_path.name.endswith((".ply_compressed", ".compressed.ply")) # Check if compressed format requested if compressed or is_compressed_ext: # If compressed=True but no compressed extension, add .compressed.ply if compressed and not is_compressed_ext: # Replace .ply with .compressed.ply, or just append if no .ply if file_path.suffix == ".ply": file_path = file_path.with_suffix(".compressed.ply") else: file_path = Path(str(file_path) + ".compressed.ply") # Ensure data is in PLY format before writing compressed (log-scales, logit-opacities) # Check format flags and convert if needed scales_format = data._format.get("scales") opacities_format = data._format.get("opacities") # Convert to PLY format if not already in PLY format if scales_format != DataFormat.SCALES_PLY or opacities_format != DataFormat.OPACITIES_PLY: logger.debug( f"[PLY Write] Converting from {scales_format}/{opacities_format} to PLY format before writing" ) # Use inplace=True for better performance (avoids 22MB copy for 400K Gaussians) # Safe since we're just writing the file and don't need to preserve original format data = data.normalize(inplace=True) # Extract arrays for compressed write (compressed write doesn't use GSData yet) means, scales, quats, opacities, sh0, shN = data.unpack() # noqa: N806 write_compressed(file_path, means, scales, quats, opacities, sh0, shN) else: # Ensure data is in PLY format before writing uncompressed (log-scales, logit-opacities) # Check format flags and convert if needed scales_format = data._format.get("scales") opacities_format = data._format.get("opacities") # Convert to PLY format if not already in PLY format if scales_format != DataFormat.SCALES_PLY or opacities_format != DataFormat.OPACITIES_PLY: logger.debug( f"[PLY Write] Converting from {scales_format}/{opacities_format} to PLY format before writing" ) # Use inplace=True for better performance (avoids 22MB copy for 400K Gaussians) # Safe since we're just writing the file and don't need to preserve original format data = data.normalize(inplace=True) write_uncompressed(file_path, data, validate=validate)
__all__ = [ "plywrite", "write_uncompressed", "write_compressed", "compress_to_bytes", "compress_to_arrays", ]