1
1
#include " core/slang-basic.h"
2
2
#include " core/slang-blob.h"
3
3
#include " gfx-test-util.h"
4
- #include " gfx-util /shader-cursor.h"
5
- #include " slang-gfx.h "
4
+ #include < slang-rhi /shader-cursor.h>
5
+ #include < slang-rhi.h >
6
6
#include " unit-test/slang-unit-test.h"
7
7
8
- using namespace gfx ;
8
+ using namespace rhi ;
9
9
10
10
namespace gfx_test
11
11
{
12
12
static Slang::Result loadProgram (
13
- gfx:: IDevice* device,
14
- Slang::ComPtr<gfx:: IShaderProgram>& outShaderProgram,
13
+ IDevice* device,
14
+ Slang::ComPtr<IShaderProgram>& outShaderProgram,
15
15
const char * mainModuleName,
16
16
const char * libModuleName,
17
17
const char * entryPointName,
@@ -64,10 +64,10 @@ static Slang::Result loadProgram(
64
64
slangReflection = composedProgram->getLayout ();
65
65
66
66
// Create shader program
67
- gfx::IShaderProgram::Desc programDesc = {};
67
+ ShaderProgramDesc programDesc = {};
68
68
programDesc.slangGlobalScope = composedProgram.get ();
69
69
70
- auto shaderProgram = device->createProgram (programDesc);
70
+ auto shaderProgram = device->createShaderProgram (programDesc);
71
71
72
72
outShaderProgram = shaderProgram;
73
73
return SLANG_OK;
@@ -80,20 +80,18 @@ static void validateArraySizeInStruct(
80
80
int expectedSize)
81
81
{
82
82
// Check reflection is available
83
- SLANG_CHECK (slangReflection != nullptr );
83
+ GFX_CHECK_CALL_ABORT (slangReflection != nullptr );
84
84
85
85
// Get the global scope layout
86
86
auto globalScope = slangReflection->getGlobalParamsVarLayout ();
87
- SLANG_CHECK_MSG (globalScope != nullptr , " Could not get global scope layout " );
87
+ GFX_CHECK_CALL_ABORT (globalScope != nullptr );
88
88
89
89
auto typeLayout = globalScope->getTypeLayout ();
90
- SLANG_CHECK_MSG (typeLayout != nullptr , " Global scope has no type layout " );
90
+ GFX_CHECK_CALL_ABORT (typeLayout != nullptr );
91
91
92
92
// Check if the global scope is a struct type
93
93
auto kind = typeLayout->getKind ();
94
- SLANG_CHECK_MSG (
95
- kind == slang::TypeReflection::Kind::Struct,
96
- " Global scope is not a struct type" );
94
+ GFX_CHECK_CALL_ABORT (kind == slang::TypeReflection::Kind::Struct);
97
95
98
96
// Find the buffer resource 'b'
99
97
bool foundBuffer = false ;
@@ -110,23 +108,19 @@ static void validateArraySizeInStruct(
110
108
111
109
// Get the type layout of the field
112
110
auto fieldTypeLayout = fieldLayout->getTypeLayout ();
113
- SLANG_CHECK_MSG (fieldTypeLayout != nullptr , " Field has no type layout " );
111
+ GFX_CHECK_CALL_ABORT (fieldTypeLayout != nullptr );
114
112
115
113
// Get the element type of the structured buffer
116
114
auto elementTypeLayout = fieldTypeLayout->getElementTypeLayout ();
117
- SLANG_CHECK_MSG (
118
- elementTypeLayout != nullptr ,
119
- " Structured buffer has no element type layout" );
115
+ GFX_CHECK_CALL_ABORT (elementTypeLayout != nullptr );
120
116
121
117
// Check if it's a struct type
122
118
auto elementKind = elementTypeLayout->getKind ();
123
- SLANG_CHECK_MSG (
124
- elementKind == slang::TypeReflection::Kind::Struct,
125
- " Buffer element is not a struct type" );
119
+ GFX_CHECK_CALL_ABORT (elementKind == slang::TypeReflection::Kind::Struct);
126
120
127
121
// Get the field count of the struct
128
122
auto structFieldCount = elementTypeLayout->getFieldCount ();
129
- SLANG_CHECK_MSG (structFieldCount >= 1 , " Struct has no fields " );
123
+ GFX_CHECK_CALL_ABORT (structFieldCount >= 1 );
130
124
131
125
// Check for the 'xs' field
132
126
bool foundXsField = false ;
@@ -143,46 +137,33 @@ static void validateArraySizeInStruct(
143
137
auto structFieldTypeLayout = structField->getTypeLayout ();
144
138
auto structFieldTypeKind = structFieldTypeLayout->getKind ();
145
139
146
- SLANG_CHECK_MSG (
147
- structFieldTypeKind == slang::TypeReflection::Kind::Array,
148
- " Field 'xs' is not an array type" );
140
+ GFX_CHECK_CALL_ABORT (structFieldTypeKind == slang::TypeReflection::Kind::Array);
149
141
150
142
// Check the array size
151
143
auto arraySize = structFieldTypeLayout->getElementCount ();
152
144
// 0 becuase we haven't resolved the constant
153
- SLANG_CHECK_MSG (
154
- arraySize == 0 ,
155
- " Field 'xs' array size does not match expected size" );
145
+ GFX_CHECK_CALL_ABORT (arraySize == 0 );
156
146
157
147
// 4 because we're resolving it
158
148
const auto resolvedArraySize =
159
149
structFieldTypeLayout->getElementCount (slangReflection);
160
- SLANG_CHECK_MSG (
161
- resolvedArraySize == expectedSize,
162
- " Field 'xs' array size does not match expected size" );
150
+ GFX_CHECK_CALL_ABORT (resolvedArraySize == expectedSize);
163
151
164
152
break ;
165
153
}
166
154
}
167
155
168
- SLANG_CHECK_MSG (foundXsField, " Could not find field 'xs' in struct S " );
156
+ GFX_CHECK_CALL_ABORT (foundXsField);
169
157
break ;
170
158
}
171
159
}
172
160
173
- SLANG_CHECK_MSG (foundBuffer, " Could not find buffer 'b' in global scope " );
161
+ GFX_CHECK_CALL_ABORT (foundBuffer);
174
162
}
175
163
176
164
177
165
void linkTimeConstantArraySizeTestImpl (IDevice* device, UnitTestContext* context)
178
166
{
179
- // Create transient heap
180
- Slang::ComPtr<ITransientResourceHeap> transientHeap;
181
- ITransientResourceHeap::Desc transientHeapDesc = {};
182
- transientHeapDesc.constantBufferSize = 4096 ;
183
- GFX_CHECK_CALL_ABORT (
184
- device->createTransientResourceHeap (transientHeapDesc, transientHeap.writeRef ()));
185
-
186
167
// Load and link program
187
168
ComPtr<IShaderProgram> shaderProgram;
188
169
slang::ProgramLayout* slangReflection;
@@ -200,70 +181,56 @@ void linkTimeConstantArraySizeTestImpl(IDevice* device, UnitTestContext* context
200
181
validateArraySizeInStruct (context, slangReflection, N);
201
182
202
183
// Create compute pipeline
203
- ComputePipelineStateDesc pipelineDesc = {};
184
+ ComputePipelineDesc pipelineDesc = {};
204
185
pipelineDesc.program = shaderProgram.get ();
205
- ComPtr<gfx::IPipelineState > pipelineState;
186
+ ComPtr<IComputePipeline > pipelineState;
206
187
GFX_CHECK_CALL_ABORT (
207
- device->createComputePipelineState (pipelineDesc, pipelineState.writeRef ()));
188
+ device->createComputePipeline (pipelineDesc, pipelineState.writeRef ()));
208
189
209
190
// Create buffer for struct S with array of size N
210
191
int32_t initialData[] = {1 , 2 , 3 , 4 };
211
- IBufferResource::Desc bufferDesc = {};
212
- bufferDesc.sizeInBytes = N * sizeof (int32_t );
213
- bufferDesc.format = gfx:: Format::Unknown ;
192
+ BufferDesc bufferDesc = {};
193
+ bufferDesc.size = N * sizeof (int32_t );
194
+ bufferDesc.format = Format::Undefined ;
214
195
bufferDesc.elementSize = sizeof (int32_t );
215
- bufferDesc.allowedStates = ResourceStateSet (
216
- ResourceState::ShaderResource,
217
- ResourceState::UnorderedAccess,
218
- ResourceState::CopyDestination,
219
- ResourceState::CopySource);
196
+ bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | BufferUsage::CopyDestination | BufferUsage::CopySource;
220
197
bufferDesc.defaultState = ResourceState::UnorderedAccess;
221
198
bufferDesc.memoryType = MemoryType::DeviceLocal;
222
199
223
- ComPtr<IBufferResource > numbersBuffer;
200
+ ComPtr<IBuffer > numbersBuffer;
224
201
GFX_CHECK_CALL_ABORT (
225
- device->createBufferResource (bufferDesc, (void *)initialData, numbersBuffer.writeRef ()));
226
-
227
- ComPtr<IResourceView> bufferView;
228
- IResourceView::Desc viewDesc = {};
229
- viewDesc.type = IResourceView::Type::UnorderedAccess;
230
- viewDesc.format = Format::Unknown;
231
- GFX_CHECK_CALL_ABORT (
232
- device->createBufferView (numbersBuffer, nullptr , viewDesc, bufferView.writeRef ()));
202
+ device->createBuffer (bufferDesc, (void *)initialData, numbersBuffer.writeRef ()));
233
203
234
204
// Record and execute command buffer
235
205
{
236
- ICommandQueue::Desc queueDesc = {ICommandQueue::QueueType::Graphics};
237
- auto queue = device->createCommandQueue (queueDesc);
238
-
239
- auto commandBuffer = transientHeap->createCommandBuffer ();
240
- auto encoder = commandBuffer->encodeComputeCommands ();
206
+ auto queue = device->getQueue (QueueType::Graphics);
207
+ auto commandEncoder = queue->createCommandEncoder ();
208
+ auto encoder = commandEncoder->beginComputePass ();
241
209
242
210
auto rootObject = encoder->bindPipeline (pipelineState);
243
211
244
212
ShaderCursor rootCursor (rootObject);
245
- rootCursor.getPath (" b" ).setResource (bufferView );
213
+ rootCursor.getPath (" b" ).setBinding ( Binding (numbersBuffer) );
246
214
247
215
encoder->dispatchCompute (1 , 1 , 1 );
248
- encoder->endEncoding ();
249
- commandBuffer->close ();
250
- queue->executeCommandBuffer (commandBuffer);
216
+ encoder->end ();
217
+ queue->submit (commandEncoder->finish ());
251
218
queue->waitOnHost ();
252
219
}
253
220
254
221
// Expected results: each element is input * N
255
222
// With N=4 and inputs [1,2,3,4], expected output is [4,8,12,16]
256
- compareComputeResult (device, numbersBuffer, Slang::makeArray< int >( 4 , 8 , 12 , 16 ) );
223
+ compareComputeResult (device, numbersBuffer, std::array{ 4 , 8 , 12 , 16 } );
257
224
}
258
225
259
226
SLANG_UNIT_TEST (linkTimeConstantArraySizeD3D12)
260
227
{
261
- runTestImpl (linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag ::D3D12);
228
+ runTestImpl (linkTimeConstantArraySizeTestImpl, unitTestContext, DeviceType ::D3D12);
262
229
}
263
230
264
231
SLANG_UNIT_TEST (linkTimeConstantArraySizeVulkan)
265
232
{
266
- runTestImpl (linkTimeConstantArraySizeTestImpl, unitTestContext, Slang::RenderApiFlag ::Vulkan);
233
+ runTestImpl (linkTimeConstantArraySizeTestImpl, unitTestContext, DeviceType ::Vulkan);
267
234
}
268
235
269
236
} // namespace gfx_test
0 commit comments