Skip to content

Add __ir_bytes virtual value #8057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
9 changes: 9 additions & 0 deletions source/slang/slang-ast-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ class ReturnValExpr : public Expr
Scope* scope = nullptr;
};

// An array of the bytes in the current module.
FIDDLE()
class IRBytesExpr : public Expr
{
FIDDLE(...)

Scope* scope = nullptr;
};

// An expression that binds a temporary variable in a local expression context
FIDDLE()
class LetExpr : public Expr
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ast-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ struct ASTIterator
iterator->maybeDispatchCallback(expr);
}
void visitReturnValExpr(ReturnValExpr* expr) { iterator->maybeDispatchCallback(expr); }
void VisitIRBytesExpr(IRBytesExpr* expr) { iterator->maybeDispatchCallback(expr); }

void visitAndTypeExpr(AndTypeExpr* expr)
{
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-ast-print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,10 @@ void ASTPrinter::addExpr(Expr* expr)
{
sb << "__return_val";
}
else if (as<IRBytesExpr>(expr))
{
sb << "__ir_bytes";
}
else if (const auto letExpr = as<LetExpr>(expr))
{
sb << "let ";
Expand Down
7 changes: 7 additions & 0 deletions source/slang/slang-ast-val.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,13 @@ Val* FuncCallIntVal::_substituteImplOverride(
return this;
}

// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRBytesCountIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

void IRBytesCountIntVal::_toTextOverride(StringBuilder& out)
{
out << "__ir_bytes.getCount()";
}

// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SizeOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

void SizeOfIntVal::_toTextOverride(StringBuilder& out)
Expand Down
11 changes: 11 additions & 0 deletions source/slang/slang-ast-val.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ class FuncCallIntVal : public IntVal
Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
};

FIDDLE()
class IRBytesCountIntVal : public IntVal
{
FIDDLE(...)
IRBytesCountIntVal(Type* inType, Scope* inScope) { setOperands(inType, inScope); }

void _toTextOverride(StringBuilder& out);

bool _isLinkTimeValOverride() { return true; }
};

FIDDLE(abstract)
class SizeOfLikeIntVal : public IntVal
{
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5712,6 +5712,14 @@ Expr* SemanticsExprVisitor::visitReturnValExpr(ReturnValExpr* expr)
return expr;
}

Expr* SemanticsExprVisitor::visitIRBytesExpr(IRBytesExpr* expr)
{
auto count =
m_astBuilder->getOrCreate<IRBytesCountIntVal>(m_astBuilder->getIntType(), expr->scope);
expr->type = m_astBuilder->getArrayType(m_astBuilder->getUInt8Type(), count);
return expr;
}

Expr* SemanticsExprVisitor::visitAndTypeExpr(AndTypeExpr* expr)
{
// The left and right sides of an `&` for types must both be types.
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3120,6 +3120,7 @@ struct SemanticsExprVisitor : public SemanticsVisitor, ExprVisitor<SemanticsExpr
Expr* visitThisInterfaceExpr(ThisInterfaceExpr* expr);
Expr* visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr);
Expr* visitReturnValExpr(ReturnValExpr* expr);
Expr* visitIRBytesExpr(IRBytesExpr* expr);
Expr* visitAndTypeExpr(AndTypeExpr* expr);
Expr* visitPointerTypeExpr(PointerTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
Expand Down
25 changes: 25 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,31 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
m_writer->emit(getName(inst));
break;

case kIROp_BlobAsArray:
{
auto stringSlice = cast<IRBlobLit>(inst->getOperand(0))->getStringSlice();
m_writer->emit("{");
for (Index i = 0; i < stringSlice.getLength(); ++i)
{
// Wrap lines at 80 characters.
if (i % 16 == 0)
m_writer->emit("\n");
m_writer->emit("0x");
char c = stringSlice[i];
int hex[2] = {(c >> 4) & 0x0f, c & 0x0f};
for (Index j = 0; j < 2; ++j)
{
int digit = hex[j];
if (digit < 10)
m_writer->emitChar(digit + '0');
else
m_writer->emitChar(digit - 10 + 'A');
}
m_writer->emit(",");
}
m_writer->emit("\n}");
break;
}
case kIROp_MakeArray:
case kIROp_MakeStruct:
{
Expand Down
43 changes: 43 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,16 @@ Result linkAndOptimizeIR(
auto targetProgram = codeGenContext->getTargetProgram();
auto targetCompilerOptions = targetRequest->getOptionSet();

Dictionary<String, ComPtr<ISlangBlob>> moduleBlobs;
codeGenContext->getProgram()->enumerateModules(
[&](Module* module)
{
auto name = module->getIRModule()->getName()->text;
ComPtr<ISlangBlob> blob;
module->serialize(blob.writeRef());
moduleBlobs.add(name, blob);
});

// Get the artifact desc for the target
const auto artifactDesc = ArtifactDescUtil::makeDescForCompileTarget(asExternal(target));

Expand Down Expand Up @@ -1050,6 +1060,39 @@ Result linkAndOptimizeIR(

finalizeSpecialization(irModule);

{
List<IRInst*> removeList;
IRBuilder builder(irModule);
for (auto inst : irModule->getGlobalInsts())
{
switch (inst->getOp())
{
case kIROp_IRBytes:
{
auto moduleName = as<IRStringLit>(inst->getOperand(0))->getStringSlice();
auto& blob = moduleBlobs.getValue(moduleName);
IRInst* args[] = {builder.getBlobValue(blob)};
inst->replaceUsesWith(
builder.emitIntrinsicInst(inst->getFullType(), kIROp_BlobAsArray, 1, args));
removeList.add(inst);
break;
}
case kIROp_IRBytesCount:
{
auto moduleName = as<IRStringLit>(inst->getOperand(0))->getStringSlice();
auto& blob = moduleBlobs.getValue(moduleName);
inst->replaceUsesWith(builder.getIntValue(blob->getBufferSize()));
removeList.add(inst);
break;
}
}
}
for (auto inst : removeList)
{
inst->removeAndDeallocate();
}
}

// Lower `Result<T,E>` types into ordinary struct types. This must happen
// after specialization, since otherwise incompatible copies of the lowered
// result structure are generated.
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts-stable-names.lua
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,7 @@ return {
["SPIRVAsmOperand.__sampledType"] = 665,
["SPIRVAsmOperand.__imageType"] = 666,
["SPIRVAsmOperand.__sampledImageType"] = 667,
["BlobAsArray"] = 668,
["IRBytes"] = 669,
["IRBytesCount"] = 670,
}
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ local insts = {
min_operands = 1,
},
},
{ BlobAsArray = { min_operands = 1, hoistable = true } },
{ makeArray = {} },
{ makeArrayFromElement = { min_operands = 1 } },
{ makeCoopVector = {} },
Expand Down Expand Up @@ -1908,6 +1909,8 @@ local insts = {
{ IsSignedInt = { min_operands = 1 } },
{ IsVector = { min_operands = 1 } },
{ GetDynamicResourceHeap = { hoistable = true } },
{ IRBytes = { min_operands = 1, hoistable = true } },
{ IRBytesCount = { min_operands = 1, hoistable = true } },
{ ForwardDifferentiate = { min_operands = 1 } },
-- Produces the primal computation of backward derivatives, will return an intermediate context for
-- backward derivative func.
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ struct IRModule : RefObject
// anything to do with serialization format
//
const static UInt k_minSupportedModuleVersion = 1;
const static UInt k_maxSupportedModuleVersion = 1;
const static UInt k_maxSupportedModuleVersion = 2;
static_assert(k_minSupportedModuleVersion <= k_maxSupportedModuleVersion);

private:
Expand Down
20 changes: 20 additions & 0 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return resVal;
}

LoweredValInfo visitIRBytesCountIntVal(IRBytesCountIntVal*)
{
auto builder = context->irBuilder;
auto moduleName = builder->getModule()->getName()->text.getUnownedSlice();
IRInst* args[] = {builder->getStringValue(moduleName)};
return LoweredValInfo::simple(
builder->emitIntrinsicInst(builder->getIntType(), kIROp_IRBytesCount, 1, args));
}

LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
{
auto baseVal = lowerVal(context, val->getBase());
Expand Down Expand Up @@ -4610,6 +4619,17 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>

LoweredValInfo visitReturnValExpr(ReturnValExpr*) { return context->returnDestination; }

LoweredValInfo visitIRBytesExpr(IRBytesExpr*)
{
auto builder = context->irBuilder;
auto moduleName = builder->getModule()->getName()->text.getUnownedSlice();
IRInst* args[] = {builder->getStringValue(moduleName)};
auto type = builder->getArrayType(
builder->getUInt8Type(),
builder->emitIntrinsicInst(builder->getIntType(), kIROp_IRBytesCount, 1, args));
return LoweredValInfo::simple(builder->emitIntrinsicInst(type, kIROp_IRBytes, 1, args));
}

LoweredValInfo visitMemberExpr(MemberExpr* expr)
{
auto loweredType = lowerType(context, expr->type);
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6980,6 +6980,13 @@ static NodeBase* parseReturnValExpr(Parser* parser, void* /*userData*/)
return expr;
}

static NodeBase* parseIRBytesExpr(Parser* parser, void* /*userData*/)
{
IRBytesExpr* expr = parser->astBuilder->create<IRBytesExpr>();
expr->scope = parser->currentScope;
return expr;
}

static Expr* parseBoolLitExpr(Parser* parser, bool value)
{
BoolLiteralExpr* expr = parser->astBuilder->create<BoolLiteralExpr>();
Expand Down Expand Up @@ -9636,6 +9643,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = {
_makeParseExpr("true", parseTrueExpr),
_makeParseExpr("false", parseFalseExpr),
_makeParseExpr("__return_val", parseReturnValExpr),
_makeParseExpr("__ir_bytes", parseIRBytesExpr),
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
Expand Down
102 changes: 102 additions & 0 deletions tests/cpu-program/gfx-heterogeneous.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//TEST:EXECUTABLE:
import gfx;
import slang;

[shader("compute")]
[numthreads(4, 1, 1)]
void computeMain(
uint3 sv_dispatchThreadID: SV_DispatchThreadID,
uniform RWStructuredBuffer<float> buffer
)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = sv_dispatchThreadID.x;
}

export __extern_cpp int main()
{
gfx.DeviceDesc deviceDesc = {};
deviceDesc.deviceType = gfx.DeviceType.CPU;
Optional<gfx.IDevice> device;
gfx.gfxCreateDevice(&deviceDesc, device);
if (device == none)
{
printf("fail\n");
return -1;
}

gfx.CommandQueueDesc queueDesc = {gfx::QueueType::Graphics};
queueDesc.type = gfx.QueueType.Graphics;
Optional<gfx.ICommandQueue> queue;
device.value.createCommandQueue(&queueDesc, queue);

gfx.ShaderProgramDesc2 programDesc = {};
var s = __ir_bytes;
programDesc.sourceData = (void*) &s[0];
programDesc.sourceType = gfx.ShaderModuleSourceType.SlangModuleBinary;
programDesc.sourceDataSize = s.getCount();
programDesc.entryPointCount = 1;
NativeString entryPointName = "computeMain";
programDesc.entryPointNames = &entryPointName;
Optional<gfx.IShaderProgram> program;
Optional<slang.ISlangBlob> diagBlob;
device.value.createProgram2(&programDesc, program, diagBlob);

Optional<gfx.IPipelineState> pipeline;
gfx.ComputePipelineStateDesc pipelineDesc;
pipelineDesc.program = NativeRef<gfx.IShaderProgram>(program.value);
device.value.createComputePipelineState(&pipelineDesc, pipeline);

Optional<gfx.ITransientResourceHeap> transientHeap;
gfx.TransientResourceHeapDesc transientHeapDesc;
transientHeapDesc.constantBufferDescriptorCount = 64;
transientHeapDesc.constantBufferSize = 1024;
transientHeapDesc.srvDescriptorCount = 1024;
transientHeapDesc.uavDescriptorCount = 1024;
transientHeapDesc.samplerDescriptorCount = 256;
transientHeapDesc.accelerationStructureDescriptorCount = 32;
device.value.createTransientResourceHeap(&transientHeapDesc, transientHeap);

Optional<gfx.IBufferResource> buffer;
gfx.BufferResourceDesc bufferDesc = {};
bufferDesc.memoryType = gfx.MemoryType.DeviceLocal;
bufferDesc.allowedStates.add(gfx.ResourceState.UnorderedAccess);
bufferDesc.defaultState = gfx.ResourceState.UnorderedAccess;
bufferDesc.elementSize = 4;
bufferDesc.sizeInBytes = 256;
bufferDesc.type = gfx.ResourceType.Buffer;
device.value.createBufferResource(&bufferDesc, nullptr, buffer);

Optional<gfx.IResourceView> bufferView;
gfx.ResourceViewDesc viewDesc;
viewDesc.type = gfx.ResourceViewType.UnorderedAccess;
device.value.createBufferView(buffer.value, none, &viewDesc, bufferView);

Optional<gfx.ICommandBuffer> commandBuffer;
transientHeap.value.createCommandBuffer(commandBuffer);
Optional<gfx.IComputeCommandEncoder> encoder;
commandBuffer.value.encodeComputeCommands(encoder);
Optional<gfx.IShaderObject> rootObject;
encoder.value.bindPipeline(pipeline.value, rootObject);
Optional<gfx.IShaderObject> entryPointObject;
rootObject.value.getEntryPoint(0, entryPointObject);
gfx.ShaderOffset offset = {};
entryPointObject.value.setResource(&offset, bufferView.value);
encoder.value.dispatchCompute(1, 1, 1);
encoder.value.endEncoding();
commandBuffer.value.close();

NativeRef<gfx.ICommandBuffer> commandBufferRef = NativeRef<gfx.ICommandBuffer>(commandBuffer.value);
queue.value.executeCommandBuffers(1, &commandBufferRef, none, 0);
queue.value.waitOnHost();

Optional<slang.ISlangBlob> blob;
device.value.readBufferResource(buffer.value, 0, 16, blob);

for (int i = 0; i < 4; i++)
{
float val = ((float *)blob.value.getBufferPointer())[i];
printf("%.1f\n", val);
}
return 0;
}
Loading