Skip to content

Commit 40e5f1c

Browse files
committed
port link-time-constant-array-size to use slang-rhi
1 parent 9b1401f commit 40e5f1c

File tree

2 files changed

+38
-71
lines changed

2 files changed

+38
-71
lines changed

lock

Whitespace-only changes.

tools/gfx-unit-test/link-time-constant-array-size.cpp

Lines changed: 38 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#include "core/slang-basic.h"
22
#include "core/slang-blob.h"
33
#include "gfx-test-util.h"
4-
#include "gfx-util/shader-cursor.h"
5-
#include "slang-gfx.h"
4+
#include <slang-rhi/shader-cursor.h>
5+
#include <slang-rhi.h>
66
#include "unit-test/slang-unit-test.h"
77

8-
using namespace gfx;
8+
using namespace rhi;
99

1010
namespace gfx_test
1111
{
1212
static Slang::Result loadProgram(
13-
gfx::IDevice* device,
14-
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
13+
IDevice* device,
14+
Slang::ComPtr<IShaderProgram>& outShaderProgram,
1515
const char* mainModuleName,
1616
const char* libModuleName,
1717
const char* entryPointName,
@@ -64,10 +64,10 @@ static Slang::Result loadProgram(
6464
slangReflection = composedProgram->getLayout();
6565

6666
// Create shader program
67-
gfx::IShaderProgram::Desc programDesc = {};
67+
ShaderProgramDesc programDesc = {};
6868
programDesc.slangGlobalScope = composedProgram.get();
6969

70-
auto shaderProgram = device->createProgram(programDesc);
70+
auto shaderProgram = device->createShaderProgram(programDesc);
7171

7272
outShaderProgram = shaderProgram;
7373
return SLANG_OK;
@@ -80,20 +80,18 @@ static void validateArraySizeInStruct(
8080
int expectedSize)
8181
{
8282
// Check reflection is available
83-
SLANG_CHECK(slangReflection != nullptr);
83+
GFX_CHECK_CALL_ABORT(slangReflection != nullptr);
8484

8585
// Get the global scope layout
8686
auto globalScope = slangReflection->getGlobalParamsVarLayout();
87-
SLANG_CHECK_MSG(globalScope != nullptr, "Could not get global scope layout");
87+
GFX_CHECK_CALL_ABORT(globalScope != nullptr);
8888

8989
auto typeLayout = globalScope->getTypeLayout();
90-
SLANG_CHECK_MSG(typeLayout != nullptr, "Global scope has no type layout");
90+
GFX_CHECK_CALL_ABORT(typeLayout != nullptr);
9191

9292
// Check if the global scope is a struct type
9393
auto kind = typeLayout->getKind();
94-
SLANG_CHECK_MSG(
95-
kind == slang::TypeReflection::Kind::Struct,
96-
"Global scope is not a struct type");
94+
GFX_CHECK_CALL_ABORT(kind == slang::TypeReflection::Kind::Struct);
9795

9896
// Find the buffer resource 'b'
9997
bool foundBuffer = false;
@@ -110,23 +108,19 @@ static void validateArraySizeInStruct(
110108

111109
// Get the type layout of the field
112110
auto fieldTypeLayout = fieldLayout->getTypeLayout();
113-
SLANG_CHECK_MSG(fieldTypeLayout != nullptr, "Field has no type layout");
111+
GFX_CHECK_CALL_ABORT(fieldTypeLayout != nullptr);
114112

115113
// Get the element type of the structured buffer
116114
auto elementTypeLayout = fieldTypeLayout->getElementTypeLayout();
117-
SLANG_CHECK_MSG(
118-
elementTypeLayout != nullptr,
119-
"Structured buffer has no element type layout");
115+
GFX_CHECK_CALL_ABORT(elementTypeLayout != nullptr);
120116

121117
// Check if it's a struct type
122118
auto elementKind = elementTypeLayout->getKind();
123-
SLANG_CHECK_MSG(
124-
elementKind == slang::TypeReflection::Kind::Struct,
125-
"Buffer element is not a struct type");
119+
GFX_CHECK_CALL_ABORT(elementKind == slang::TypeReflection::Kind::Struct);
126120

127121
// Get the field count of the struct
128122
auto structFieldCount = elementTypeLayout->getFieldCount();
129-
SLANG_CHECK_MSG(structFieldCount >= 1, "Struct has no fields");
123+
GFX_CHECK_CALL_ABORT(structFieldCount >= 1);
130124

131125
// Check for the 'xs' field
132126
bool foundXsField = false;
@@ -143,46 +137,33 @@ static void validateArraySizeInStruct(
143137
auto structFieldTypeLayout = structField->getTypeLayout();
144138
auto structFieldTypeKind = structFieldTypeLayout->getKind();
145139

146-
SLANG_CHECK_MSG(
147-
structFieldTypeKind == slang::TypeReflection::Kind::Array,
148-
"Field 'xs' is not an array type");
140+
GFX_CHECK_CALL_ABORT(structFieldTypeKind == slang::TypeReflection::Kind::Array);
149141

150142
// Check the array size
151143
auto arraySize = structFieldTypeLayout->getElementCount();
152144
// 0 becuase we haven't resolved the constant
153-
SLANG_CHECK_MSG(
154-
arraySize == 0,
155-
"Field 'xs' array size does not match expected size");
145+
GFX_CHECK_CALL_ABORT(arraySize == 0);
156146

157147
// 4 because we're resolving it
158148
const auto resolvedArraySize =
159149
structFieldTypeLayout->getElementCount(slangReflection);
160-
SLANG_CHECK_MSG(
161-
resolvedArraySize == expectedSize,
162-
"Field 'xs' array size does not match expected size");
150+
GFX_CHECK_CALL_ABORT(resolvedArraySize == expectedSize);
163151

164152
break;
165153
}
166154
}
167155

168-
SLANG_CHECK_MSG(foundXsField, "Could not find field 'xs' in struct S");
156+
GFX_CHECK_CALL_ABORT(foundXsField);
169157
break;
170158
}
171159
}
172160

173-
SLANG_CHECK_MSG(foundBuffer, "Could not find buffer 'b' in global scope");
161+
GFX_CHECK_CALL_ABORT(foundBuffer);
174162
}
175163

176164

177165
void linkTimeConstantArraySizeTestImpl(IDevice* device, UnitTestContext* context)
178166
{
179-
// Create transient heap
180-
Slang::ComPtr<ITransientResourceHeap> transientHeap;
181-
ITransientResourceHeap::Desc transientHeapDesc = {};
182-
transientHeapDesc.constantBufferSize = 4096;
183-
GFX_CHECK_CALL_ABORT(
184-
device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef()));
185-
186167
// Load and link program
187168
ComPtr<IShaderProgram> shaderProgram;
188169
slang::ProgramLayout* slangReflection;
@@ -200,70 +181,56 @@ void linkTimeConstantArraySizeTestImpl(IDevice* device, UnitTestContext* context
200181
validateArraySizeInStruct(context, slangReflection, N);
201182

202183
// Create compute pipeline
203-
ComputePipelineStateDesc pipelineDesc = {};
184+
ComputePipelineDesc pipelineDesc = {};
204185
pipelineDesc.program = shaderProgram.get();
205-
ComPtr<gfx::IPipelineState> pipelineState;
186+
ComPtr<IComputePipeline> pipelineState;
206187
GFX_CHECK_CALL_ABORT(
207-
device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
188+
device->createComputePipeline(pipelineDesc, pipelineState.writeRef()));
208189

209190
// Create buffer for struct S with array of size N
210191
int32_t initialData[] = {1, 2, 3, 4};
211-
IBufferResource::Desc bufferDesc = {};
212-
bufferDesc.sizeInBytes = N * sizeof(int32_t);
213-
bufferDesc.format = gfx::Format::Unknown;
192+
BufferDesc bufferDesc = {};
193+
bufferDesc.size = N * sizeof(int32_t);
194+
bufferDesc.format = Format::Undefined;
214195
bufferDesc.elementSize = sizeof(int32_t);
215-
bufferDesc.allowedStates = ResourceStateSet(
216-
ResourceState::ShaderResource,
217-
ResourceState::UnorderedAccess,
218-
ResourceState::CopyDestination,
219-
ResourceState::CopySource);
196+
bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | BufferUsage::CopyDestination | BufferUsage::CopySource;
220197
bufferDesc.defaultState = ResourceState::UnorderedAccess;
221198
bufferDesc.memoryType = MemoryType::DeviceLocal;
222199

223-
ComPtr<IBufferResource> numbersBuffer;
200+
ComPtr<IBuffer> numbersBuffer;
224201
GFX_CHECK_CALL_ABORT(
225-
device->createBufferResource(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
226-
227-
ComPtr<IResourceView> bufferView;
228-
IResourceView::Desc viewDesc = {};
229-
viewDesc.type = IResourceView::Type::UnorderedAccess;
230-
viewDesc.format = Format::Unknown;
231-
GFX_CHECK_CALL_ABORT(
232-
device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef()));
202+
device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
233203

234204
// Record and execute command buffer
235205
{
236-
ICommandQueue::Desc queueDesc = {ICommandQueue::QueueType::Graphics};
237-
auto queue = device->createCommandQueue(queueDesc);
238-
239-
auto commandBuffer = transientHeap->createCommandBuffer();
240-
auto encoder = commandBuffer->encodeComputeCommands();
206+
auto queue = device->getQueue(QueueType::Graphics);
207+
auto commandEncoder = queue->createCommandEncoder();
208+
auto encoder = commandEncoder->beginComputePass();
241209

242210
auto rootObject = encoder->bindPipeline(pipelineState);
243211

244212
ShaderCursor rootCursor(rootObject);
245-
rootCursor.getPath("b").setResource(bufferView);
213+
rootCursor.getPath("b").setBinding(Binding(numbersBuffer));
246214

247215
encoder->dispatchCompute(1, 1, 1);
248-
encoder->endEncoding();
249-
commandBuffer->close();
250-
queue->executeCommandBuffer(commandBuffer);
216+
encoder->end();
217+
queue->submit(commandEncoder->finish());
251218
queue->waitOnHost();
252219
}
253220

254221
// Expected results: each element is input * N
255222
// With N=4 and inputs [1,2,3,4], expected output is [4,8,12,16]
256-
compareComputeResult(device, numbersBuffer, Slang::makeArray<int>(4, 8, 12, 16));
223+
compareComputeResult(device, numbersBuffer, std::array{4, 8, 12, 16});
257224
}
258225

259226
SLANG_UNIT_TEST(linkTimeConstantArraySizeD3D12)
260227
{
261-
runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
228+
runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, DeviceType::D3D12);
262229
}
263230

264231
SLANG_UNIT_TEST(linkTimeConstantArraySizeVulkan)
265232
{
266-
runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
233+
runTestImpl(linkTimeConstantArraySizeTestImpl, unitTestContext, DeviceType::Vulkan);
267234
}
268235

269236
} // namespace gfx_test

0 commit comments

Comments
 (0)