Skip to content

Commit 8909a6a

Browse files
committed
bottleneck ir module reading and writing
1 parent b282c88 commit 8909a6a

File tree

5 files changed

+217
-189
lines changed

5 files changed

+217
-189
lines changed

source/slang/slang-serialize-container.cpp

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "slang-serialize-ast.h"
1313
#include "slang-serialize-ir.h"
1414
#include "slang-serialize-source-loc.h"
15+
#include "slang-serialize-types.h"
1516

1617
namespace Slang
1718
{
@@ -137,11 +138,10 @@ struct ModuleEncodingContext
137138
IRModule* irModule = targetProgram->getOrCreateIRModuleForLayout(sink);
138139

139140
// Okay, we need to serialize this target program and its IR too...
140-
IRSerialData serialData;
141-
IRSerialWriter writer;
142-
143-
SLANG_RETURN_ON_FAIL(writer.write(irModule, _sourceLocWriter, &serialData));
144-
SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _cursor));
141+
{
142+
// SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, PropertyKeys<IRModule>::IRModule);
143+
writeSerializedModuleIR(_cursor, irModule, _sourceLocWriter);
144+
}
145145

146146
return SLANG_OK;
147147
}
@@ -216,10 +216,8 @@ struct ModuleEncodingContext
216216
//
217217
if (auto irModule = module->getIRModule())
218218
{
219-
IRSerialData serialData;
220-
IRSerialWriter writer;
221-
SLANG_RETURN_ON_FAIL(writer.write(irModule, _sourceLocWriter, &serialData));
222-
SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _cursor));
219+
// SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, PropertyKeys<IRModule>::IRModule);
220+
writeSerializedModuleIR(_cursor, irModule, _sourceLocWriter);
223221
}
224222

225223
// If we have AST information available, then we serialize it here.
@@ -582,31 +580,41 @@ SlangResult decodeModuleIR(
582580
Session* session,
583581
SerialSourceLocReader* sourceLocReader)
584582
{
585-
// IR serialization still uses the older approach, where
586-
// data gets deserialized from the RIFF into an intermediate
587-
// data structure (`IRSerialData`), and then the actual
588-
// in-memory structures are created based on the intermediate.
589-
//
590-
// Thus we start by running the `IRSerialReader::readContainer`
591-
// logic to get the `IRSerialData` representation.
592-
//
593-
// TODO(tfoley): This should all get streamlined so that we
594-
// are deserializing IR nodes directly from the format written
595-
// into the RIFF.
596-
//
597-
IRSerialData serialData;
598-
SLANG_RETURN_ON_FAIL(IRSerialReader::readFrom(chunk, &serialData));
583+
return readSerializedModuleIR(chunk, session, sourceLocReader, outIRModule);
584+
}
599585

600-
// Next we read the actual IR representation out from the
601-
// `serialData`. This is the step that may pull source-location
602-
// information from the provided `sourceLocReader`.
603-
//
604-
IRSerialReader reader;
605-
SLANG_RETURN_ON_FAIL(reader.read(serialData, session, sourceLocReader, outIRModule));
586+
static void calcModuleInstructionList(IRModule* module, List<IRInst*>& instsOut)
587+
{
588+
// We reserve 0 for null
589+
instsOut.setCount(1);
590+
instsOut[0] = nullptr;
606591

607-
return SLANG_OK;
592+
// Stack for parentInst
593+
List<IRInst*> parentInstStack;
594+
595+
IRModuleInst* moduleInst = module->getModuleInst();
596+
parentInstStack.add(moduleInst);
597+
598+
// Add to list
599+
instsOut.add(moduleInst);
600+
601+
// Traverse all of the instructions
602+
while (parentInstStack.getCount())
603+
{
604+
// If it's in the stack it is assumed it is already in the inst map
605+
IRInst* parentInst = parentInstStack.getLast();
606+
parentInstStack.removeLast();
607+
608+
IRInstListBase childrenList = parentInst->getDecorationsAndChildren();
609+
for (IRInst* child : childrenList)
610+
{
611+
instsOut.add(child);
612+
parentInstStack.add(child);
613+
}
614+
}
608615
}
609616

