Skip to content

Commit 9447a7e

Browse files
SH support in NME GS (#16625)
1 parent 1db825f commit 9447a7e

File tree

10 files changed

+108
-23
lines changed

10 files changed

+108
-23
lines changed

packages/dev/core/src/Materials/GaussianSplatting/gaussianSplattingMaterial.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ export class GaussianSplattingMaterial extends PushMaterial {
199199
"invViewport",
200200
"dataTextureSize",
201201
"focal",
202-
"vEyePosition",
202+
"eyePosition",
203203
"kernelSize",
204204
];
205205
const samplers = ["covariancesATexture", "covariancesBTexture", "centersTexture", "colorsTexture", "shTexture0", "shTexture1", "shTexture2"];
@@ -293,10 +293,7 @@ export class GaussianSplattingMaterial extends PushMaterial {
293293

294294
effect.setFloat2("focal", focal, focal);
295295
effect.setFloat("kernelSize", gsMaterial && gsMaterial.kernelSize ? gsMaterial.kernelSize : GaussianSplattingMaterial.KernelSize);
296-
297-
// vEyePosition doesn't get automatially bound on MacOS with Chromium for no apparent reason.
298-
// Binding it manually here instead. Remove next line when SH rendering is fine on that platform.
299-
scene.bindEyePosition(effect);
296+
scene.bindEyePosition(effect, "eyePosition", true);
300297

301298
if (gsMesh.covariancesATexture) {
302299
const textureSize = gsMesh.covariancesATexture.getSize();

packages/dev/core/src/Materials/Node/Blocks/GaussianSplatting/gaussianBlock.ts

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ export class GaussianBlock extends NodeMaterialBlock {
2222
this.registerInput("splatColor", NodeMaterialBlockConnectionPointTypes.Color4, false, NodeMaterialBlockTargets.Fragment);
2323

2424
this.registerOutput("rgba", NodeMaterialBlockConnectionPointTypes.Color4, NodeMaterialBlockTargets.Fragment);
25+
this.registerOutput("rgb", NodeMaterialBlockConnectionPointTypes.Color3, NodeMaterialBlockTargets.Fragment);
26+
this.registerOutput("alpha", NodeMaterialBlockConnectionPointTypes.Float, NodeMaterialBlockTargets.Fragment);
2527
}
2628

2729
/**
@@ -46,6 +48,20 @@ export class GaussianBlock extends NodeMaterialBlock {
4648
return this._outputs[0];
4749
}
4850

51+
/**
52+
* Gets the rgb output component
53+
*/
54+
public get rgb(): NodeMaterialConnectionPoint {
55+
return this._outputs[1];
56+
}
57+
58+
/**
59+
* Gets the alpha output component
60+
*/
61+
public get alpha(): NodeMaterialConnectionPoint {
62+
return this._outputs[2];
63+
}
64+
4965
/**
5066
* Initialize the block and prepare the context for build
5167
* @param state defines the state that will be used for the build
@@ -68,15 +84,23 @@ export class GaussianBlock extends NodeMaterialBlock {
6884
state._emitFunctionFromInclude("fogFragmentDeclaration", comments);
6985
state._emitFunctionFromInclude("gaussianSplattingFragmentDeclaration", comments);
7086
state._emitVaryingFromString("vPosition", NodeMaterialBlockConnectionPointTypes.Vector2);
87+
88+
const tempSplatColor = state._getFreeVariableName("tempSplatColor");
7189
const color = this.splatColor;
72-
const output = this._outputs[0];
90+
const rgba = this._outputs[0];
91+
const rgb = this._outputs[1];
92+
const alpha = this._outputs[2];
7393

7494
if (state.shaderLanguage === ShaderLanguage.WGSL) {
75-
state.compilationString += `${state._declareOutput(output)} = gaussianColor(${color.associatedVariableName}, input.vPosition);\n`;
95+
state.compilationString += `let ${tempSplatColor}:vec4f = gaussianColor(${color.associatedVariableName}, input.vPosition);\n`;
7696
} else {
77-
state.compilationString += `${state._declareOutput(output)} = gaussianColor(${color.associatedVariableName});\n`;
97+
state.compilationString += `vec4 ${tempSplatColor} = gaussianColor(${color.associatedVariableName});\n`;
7898
}
7999

100+
state.compilationString += `${state._declareOutput(rgba)} = ${tempSplatColor}.rgba;`;
101+
state.compilationString += `${state._declareOutput(rgb)} = ${tempSplatColor}.rgb;`;
102+
state.compilationString += `${state._declareOutput(alpha)} = ${tempSplatColor}.a;`;
103+
80104
return this;
81105
}
82106
}

packages/dev/core/src/Materials/Node/Blocks/GaussianSplatting/gaussianSplattingBlock.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import { NodeMaterialBlockTargets } from "../../Enums/nodeMaterialBlockTargets";
55
import type { NodeMaterialConnectionPoint } from "../../nodeMaterialBlockConnectionPoint";
66
import { RegisterClass } from "../../../../Misc/typeStore";
77
import { VertexBuffer } from "core/Meshes/buffer";
8+
import type { GaussianSplattingMesh } from "core/Meshes/GaussianSplatting/gaussianSplattingMesh";
89
import { ShaderLanguage } from "core/Materials/shaderLanguage";
10+
import type { AbstractMesh } from "core/Meshes/abstractMesh";
11+
import type { NodeMaterial, NodeMaterialDefines } from "../../nodeMaterial";
912

1013
/**
1114
* Block used for the Gaussian Splatting
@@ -27,6 +30,7 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
2730
this.registerInput("projection", NodeMaterialBlockConnectionPointTypes.Matrix, false, NodeMaterialBlockTargets.Vertex);
2831

2932
this.registerOutput("splatVertex", NodeMaterialBlockConnectionPointTypes.Vector4, NodeMaterialBlockTargets.Vertex);
33+
this.registerOutput("SH", NodeMaterialBlockConnectionPointTypes.Color3, NodeMaterialBlockTargets.Vertex);
3034
}
3135

3236
/**
@@ -79,6 +83,13 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
7983
return this._outputs[0];
8084
}
8185

86+
/**
87+
* Gets the SH output contribution
88+
*/
89+
public get SH(): NodeMaterialConnectionPoint {
90+
return this._outputs[1];
91+
}
92+
8293
/**
8394
* Initialize the block and prepare the context for build
8495
* @param state defines the state that will be used for the build
@@ -87,6 +98,18 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
8798
state._excludeVariableName("focal");
8899
state._excludeVariableName("invViewport");
89100
state._excludeVariableName("kernelSize");
101+
state._excludeVariableName("eyePosition");
102+
}
103+
/**
104+
* Update defines for shader compilation
105+
* @param mesh defines the mesh to be rendered
106+
* @param nodeMaterial defines the node material requesting the update
107+
* @param defines defines the material defines to update
108+
*/
109+
public override prepareDefines(mesh: AbstractMesh, nodeMaterial: NodeMaterial, defines: NodeMaterialDefines) {
110+
if (mesh.getClassName() == "GaussianSplattingMesh") {
111+
defines.setValue("SH_DEGREE", (<GaussianSplattingMesh>mesh).shDegree, true);
112+
}
90113
}
91114

92115
protected override _buildBlock(state: NodeMaterialBuildState) {
@@ -96,12 +119,16 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
96119
return;
97120
}
98121

122+
state.sharedData.blocksWithDefines.push(this);
123+
99124
const comments = `//${this.name}`;
100125
state._emitFunctionFromInclude("gaussianSplattingVertexDeclaration", comments);
101126
state._emitFunctionFromInclude("gaussianSplatting", comments);
127+
state._emitFunctionFromInclude("helperFunctions", comments);
102128
state._emitUniformFromString("focal", NodeMaterialBlockConnectionPointTypes.Vector2);
103129
state._emitUniformFromString("invViewport", NodeMaterialBlockConnectionPointTypes.Vector2);
104130
state._emitUniformFromString("kernelSize", NodeMaterialBlockConnectionPointTypes.Float);
131+
state._emitUniformFromString("eyePosition", NodeMaterialBlockConnectionPointTypes.Vector3);
105132
state.attributes.push(VertexBuffer.PositionKind);
106133
state.sharedData.nodeMaterial.backFaceCulling = false;
107134

@@ -111,6 +138,7 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
111138
const view = this.view;
112139
const projection = this.projection;
113140
const output = this.splatVertex;
141+
const sh = this.SH;
114142

115143
const addF = state.fSuffix;
116144
let splatScaleParameter = `vec2${addF}(1.,1.)`;
@@ -124,6 +152,28 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
124152
input = "input.position";
125153
uniforms = ", uniforms.focal, uniforms.invViewport, uniforms.kernelSize";
126154
}
155+
if (this.SH.isConnected) {
156+
state.compilationString += `#if SH_DEGREE > 0\n`;
157+
158+
if (state.shaderLanguage === ShaderLanguage.WGSL) {
159+
state.compilationString += `let worldRot: mat3x3f = mat3x3f(${world.associatedVariableName}[0].xyz, ${world.associatedVariableName}[1].xyz, ${world.associatedVariableName}[2].xyz);`;
160+
state.compilationString += `let normWorldRot: mat3x3f = inverseMat3(worldRot);`;
161+
state.compilationString += `var dir: vec3f = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - uniforms.eyePosition));\n`;
162+
} else {
163+
state.compilationString += `mat3 worldRot = mat3(${world.associatedVariableName});`;
164+
state.compilationString += `mat3 normWorldRot = inverseMat3(worldRot);`;
165+
state.compilationString += `vec3 dir = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - eyePosition));\n`;
166+
}
167+
168+
state.compilationString += `dir *= vec3${addF}(1.,1.,-1.);\n`;
169+
state.compilationString += `${state._declareOutput(sh)} = computeSH(splat, dir);\n`;
170+
state.compilationString += `#else\n`;
171+
state.compilationString += `${state._declareOutput(sh)} = vec3${addF}(0.,0.,0.);\n`;
172+
state.compilationString += `#endif;\n`;
173+
} else {
174+
state.compilationString += `${state._declareOutput(sh)} = vec3${addF}(0.,0.,0.);`;
175+
}
176+
127177
state.compilationString += `${state._declareOutput(output)} = gaussianSplatting(${input}, ${splatPosition.associatedVariableName}, ${splatScaleParameter}, covA, covB, ${world.associatedVariableName}, ${view.associatedVariableName}, ${projection.associatedVariableName}${uniforms});\n`;
128178
return this;
129179
}

packages/dev/core/src/Materials/Node/Blocks/GaussianSplatting/splatReaderBlock.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ export class SplatReaderBlock extends NodeMaterialBlock {
9696
state._emit2DSampler("covariancesBTexture");
9797
state._emit2DSampler("centersTexture");
9898
state._emit2DSampler("colorsTexture");
99+
state._emit2DSampler("shTexture0", "SH_DEGREE > 0", undefined, undefined, true, "highp");
100+
state._emit2DSampler("shTexture1", "SH_DEGREE > 0", undefined, undefined, true, "highp");
101+
state._emit2DSampler("shTexture2", "SH_DEGREE > 0", undefined, undefined, true, "highp");
99102

100103
state._emitFunctionFromInclude("gaussianSplattingVertexDeclaration", comments);
101104
state._emitFunctionFromInclude("gaussianSplatting", comments);

packages/dev/core/src/Materials/Node/nodeMaterialBuildState.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,20 @@ export class NodeMaterialBuildState {
296296
/**
297297
* @internal
298298
*/
299-
public _emit2DSampler(name: string, define = "", force = false, annotation?: string) {
299+
public _emit2DSampler(name: string, define = "", force = false, annotation?: string, unsignedSampler?: boolean, precision?: string) {
300300
if (this.samplers.indexOf(name) < 0 || force) {
301301
if (define) {
302302
this._samplerDeclaration += `#if ${define}\n`;
303303
}
304304

305305
if (this.shaderLanguage === ShaderLanguage.WGSL) {
306+
const unsignedSamplerPrefix = unsignedSampler ? "u" : "f";
306307
this._samplerDeclaration += `var ${name + Constants.AUTOSAMPLERSUFFIX}: sampler;\n`;
307-
this._samplerDeclaration += `var ${name}: texture_2d<f32>;\n`;
308+
this._samplerDeclaration += `var ${name}: texture_2d<${unsignedSamplerPrefix}32>;\n`;
308309
} else {
309-
this._samplerDeclaration += `uniform sampler2D ${name}; ${annotation ? annotation : ""}\n`;
310+
const unsignedSamplerPrefix = unsignedSampler ? "u" : "";
311+
const precisionDecl = precision ?? "";
312+
this._samplerDeclaration += `uniform ${precisionDecl} ${unsignedSamplerPrefix}sampler2D ${name}; ${annotation ? annotation : ""}\n`;
310313
}
311314

312315
if (define) {

packages/dev/core/src/Materials/Node/nodeMaterialDefault.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { Texture } from "../Textures/texture";
1313
import { Tools } from "core/Misc/tools";
1414
import { CurrentScreenBlock } from "./Blocks/Dual/currentScreenBlock";
1515
import { Color4 } from "core/Maths/math.color";
16+
import { AddBlock } from "./Blocks/addBlock";
1617

1718
/**
1819
* Clear the material and set it to a default state for gaussian splatting
@@ -56,13 +57,19 @@ export function SetToDefaultGaussianSplatting(nodeMaterial: NodeMaterial): void
5657
view.connectTo(gs, { input: "view" });
5758
projection.connectTo(gs, { input: "projection" });
5859

60+
const addBlock = new AddBlock("Add SH");
61+
5962
// from color to gaussian color
6063
const gaussian = new GaussianBlock("Gaussian");
6164
splatReader.connectTo(gaussian, { input: "splatColor", output: "splatColor" });
6265

6366
// fragment and vertex outputs
6467
const fragmentOutput = new FragmentOutputBlock("FragmentOutput");
65-
gaussian.connectTo(fragmentOutput);
68+
69+
gs.SH.connectTo(addBlock.left);
70+
gaussian.rgb.connectTo(addBlock.right);
71+
addBlock.output.connectTo(fragmentOutput.rgb);
72+
gaussian.alpha.connectTo(fragmentOutput.a);
6673

6774
const vertexOutput = new VertexOutputBlock("VertexOutput");
6875
gs.connectTo(vertexOutput);

packages/dev/core/src/Shaders/ShadersInclude/gaussianSplatting.fx

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,11 @@ vec4 decompose(uint value)
126126
return components * vec4(2./255.) - vec4(1.);
127127
}
128128

129-
vec3 computeSH(Splat splat, vec3 color, vec3 dir)
129+
vec3 computeSH(Splat splat, vec3 dir)
130130
{
131131
vec3 sh[16];
132132

133-
sh[0] = color;
134-
133+
sh[0] = vec3(0.,0.,0.);
135134
#if SH_DEGREE > 0
136135
vec4 sh00 = decompose(splat.sh0.x);
137136
vec4 sh01 = decompose(splat.sh0.y);
@@ -172,9 +171,9 @@ vec3 computeSH(Splat splat, vec3 color, vec3 dir)
172171
return computeColorFromSHDegree(dir, sh);
173172
}
174173
#else
175-
vec3 computeSH(Splat splat, vec3 color, vec3 dir)
174+
vec3 computeSH(Splat splat, vec3 dir)
176175
{
177-
return color;
176+
return vec3(0.,0.,0.);
178177
}
179178
#endif
180179

packages/dev/core/src/Shaders/gaussianSplatting.vertex.fx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ uniform vec2 invViewport;
1818
uniform vec2 dataTextureSize;
1919
uniform vec2 focal;
2020
uniform float kernelSize;
21+
uniform vec3 eyePosition;
2122

2223
uniform sampler2D covariancesATexture;
2324
uniform sampler2D covariancesBTexture;
@@ -54,9 +55,9 @@ void main () {
5455
mat3 worldRot = mat3(world);
5556
mat3 normWorldRot = inverseMat3(worldRot);
5657

57-
vec3 dir = normalize(normWorldRot * (worldPos.xyz - vEyePosition.xyz));
58+
vec3 dir = normalize(normWorldRot * (worldPos.xyz - eyePosition));
5859
dir *= vec3(1.,1.,-1.); // convert to Babylon Space
59-
vColor.xyz = computeSH(splat, splat.color.xyz, dir);
60+
vColor.xyz = splat.color.xyz + computeSH(splat, dir);
6061
#endif
6162

6263
gl_Position = gaussianSplatting(position, worldPos.xyz, vec2(1.,1.), covA, covB, world, view, projection);

packages/dev/core/src/ShadersWGSL/ShadersInclude/gaussianSplatting.fx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ fn decompose(value: u32) -> vec4f
111111
return components * vec4f(2./255.) - vec4f(1.);
112112
}
113113

114-
fn computeSH(splat: Splat, color: vec3f, dir: vec3f) -> vec3f
114+
fn computeSH(splat: Splat, dir: vec3f) -> vec3f
115115
{
116116
var sh: array<vec3<f32>, 16>;
117117

118-
sh[0] = color;
118+
sh[0] = vec3f(0., 0., 0.);
119119

120120
#if SH_DEGREE > 0
121121
let sh00: vec4f = decompose(splat.sh0.x);

packages/dev/core/src/ShadersWGSL/gaussianSplatting.vertex.fx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ uniform invViewport: vec2f;
1515
uniform dataTextureSize: vec2f;
1616
uniform focal: vec2f;
1717
uniform kernelSize: f32;
18+
uniform eyePosition: vec3f;
1819

1920
// textures
2021
var covariancesATexture: texture_2d<f32>;
@@ -51,9 +52,9 @@ fn main(input : VertexInputs) -> FragmentInputs {
5152
let worldRot: mat3x3f = mat3x3f(mesh.world[0].xyz, mesh.world[1].xyz, mesh.world[2].xyz);
5253
let normWorldRot: mat3x3f = inverseMat3(worldRot);
5354

54-
var dir: vec3f = normalize(normWorldRot * (worldPos.xyz - scene.vEyePosition.xyz));
55+
var dir: vec3f = normalize(normWorldRot * (worldPos.xyz - uniforms.eyePosition.xyz));
5556
dir *= vec3f(1.,1.,-1.); // convert to Babylon Space
56-
vertexOutputs.vColor = vec4f(computeSH(splat, splat.color.xyz, dir), splat.color.w);
57+
vertexOutputs.vColor = vec4f(splat.color.xyz + computeSH(splat, dir), splat.color.w);
5758
#else
5859
vertexOutputs.vColor = splat.color;
5960
#endif

0 commit comments

Comments
 (0)