Skip to content

Commit c3229b0

Browse files
mvaligurskyMartin Valigursky
andauthored
Generate intervals texture for unified splat copying on GPU (#7881)
Co-authored-by: Martin Valigursky <[email protected]>
1 parent 3776044 commit c3229b0

File tree

8 files changed

+350
-89
lines changed

8 files changed

+350
-89
lines changed

examples/src/examples/gaussian-splatting/lod.example.mjs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ app.on('destroy', () => {
4949
pc.Tracing.set(pc.TRACEID_SHADER_ALLOC, true);
5050

5151
const assets = {
52-
church: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/morocco.ply` }),
52+
// church: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/morocco.ply` }),
53+
church: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/dubai.ply` }),
5354
logo: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/pclogo.ply` }),
5455
guitar: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/guitar.compressed.ply` }),
5556
skull: new pc.Asset('gsplat', 'gsplat', { url: `${rootPath}/static/assets/splats/skull.ply` })

src/scene/gsplat/unified/gspat-state.js

Lines changed: 8 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import { Vec3 } from '../../../core/math/vec3.js';
2-
import { Texture } from '../../../platform/graphics/texture.js';
3-
import { ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R32U } from '../../../platform/graphics/constants.js';
42
import { Vec4 } from '../../../core/math/vec4.js';
53
import { Debug } from '../../../core/debug.js';
4+
import { GSplatIntervalTexture } from './gsplat-interval-texture.js';
65

76
/**
87
* @import { GraphNode } from "../../graph-node.js"
@@ -52,11 +51,11 @@ class GSplatState {
5251
viewport = new Vec4();
5352

5453
/**
55-
* Texture that maps target indices to source splat indices based on intervals
54+
* Manager for the intervals texture generation
5655
*
57-
* @type {Texture|null}
56+
* @type {GSplatIntervalTexture}
5857
*/
59-
intervalsTexture = null;
58+
intervalTexture;
6059

6160
/**
6261
* @param {GraphicsDevice} device - The graphics device
@@ -67,89 +66,22 @@ class GSplatState {
6766
this.device = device;
6867
this.resource = resource;
6968
this.node = node;
69+
this.intervalTexture = new GSplatIntervalTexture(device);
7070
}
7171

7272
destroy() {
7373
this.intervals.length = 0;
74-
this.intervalsTexture?.destroy();
75-
this.intervalsTexture = null;
74+
this.intervalTexture.destroy();
7675
}
7776

7877
setLines(start, count, textureSize, activeSplats) {
7978
this.lineStart = start;
8079
this.lineCount = count;
8180
this.padding = textureSize * count - activeSplats;
81+
Debug.assert(this.padding >= 0);
8282
this.viewport.set(0, start, textureSize, count);
8383
}
8484

85-
/**
86-
* Creates a texture that maps target indices to source splat indices based on intervals
87-
*/
88-
updateIntervalsTexture() {
89-
// Count total number of splats referenced by intervals
90-
let totalSplats = 0;
91-
for (let i = 0; i < this.intervals.length; i += 2) {
92-
const start = this.intervals[i];
93-
const end = this.intervals[i + 1];
94-
totalSplats += (end - start);
95-
}
96-
97-
this.activeSplats = totalSplats;
98-
99-
if (totalSplats === 0) {
100-
return;
101-
}
102-
103-
// Estimate roughly square texture size
104-
const maxTextureSize = this.device.maxTextureSize;
105-
let textureWidth = Math.ceil(Math.sqrt(totalSplats));
106-
textureWidth = Math.min(textureWidth, maxTextureSize);
107-
const textureHeight = Math.ceil(totalSplats / textureWidth);
108-
109-
// Create initial 1x1 texture
110-
if (!this.intervalsTexture) {
111-
this.intervalsTexture = this.createTexture('intervalsTexture', PIXELFORMAT_R32U, 1, 1);
112-
}
113-
114-
// Resize texture if dimensions changed
115-
if (this.intervalsTexture.width !== textureWidth || this.intervalsTexture.height !== textureHeight) {
116-
this.intervalsTexture.resize(textureWidth, textureHeight);
117-
}
118-
119-
// update mapping data
120-
if (this.intervalsTexture) {
121-
const pixels = this.intervalsTexture.lock();
122-
let targetIndex = 0;
123-
124-
for (let i = 0; i < this.intervals.length; i += 2) {
125-
const start = this.intervals[i];
126-
const end = this.intervals[i + 1];
127-
128-
for (let splatIndex = start; splatIndex < end; splatIndex++) {
129-
pixels[targetIndex] = splatIndex;
130-
targetIndex++;
131-
}
132-
}
133-
134-
this.intervalsTexture.unlock();
135-
}
136-
}
137-
138-
createTexture(name, format, width, height) {
139-
return new Texture(this.device, {
140-
name: name,
141-
width: width,
142-
height: height,
143-
format: format,
144-
cubemap: false,
145-
mipmaps: false,
146-
minFilter: FILTER_NEAREST,
147-
magFilter: FILTER_NEAREST,
148-
addressU: ADDRESS_CLAMP_TO_EDGE,
149-
addressV: ADDRESS_CLAMP_TO_EDGE
150-
});
151-
}
152-
15385
/**
15486
* @param {GraphNode} cameraNode - The camera node for LOD calculation
15587
*/
@@ -247,7 +179,7 @@ class GSplatState {
247179
// console.log(`Block LOD Distribution (blocks: ${numBlocks}): LOD 0: ${pcts[0]}%, LOD 1: ${pcts[1]}%, LOD 2: ${pcts[2]}%`);
248180
// }
249181

250-
this.updateIntervalsTexture();
182+
this.activeSplats = this.intervalTexture.update(this.intervals);
251183
}
252184
}
253185

src/scene/gsplat/unified/gsplat-info.js

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,23 @@ class GSplatInfo {
169169
const scope = device.scope;
170170
Debug.assert(resource);
171171

172-
// set up splat resource properties
172+
// render using render state
173+
const { activeSplats, lineStart, viewport, intervalTexture } = this.renderState;
174+
175+
// assign material properties to scope
173176
this.material.setParameters(this.device);
174177

175178
// matrix to transform splats to the world space
176179
scope.resolve('uTransform').setValue(this.node.getWorldTransform().data);
177180

178181
if (resource.hasLod) {
179182
// Set LOD intervals texture for remapping of indices
180-
scope.resolve('uIntervalsTexture').setValue(this.renderState.intervalsTexture);
183+
scope.resolve('uIntervalsTexture').setValue(intervalTexture.texture);
181184
}
182185

183-
const renderState = this.renderState;
184-
scope.resolve('uActiveSplats').setValue(renderState.activeSplats);
185-
scope.resolve('uStartLine').setValue(renderState.lineStart);
186-
scope.resolve('uViewportWidth').setValue(renderState.viewport.z); // this is textureSize, TODO: replace it
186+
scope.resolve('uActiveSplats').setValue(activeSplats);
187+
scope.resolve('uStartLine').setValue(lineStart);
188+
scope.resolve('uViewportWidth').setValue(viewport.z);
187189

188190
// SH related
189191
scope.resolve('matrix_model').setValue(this.node.getWorldTransform().data);
@@ -192,7 +194,7 @@ class GSplatInfo {
192194
const viewMat = _viewMat.copy(viewInvMat).invert();
193195
scope.resolve('matrix_view').setValue(viewMat.data);
194196

195-
drawQuadWithShader(device, renderTarget, this.copyShader, renderState.viewport, renderState.viewport);
197+
drawQuadWithShader(device, renderTarget, this.copyShader, viewport, viewport);
196198
}
197199
}
198200

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import { Texture } from '../../../platform/graphics/texture.js';
2+
import {
3+
ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R32U, PIXELFORMAT_RG32U, CULLFACE_NONE,
4+
SEMANTIC_POSITION
5+
} from '../../../platform/graphics/constants.js';
6+
import { RenderTarget } from '../../../platform/graphics/render-target.js';
7+
import { drawQuadWithShader } from '../../../scene/graphics/quad-render-utils.js';
8+
import { BlendState } from '../../../platform/graphics/blend-state.js';
9+
import { DepthState } from '../../../platform/graphics/depth-state.js';
10+
import { ShaderUtils } from '../../shader-lib/shader-utils.js';
11+
import gsplatIntervalTextureGLSL from '../../shader-lib/glsl/chunks/gsplat/frag/gsplatIntervalTexture.js';
12+
import gsplatIntervalTextureWGSL from '../../shader-lib/wgsl/chunks/gsplat/frag/gsplatIntervalTexture.js';
13+
14+
/**
15+
* @import { GraphicsDevice } from '../../../platform/graphics/graphics-device.js'
16+
* @import { Shader } from '../../../platform/graphics/shader.js'
17+
*/
18+
19+
/**
20+
* Manages the intervals texture generation for GSplat LOD system using GPU acceleration. A list of
21+
* intervals is provided to the update method, and the texture is generated on the GPU. The texture
22+
* is then used to map target indices to source splat indices.
23+
*
24+
* @ignore
25+
*/
26+
class GSplatIntervalTexture {
27+
/** @type {GraphicsDevice} */
28+
device;
29+
30+
/**
31+
* Texture that maps target indices to source splat indices based on intervals
32+
*
33+
* @type {Texture|null}
34+
*/
35+
texture = null;
36+
37+
/**
38+
* Texture that stores interval data (start + accumulated sum pairs) for GPU processing
39+
*
40+
* @type {Texture|null}
41+
*/
42+
intervalsDataTexture = null;
43+
44+
/**
45+
* Shader for generating intervals texture on GPU
46+
*
47+
* @type {Shader|null}
48+
*/
49+
shader = null;
50+
51+
/**
52+
* @param {GraphicsDevice} device - The graphics device
53+
*/
54+
constructor(device) {
55+
this.device = device;
56+
}
57+
58+
destroy() {
59+
this.texture?.destroy();
60+
this.texture = null;
61+
this.intervalsDataTexture?.destroy();
62+
this.intervalsDataTexture = null;
63+
this.shader = null;
64+
}
65+
66+
/**
67+
* Creates shader for GPU-based intervals texture generation
68+
*/
69+
getShader() {
70+
if (!this.shader) {
71+
this.shader = ShaderUtils.createShader(this.device, {
72+
uniqueName: 'GSplatIntervalsShader',
73+
attributes: { aPosition: SEMANTIC_POSITION },
74+
vertexChunk: 'quadVS',
75+
fragmentGLSL: gsplatIntervalTextureGLSL,
76+
fragmentWGSL: gsplatIntervalTextureWGSL,
77+
fragmentOutputTypes: ['uint']
78+
});
79+
}
80+
81+
return this.shader;
82+
}
83+
84+
/**
85+
* Creates a texture with specified parameters
86+
*/
87+
createTexture(name, format, width, height) {
88+
return new Texture(this.device, {
89+
name: name,
90+
width: width,
91+
height: height,
92+
format: format,
93+
cubemap: false,
94+
mipmaps: false,
95+
minFilter: FILTER_NEAREST,
96+
magFilter: FILTER_NEAREST,
97+
addressU: ADDRESS_CLAMP_TO_EDGE,
98+
addressV: ADDRESS_CLAMP_TO_EDGE
99+
});
100+
}
101+
102+
/**
103+
* Updates the intervals texture based on provided intervals array
104+
*
105+
* @param {number[]} intervals - Array of intervals (start, end pairs)
106+
* @returns {number} The number of active splats
107+
*/
108+
update(intervals) {
109+
if (!intervals || intervals.length === 0) {
110+
return 0;
111+
}
112+
113+
// Count total number of splats referenced by intervals
114+
let totalSplats = 0;
115+
for (let i = 0; i < intervals.length; i += 2) {
116+
const start = intervals[i];
117+
const end = intervals[i + 1];
118+
totalSplats += (end - start);
119+
}
120+
121+
// Calculate texture dimensions for output intervals texture
122+
const maxTextureSize = this.device.maxTextureSize;
123+
let textureWidth = Math.ceil(Math.sqrt(totalSplats));
124+
textureWidth = Math.min(textureWidth, maxTextureSize);
125+
const textureHeight = Math.ceil(totalSplats / textureWidth);
126+
127+
// Create/resize main intervals texture
128+
if (!this.texture) {
129+
this.texture = this.createTexture('intervalsTexture', PIXELFORMAT_R32U, textureWidth, textureHeight);
130+
}
131+
if (this.texture.width !== textureWidth || this.texture.height !== textureHeight) {
132+
this.texture.resize(textureWidth, textureHeight);
133+
}
134+
135+
// Prepare intervals data with CPU prefix sum
136+
const numIntervals = intervals.length / 2;
137+
const dataTextureSize = Math.ceil(Math.sqrt(numIntervals));
138+
139+
// Create/resize intervals data texture
140+
if (!this.intervalsDataTexture) {
141+
this.intervalsDataTexture = this.createTexture('intervalsData', PIXELFORMAT_RG32U, dataTextureSize, dataTextureSize);
142+
}
143+
if (this.intervalsDataTexture.width !== dataTextureSize) {
144+
this.intervalsDataTexture.resize(dataTextureSize, dataTextureSize);
145+
}
146+
147+
// Compute intervals data with accumulated sums on CPU
148+
// TODO: consider doing this using compute shader on WebGPU
149+
const intervalsData = this.intervalsDataTexture.lock();
150+
let runningSum = 0;
151+
152+
for (let i = 0; i < numIntervals; i++) {
153+
const start = intervals[i * 2];
154+
const end = intervals[i * 2 + 1];
155+
const intervalSize = end - start;
156+
runningSum += intervalSize;
157+
158+
intervalsData[i * 2] = start; // R: interval start
159+
intervalsData[i * 2 + 1] = runningSum; // G: accumulated sum
160+
}
161+
this.intervalsDataTexture.unlock();
162+
163+
// Generate intervals texture on GPU
164+
const renderTarget = new RenderTarget({
165+
colorBuffer: this.texture,
166+
depth: false
167+
});
168+
169+
const scope = this.device.scope;
170+
scope.resolve('uIntervalsTexture').setValue(this.intervalsDataTexture);
171+
scope.resolve('uNumIntervals').setValue(numIntervals);
172+
scope.resolve('uTextureWidth').setValue(textureWidth);
173+
scope.resolve('uActiveSplats').setValue(totalSplats);
174+
175+
this.device.setCullMode(CULLFACE_NONE);
176+
this.device.setBlendState(BlendState.NOBLEND);
177+
this.device.setDepthState(DepthState.NODEPTH);
178+
179+
drawQuadWithShader(this.device, renderTarget, this.getShader());
180+
181+
renderTarget.destroy();
182+
return totalSplats;
183+
}
184+
}
185+
186+
export { GSplatIntervalTexture };

src/scene/gsplat/unified/gsplat-manager.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class GSplatManager {
294294
// Reassign lines based on current LOD active splats
295295
this.assignLines(this.splats, textureSize);
296296

297-
// give sorter info related it needs to generate global centers array for sorting
297+
// give sorter info it needs to generate global centers array for sorting
298298
const payload = this.centerBuffer.update(this.splats);
299299
this.sorter.setIntervals(payload);
300300
}

src/scene/gsplat/unified/gsplat-unified-sort-worker.js

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,11 @@ function UnifiedSortWorker() {
121121
// }
122122
// }
123123

124-
const minDist = -1000;
125-
const maxDist = 1000;
124+
// const minDist = -1000;
125+
// const maxDist = 1000;
126+
const minDist = -1872;
127+
const maxDist = 1891;
128+
126129
// const minDist = -50;
127130
// const maxDist = 10;
128131

@@ -131,8 +134,7 @@ function UnifiedSortWorker() {
131134
const numVertices = textureSize * textureSize;
132135

133136
// calculate number of bits needed to store sorting result
134-
// const compareBits = Math.max(10, Math.min(20, Math.round(Math.log2(numVertices / 4))));
135-
const compareBits = 20;
137+
const compareBits = Math.max(10, Math.min(20, Math.round(Math.log2(numVertices / 4))));
136138

137139
const bucketCount = 2 ** compareBits + 1;
138140

0 commit comments

Comments
 (0)