617+
610618
/* static */ SlangResult SerialContainerUtil::verifyIRSerialize(
611619
IRModule* module,
612620
Session* session,
@@ -615,9 +623,8 @@ SlangResult decodeModuleIR(
615623
// Verify if we can stream out with raw source locs
616624

617625
List<IRInst*> originalInsts;
618-
IRSerialWriter::calcInstructionList(module, originalInsts);
619-
620-
IRSerialData irData;
626+
SLANG_ASSERT(false);
627+
calcModuleInstructionList(module, originalInsts);
621628

622629
OwnedMemoryStream memoryStream(FileAccess::ReadWrite);
623630

@@ -627,7 +634,6 @@ SlangResult decodeModuleIR(
627634

628635
// Need to put all of this in a module chunk
629636
SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(cursor, SerialBinary::kModuleFourCC);
630-
631637
RefPtr<SerialSourceLocWriter> sourceLocWriter;
632638

633639
if (options.sourceManagerToUseWhenSerializingSourceLocs)
@@ -636,12 +642,7 @@ SlangResult decodeModuleIR(
636642
new SerialSourceLocWriter(options.sourceManagerToUseWhenSerializingSourceLocs);
637643
}
638644

639-
{
640-
// Write IR out to `irData`
641-
IRSerialWriter writer;
642-
SLANG_RETURN_ON_FAIL(writer.write(module, sourceLocWriter, &irData));
643-
}
644-
SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(irData, cursor));
645+
writeSerializedModuleIR(cursor, module, sourceLocWriter);
645646

646647
// Write the debug info Riff container
647648
if (sourceLocWriter)
@@ -703,25 +704,13 @@ SlangResult decodeModuleIR(
703704
return SLANG_FAIL;
704705
}
705706

706-
{
707-
IRSerialData irReadData;
708-
IRSerialReader reader;
709-
SLANG_RETURN_ON_FAIL(reader.readFrom(irChunk, &irReadData));
710-
711-
// Check the stream read data is the same
712-
if (irData != irReadData)
713-
{
714-
SLANG_ASSERT(!"Streamed in data doesn't match");
715-
return SLANG_FAIL;
716-
}
717-
718-
SLANG_RETURN_ON_FAIL(reader.read(irData, session, sourceLocReader, irReadModule));
719-
}
707+
SLANG_RETURN_ON_FAIL(
708+
readSerializedModuleIR(irChunk, session, sourceLocReader, irReadModule));
720709
}
721710
}
722711

723712
List<IRInst*> readInsts;
724-
IRSerialWriter::calcInstructionList(irReadModule, readInsts);
713+
calcModuleInstructionList(irReadModule, readInsts);
725714

726715
if (readInsts.getCount() != originalInsts.getCount())
727716
{

source/slang/slang-serialize-container.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ struct StringChunk : RIFF::DataChunk
6868
String getValue() const;
6969
};
7070

71-
struct IRModuleChunk;
71+
struct IRModuleChunk : RIFF::ListChunk
72+
{
73+
};
7274

