Skip to content

Commit 57c3f93

Browse files
Implement throw & catch statements (#6916)
* Implement throw statement It already existed in the IR, so only parsing, checking and lowering was missing. * Initial catch implementation Likely very broken. * Error out when catch() isn't last in scope * Prevent accessing variables from scope preceding catch As those may actually not be available at that point. * Add IError and use it in Result type lowering * Add diagnostic tests * Allow caught throws in non-throw functions * Fix catch propagating between functions & SPIR-V merge issue * Add test for non-trivial error types * Fix MSVC build * Fix invalid value type from Result lowering * Also lower error handling in templates * Lower result types only after specialization * Attempt to disambiguate error enums by witness table * Revert matching by witness, types should be distinct too * Don't assert valueField when getting Result's error value It may not exist if the function returns void, but getting the error value is still legitimate. * Update tests for new error numbers & get rid of expected.txt * Change catch lowering to resemble breaking a loop ... To make SPIR-V happy. * Fix dead catch blocks and invalid cached dominator tree * More SPIR-V adjustment * Lower catch as two nested loops * Add defer interaction test and revert broken defer changes * Fix enum type when throwing literals * Cleanup and bikeshedding * Document error handling mechanism * Fix table of contents * Use boolean tag in Result<T, E> * Use anyValue storage for Result<T,E> * Remove IError * Fix formatting * Eradicate success values from docs and tests * Use parseModernParamDecl for catch parameter * Implement do-catch syntax * Implement catch-all * Fix formatting * Fix marshalling native calls that throw --------- Co-authored-by: Yong He <[email protected]>
1 parent d108bfa commit 57c3f93

26 files changed

+1003
-176
lines changed

docs/user-guide/03-convenience-features.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,71 @@ by using the `[ForceInline]` decoration:
793793
int f(int x) { return x + 1; }
794794
```
795795

796+
Error handling
797+
-----------------
798+
799+
Slang supports an error handling mechanism that is superficially similar to
800+
exceptions in many other languages, but has some unique characteristics.
801+
802+
In contrast to C++ exceptions, this mechanism makes the control flow of errors
803+
more explicit, and the performance charasteristics are similar to adding an
804+
if-statement after every potentially throwing function call to check and handle
805+
the error.
806+
807+
In order to be able to throw an error, a function must declare the type of that
808+
error with `throws`:
809+
```
810+
enum MyError
811+
{
812+
Failure,
813+
CatastrophicFailure
814+
}
815+
816+
int f() throws MyError
817+
{
818+
if (computerIsBroken())
819+
throw MyError.CatastrophicFailure;
820+
return 42;
821+
}
822+
```
823+
Currently, functions may only throw a single type of error.
824+
825+
To call a function that may throw, you must prepend it with `try`:
826+
827+
```
828+
let result = try f();
829+
```
830+
831+
If you don't catch the `try`, related errors are re-thrown and the calling
832+
function must declare that it `throws` that error type:
833+
834+
```
835+
void g() throws MyError
836+
{
837+
// This would not compile if `g()` wasn't declared to throw MyError as well.
838+
let result = try f();
839+
printf("Success: %d\n", result);
840+
}
841+
```
842+
843+
To catch an error, you can use a `do-catch` statement:
844+
845+
```
846+
void g()
847+
{
848+
do
849+
{
850+
let result = try f();
851+
printf("Success: %d\n", result);
852+
}
853+
catch(err: MyError)
854+
{
855+
printf("Not good!\n");
856+
}
857+
}
858+
```
859+
860+
You can chain multiple catch statements for different types of errors.
796861

797862
Special Scoping Syntax
798863
-------------------

docs/user-guide/toc.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
<li data-link="convenience-features#extensions"><span>Extensions</span></li>
5050
<li data-link="convenience-features#multi-level-break"><span>Multi-level break</span></li>
5151
<li data-link="convenience-features#force-inlining"><span>Force inlining</span></li>
52+
<li data-link="convenience-features#error-handling"><span>Error handling</span></li>
5253
<li data-link="convenience-features#special-scoping-syntax"><span>Special Scoping Syntax</span></li>
5354
<li data-link="convenience-features#user-defined-attributes-experimental"><span>User Defined Attributes (Experimental)</span></li>
5455
</ul>

source/slang/slang-ast-iterator.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,20 @@ struct ASTIterator
438438
dispatchIfNotNull(stmt->statement);
439439
}
440440

441+
void visitThrowStmt(ThrowStmt* stmt)
442+
{
443+
iterator->maybeDispatchCallback(stmt);
444+
iterator->visitExpr(stmt->expression);
445+
}
446+
447+
void visitCatchStmt(CatchStmt* stmt)
448+
{
449+
if (stmt->errorVar)
450+
iterator->visitDecl(stmt->errorVar);
451+
dispatchIfNotNull(stmt->tryBody);
452+
dispatchIfNotNull(stmt->handleBody);
453+
}
454+
441455
void visitWhileStmt(WhileStmt* stmt)
442456
{
443457
iterator->maybeDispatchCallback(stmt);

source/slang/slang-ast-stmt.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,22 @@ class DeferStmt : public Stmt
294294
FIDDLE() Stmt* statement = nullptr;
295295
};
296296

297+
FIDDLE()
298+
class ThrowStmt : public Stmt
299+
{
300+
FIDDLE(...)
301+
FIDDLE() Expr* expression = nullptr;
302+
};
303+
304+
FIDDLE()
305+
class CatchStmt : public Stmt
306+
{
307+
FIDDLE(...)
308+
FIDDLE() ParamDecl* errorVar = nullptr; // null => catch-all
309+
FIDDLE() Stmt* tryBody = nullptr;
310+
FIDDLE() Stmt* handleBody = nullptr;
311+
};
312+
297313
FIDDLE()
298314
class ExpressionStmt : public Stmt
299315
{

source/slang/slang-check-decl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,15 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase,
856856

857857
void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); }
858858

859+
void visitThrowStmt(ThrowStmt* stmt) { dispatchIfNotNull(stmt->expression); }
860+
861+
void visitCatchStmt(CatchStmt* stmt)
862+
{
863+
dispatchIfNotNull(stmt->errorVar);
864+
dispatchIfNotNull(stmt->tryBody);
865+
dispatchIfNotNull(stmt->handleBody);
866+
}
867+
859868
void visitWhileStmt(WhileStmt* stmt)
860869
{
861870
dispatchIfNotNull(stmt->predicate);

source/slang/slang-check-expr.cpp

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,45 +3951,64 @@ Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr)
39513951
return expr;
39523952

39533953
auto parentFunc = this->m_parentFunc;
3954-
// TODO: check if the try clause is caught.
3955-
// For now we assume all `try`s are not caught (because we don't have catch yet).
3956-
if (!parentFunc)
3954+
auto base = as<InvokeExpr>(expr->base);
3955+
if (!base)
39573956
{
3958-
getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
3957+
getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr);
39593958
return expr;
39603959
}
3961-
if (parentFunc->errorType->equals(m_astBuilder->getBottomType()))
3960+
3961+
auto callee = as<DeclRefExpr>(base->functionExpr);
3962+
if (!callee)
39623963
{
3963-
getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
3964+
getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc);
39643965
return expr;
39653966
}
3966-
if (!as<InvokeExpr>(expr->base))
3967+
3968+
auto funcCallee = as<FuncDecl>(callee->declRef.getDecl());
3969+
Stmt* catchStmt = nullptr;
3970+
if (funcCallee)
39673971
{
3968-
getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr);
3972+
if (funcCallee->errorType->equals(m_astBuilder->getBottomType()))
3973+
{
3974+
getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef);
3975+
return expr;
3976+
}
3977+
catchStmt = findMatchingCatchStmt(funcCallee->errorType);
3978+
}
3979+
3980+
if (FindOuterStmt<DeferStmt>(catchStmt))
3981+
{
3982+
// 'try' may jump outside a defer statement, which isn't allowed for
3983+
// now.
3984+
getSink()->diagnose(expr, Diagnostics::uncaughtTryInsideDefer);
39693985
return expr;
39703986
}
3971-
auto base = as<InvokeExpr>(expr->base);
3972-
if (auto callee = as<DeclRefExpr>(base->functionExpr))
3987+
3988+
if (!catchStmt)
39733989
{
3974-
if (auto funcCallee = as<FuncDecl>(callee->declRef.getDecl()))
3990+
// Uncaught try.
3991+
if (!parentFunc)
39753992
{
3976-
if (funcCallee->errorType->equals(m_astBuilder->getBottomType()))
3977-
{
3978-
getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef);
3979-
}
3980-
if (!parentFunc->errorType->equals(funcCallee->errorType))
3981-
{
3982-
getSink()->diagnose(
3983-
expr,
3984-
Diagnostics::errorTypeOfCalleeIncompatibleWithCaller,
3985-
callee->declRef,
3986-
funcCallee->errorType,
3987-
parentFunc->errorType);
3988-
}
3993+
getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
3994+
return expr;
3995+
}
3996+
if (parentFunc->errorType->equals(m_astBuilder->getBottomType()))
3997+
{
3998+
getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
3999+
return expr;
4000+
}
4001+
if (funcCallee && !parentFunc->errorType->equals(funcCallee->errorType))
4002+
{
4003+
getSink()->diagnose(
4004+
expr,
4005+
Diagnostics::errorTypeOfCalleeIncompatibleWithCaller,
4006+
callee->declRef,
4007+
funcCallee->errorType,
4008+
parentFunc->errorType);
39894009
return expr;
39904010
}
39914011
}
3992-
getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc);
39934012
return expr;
39944013
}
39954014

source/slang/slang-check-impl.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,20 @@ struct SemanticsContext
10281028
return result;
10291029
}
10301030

1031+
template<typename T>
1032+
T* FindOuterStmt(Stmt* searchUntil = nullptr)
1033+
{
1034+
for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil;
1035+
outerStmtInfo = outerStmtInfo->next)
1036+
{
1037+
auto outerStmt = outerStmtInfo->stmt;
1038+
auto found = as<T>(outerStmt);
1039+
if (found)
1040+
return found;
1041+
}
1042+
return nullptr;
1043+
}
1044+
10311045
// Setup the flag to indicate disabling the short-circuiting evaluation
10321046
// for the logical expressions associted with the subcontext
10331047
SemanticsContext disableShortCircuitLogicalExpr()
@@ -2867,6 +2881,8 @@ struct SemanticsVisitor : public SemanticsContext
28672881
void addVisibilityModifier(Decl* decl, DeclVisibility vis);
28682882

28692883
void checkRayPayloadStructFields(StructDecl* structDecl);
2884+
2885+
CatchStmt* findMatchingCatchStmt(Type* errorType);
28702886
};
28712887

28722888

@@ -3011,9 +3027,6 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
30113027

30123028
void checkStmt(Stmt* stmt);
30133029

3014-
template<typename T>
3015-
T* FindOuterStmt(Stmt* searchUntil = nullptr);
3016-
30173030
Stmt* findOuterStmtWithLabel(Name* label);
30183031

30193032
void visitDeclStmt(DeclStmt* stmt);
@@ -3058,6 +3071,10 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
30583071

30593072
void visitDeferStmt(DeferStmt* stmt);
30603073

3074+
void visitThrowStmt(ThrowStmt* stmt);
3075+
3076+
void visitCatchStmt(CatchStmt* stmt);
3077+
30613078
void visitWhileStmt(WhileStmt* stmt);
30623079

30633080
void visitGpuForeachStmt(GpuForeachStmt* stmt);

source/slang/slang-check-stmt.cpp

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context)
4141
checkModifiers(stmt);
4242
}
4343

44+
CatchStmt* SemanticsVisitor::findMatchingCatchStmt(Type* errorType)
45+
{
46+
for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
47+
{
48+
if (auto catchStmt = as<CatchStmt>(outerStmtInfo->stmt))
49+
{
50+
if (!catchStmt->errorVar || catchStmt->errorVar->getType()->equals(errorType))
51+
return catchStmt;
52+
}
53+
}
54+
return nullptr;
55+
}
56+
4457
void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt)
4558
{
4659
// When we encounter a declaration during statement checking,
@@ -118,20 +131,6 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt)
118131
SemanticsVisitor::checkStmt(stmt, *this);
119132
}
120133

121-
template<typename T>
122-
T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil)
123-
{
124-
for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil;
125-
outerStmtInfo = outerStmtInfo->next)
126-
{
127-
auto outerStmt = outerStmtInfo->stmt;
128-
auto found = as<T>(outerStmt);
129-
if (found)
130-
return found;
131-
}
132-
return nullptr;
133-
}
134-
135134
Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
136135
{
137136
for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
@@ -616,6 +615,55 @@ void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt)
616615
subContext.checkStmt(stmt->statement);
617616
}
618617

618+
void SemanticsStmtVisitor::visitThrowStmt(ThrowStmt* stmt)
619+
{
620+
stmt->expression = CheckTerm(stmt->expression);
621+
Stmt* catchStmt = findMatchingCatchStmt(stmt->expression->type);
622+
623+
auto parentFunc = getParentFunc();
624+
if (!catchStmt && (!parentFunc || parentFunc->errorType->equals(m_astBuilder->getBottomType())))
625+
{
626+
getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInNonThrowFunc);
627+
return;
628+
}
629+
630+
if (!catchStmt && !stmt->expression->type->equals(m_astBuilder->getErrorType()))
631+
{
632+
if (!parentFunc->errorType->equals(stmt->expression->type))
633+
{
634+
getSink()->diagnose(
635+
stmt->expression,
636+
Diagnostics::throwTypeIncompatibleWithErrorType,
637+
stmt->expression->type,
638+
parentFunc->errorType);
639+
}
640+
}
641+
642+
if (FindOuterStmt<DeferStmt>(catchStmt))
643+
{
644+
// Allowing 'throw' to escape a defer statement gets quite complex, for
645+
// similar reasons as 'return' - if you have two (or more) defers,
646+
// both of which exit the outer scope, it's unclear which one gets
647+
// called and when. Both can't fully run. That kind of goes against the
648+
// point of 'defer', which is to _always_ run some code when exiting
649+
// scopes.
650+
getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInsideDefer);
651+
}
652+
}
653+
654+
void SemanticsStmtVisitor::visitCatchStmt(CatchStmt* stmt)
655+
{
656+
if (stmt->errorVar)
657+
{
658+
ensureDeclBase(stmt->errorVar, DeclCheckState::DefinitionChecked, this);
659+
stmt->errorVar->hiddenFromLookup = false;
660+
}
661+
662+
WithOuterStmt subContext(this, stmt);
663+
subContext.checkStmt(stmt->tryBody);
664+
subContext.checkStmt(stmt->handleBody);
665+
}
666+
619667
void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt)
620668
{
621669
stmt->expression = CheckExpr(stmt->expression);

0 commit comments

Comments
 (0)