7375
struct ASTModuleChunk : RIFF::ListChunk
7476
{

source/slang/slang-serialize-ir.cpp

Lines changed: 146 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,102 @@
99
namespace Slang
1010
{
1111

12+
struct IRSerialWriter
13+
{
14+
typedef IRSerialData Ser;
15+
typedef IRSerialBinary Bin;
16+
17+
Result write(
18+
IRModule* module,
19+
SerialSourceLocWriter* sourceLocWriter,
20+
IRSerialData* serialData);
21+
22+
/// Write to a container
23+
static Result writeTo(const IRSerialData& data, RIFF::BuildCursor& cursor);
24+
25+
/// Get an instruction index from an instruction
26+
Ser::InstIndex getInstIndex(IRInst* inst) const
27+
{
28+
return inst ? Ser::InstIndex(m_instMap.getValue(inst)) : Ser::InstIndex(0);
29+
}
30+
31+
/// Get a slice from an index
32+
UnownedStringSlice getStringSlice(Ser::StringIndex index) const
33+
{
34+
return m_stringSlicePool.getSlice(StringSlicePool::Handle(index));
35+
}
36+
/// Get index from string representations
37+
Ser::StringIndex getStringIndex(StringRepresentation* string)
38+
{
39+
return Ser::StringIndex(m_stringSlicePool.add(string));
40+
}
41+
Ser::StringIndex getStringIndex(const UnownedStringSlice& slice)
42+
{
43+
return Ser::StringIndex(m_stringSlicePool.add(slice));
44+
}
45+
Ser::StringIndex getStringIndex(Name* name)
46+
{
47+
return name ? getStringIndex(name->text) : SerialStringData::kNullStringIndex;
48+
}
49+
Ser::StringIndex getStringIndex(const char* chars)
50+
{
51+
return Ser::StringIndex(m_stringSlicePool.add(chars));
52+
}
53+
Ser::StringIndex getStringIndex(const String& string)
54+
{
55+
return Ser::StringIndex(m_stringSlicePool.add(string.getUnownedSlice()));
56+
}
57+
58+
StringSlicePool& getStringPool() { return m_stringSlicePool; }
59+
60+
IRSerialWriter()
61+
: m_serialData(nullptr), m_stringSlicePool(StringSlicePool::Style::Default)
62+
{
63+
}
64+
65+
protected:
66+
void _addInstruction(IRInst* inst);
67+
Result _calcDebugInfo(SerialSourceLocWriter* sourceLocWriter);
68+
69+
List<IRInst*> m_insts; ///< Instructions in same order as stored in the
70+
71+
List<IRDecoration*>
72+
m_decorations; ///< Holds all decorations in order of the instructions as found
73+
List<IRInst*> m_instWithFirstDecoration; ///< All decorations are held in this order after all
74+
///< the regular instructions
75+
76+
Dictionary<IRInst*, Ser::InstIndex> m_instMap; ///< Map an instruction to an instruction index
77+
78+
StringSlicePool m_stringSlicePool;
79+
IRSerialData* m_serialData; ///< Where the data is stored
80+
};
81+
82+
struct IRSerialReader
83+
{
84+
typedef IRSerialData Ser;
85+
86+
/// Read a stream to fill in dataOut IRSerialData
87+
static Result readFrom(RIFF::ListChunk const* irModuleChunk, IRSerialData* outData);
88+
89+
/// Read a module from serial data
90+
Result read(
91+
const IRSerialData& data,
92+
Session* session,
93+
SerialSourceLocReader* sourceLocReader,
94+
RefPtr<IRModule>& outModule);
95+
96+
IRSerialReader()
97+
: m_serialData(nullptr), m_module(nullptr), m_stringTable(StringSlicePool::Style::Default)
98+
{
99+
}
100+
101+
protected:
102+
StringSlicePool m_stringTable;
103+
104+
const IRSerialData* m_serialData;
105+
IRModule* m_module;
106+
};
107+
12108
static bool _isConstant(IROp opIn)
13109
{
14110
const int op = (kIROpMask_OpMask & opIn);
@@ -359,37 +455,6 @@ Result _writeInstArrayChunk(
359455
return SLANG_OK;
360456
}
361457

362-
/* static */ void IRSerialWriter::calcInstructionList(IRModule* module, List<IRInst*>& instsOut)
363-
{
364-
// We reserve 0 for null
365-
instsOut.setCount(1);
366-
instsOut[0] = nullptr;
367-
368-
// Stack for parentInst
369-
List<IRInst*> parentInstStack;
370-
371-
IRModuleInst* moduleInst = module->getModuleInst();
372-
parentInstStack.add(moduleInst);
373-
374-
// Add to list
375-
instsOut.add(moduleInst);
376-
377-
// Traverse all of the instructions
378-
while (parentInstStack.getCount())
379-
{
380-
// If it's in the stack it is assumed it is already in the inst map
381-
IRInst* parentInst = parentInstStack.getLast();
382-
parentInstStack.removeLast();
383-
384-
IRInstListBase childrenList = parentInst->getDecorationsAndChildren();
385-
for (IRInst* child : childrenList)
386-
{
387-
instsOut.add(child);
388-
parentInstStack.add(child);
389-
}
390-
}
391-
}
392-
393458
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IRSerialReader !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
394459

395460
static Result _readInstArrayChunk(RIFF::DataChunk const* chunk, List<IRSerialData::Inst>& arrayOut)
@@ -399,7 +464,7 @@ static Result _readInstArrayChunk(RIFF::DataChunk const* chunk, List<IRSerialDat
399464
}
400465

401466
/* static */ Result IRSerialReader::readFrom(
402-
IRModuleChunk const* irModuleChunk,
467+
RIFF::ListChunk const* irModuleChunk,
403468
IRSerialData* outData)
404469
{
405470
typedef IRSerialBinary Bin;
@@ -758,4 +823,53 @@ Result IRSerialReader::read(
758823
return SLANG_OK;
759824
}
760825

826+
void writeSerializedModuleIR(
827+
RIFF::BuildCursor& cursor,
828+
IRModule* irModule,
829+
SerialSourceLocWriter* sourceLocWriter)
830+
{
831+
IRSerialData serialData;
832+
IRSerialWriter writer;
833+
834+
writer.write(irModule, sourceLocWriter, &serialData);
835+
IRSerialWriter::writeTo(serialData, cursor);
836+
}
837+
838+
SlangResult readSerializedModuleIR(
839+
RIFF::Chunk const* chunk,
840+
// [[maybe_unused]] ISlangBlob* blobHoldingSerializedData,
841+
Session* session,
842+
// [[maybe_unused]] DiagnosticSink* sink,
843+
SerialSourceLocReader* sourceLocReader,
844+
RefPtr<IRModule>& outIRModule)
845+
{
846+
// IR serialization still uses the older approach, where
847+
// data gets deserialized from the RIFF into an intermediate
848+
// data structure (`IRSerialData`), and then the actual
849+
// in-memory structures are created based on the intermediate.
850+
//
851+
// Thus we start by running the `IRSerialReader::readContainer`
852+
// logic to get the `IRSerialData` representation.
853+
//
854+
// TODO(tfoley): This should all get streamlined so that we
855+
// are deserializing IR nodes directly from the format written
856+
// into the RIFF.
857+
const auto moduleChunk = as<RIFF::ListChunk>(chunk);
858+
if (!moduleChunk)
859+
{
860+
SLANG_UNEXPECTED("invalid format for serialized module IR");
861+
}
862+
IRSerialData serialData;
863+
SLANG_RETURN_ON_FAIL(IRSerialReader::readFrom(moduleChunk, &serialData));
864+
865+
// Next we read the actual IR representation out from the
866+
// `serialData`. This is the step that may pull source-location
867+
// information from the provided `sourceLocReader`.
868+
//
869+
IRSerialReader reader;
870+
SLANG_RETURN_ON_FAIL(reader.read(serialData, session, sourceLocReader, outIRModule));
871+
872+
return SLANG_OK;
873+
}
874+
761875
} // namespace Slang

0 commit comments

Comments
 (0)