@@ -857,6 +857,7 @@ class Context {
857857 Kind TheKind;
858858 Optional<AnyFunctionRef> Function;
859859 bool HandlesErrors = false ;
860+ bool HandlesAsync = false ;
860861
861862 // / Whether error-handling queries should ignore the function context, e.g.,
862863 // / for autoclosure and rethrows checks.
@@ -870,9 +871,10 @@ class Context {
870871 assert (TheKind != Kind::PotentiallyHandled);
871872 }
872873
873- explicit Context (bool handlesErrors, Optional<AnyFunctionRef> function)
874+ explicit Context (bool handlesErrors, bool handlesAsync,
875+ Optional<AnyFunctionRef> function)
874876 : TheKind(Kind::PotentiallyHandled), Function(function),
875- HandlesErrors(handlesErrors) { }
877+ HandlesErrors(handlesErrors), HandlesAsync(handlesAsync) { }
876878
877879public:
878880 // / Whether this is a function that rethrows.
@@ -910,7 +912,7 @@ class Context {
910912
911913 static Context forTopLevelCode (TopLevelCodeDecl *D) {
912914 // Top-level code implicitly handles errors and 'async' calls.
913- return Context (/* handlesErrors=*/ true , None);
915+ return Context (/* handlesErrors=*/ true , /* handlesAsync= */ true , None);
914916 }
915917
916918 static Context forFunction (AbstractFunctionDecl *D) {
@@ -930,8 +932,7 @@ class Context {
930932 }
931933 }
932934
933- bool handlesErrors = D->hasThrows ();
934- return Context (handlesErrors, AnyFunctionRef (D));
935+ return Context (D->hasThrows (), D->hasAsync (), AnyFunctionRef (D));
935936 }
936937
937938 static Context forDeferBody () {
@@ -956,12 +957,15 @@ class Context {
956957 static Context forClosure (AbstractClosureExpr *E) {
957958 // Determine whether the closure has throwing function type.
958959 bool closureTypeThrows = true ;
960+ bool closureTypeIsAsync = true ;
959961 if (auto closureType = E->getType ()) {
960- if (auto fnType = closureType->getAs <AnyFunctionType>())
962+ if (auto fnType = closureType->getAs <AnyFunctionType>()) {
961963 closureTypeThrows = fnType->isThrowing ();
964+ closureTypeIsAsync = fnType->isAsync ();
965+ }
962966 }
963967
964- return Context (closureTypeThrows, AnyFunctionRef (E));
968+ return Context (closureTypeThrows, closureTypeIsAsync, AnyFunctionRef (E));
965969 }
966970
967971 static Context forCatchPattern (CaseStmt *S) {
@@ -1013,6 +1017,10 @@ class Context {
10131017 llvm_unreachable (" bad error kind" );
10141018 }
10151019
1020+ bool handlesAsync () const {
1021+ return HandlesAsync;
1022+ }
1023+
10161024 DeclContext *getRethrowsDC () const {
10171025 if (!isRethrows ())
10181026 return nullptr ;
@@ -1182,7 +1190,6 @@ class Context {
11821190 case Kind::DeferBody:
11831191 diagnoseThrowInIllegalContext (Diags, E, getKind ());
11841192 return ;
1185-
11861193 }
11871194 llvm_unreachable (" bad context kind" );
11881195 }
@@ -1211,6 +1218,64 @@ class Context {
12111218 }
12121219 llvm_unreachable (" bad context kind" );
12131220 }
1221+
1222+ void diagnoseUncoveredAsyncSite (ASTContext &ctx, ASTNode node) {
1223+ SourceRange highlight;
1224+
1225+ // Generate more specific messages in some cases.
1226+ if (auto apply = dyn_cast_or_null<ApplyExpr>(node.dyn_cast <Expr*>()))
1227+ highlight = apply->getSourceRange ();
1228+
1229+ auto diag = diag::async_call_without_await;
1230+ if (isAutoClosure ())
1231+ diag = diag::async_call_without_await_in_autoclosure;
1232+ ctx.Diags .diagnose (node.getStartLoc (), diag)
1233+ .highlight (highlight);
1234+ }
1235+
1236+ void diagnoseAsyncInIllegalContext (DiagnosticEngine &Diags, ASTNode node) {
1237+ if (auto *e = node.dyn_cast <Expr*>()) {
1238+ if (isa<ApplyExpr>(e)) {
1239+ Diags.diagnose (e->getLoc (), diag::async_call_in_illegal_context,
1240+ static_cast <unsigned >(getKind ()));
1241+ return ;
1242+ }
1243+ }
1244+
1245+ Diags.diagnose (node.getStartLoc (), diag::await_in_illegal_context,
1246+ static_cast <unsigned >(getKind ()));
1247+ }
1248+
1249+ void maybeAddAsyncNote (DiagnosticEngine &Diags) {
1250+ if (!Function)
1251+ return ;
1252+
1253+ auto func = dyn_cast_or_null<FuncDecl>(Function->getAbstractFunctionDecl ());
1254+ if (!func)
1255+ return ;
1256+
1257+ func->diagnose (diag::note_add_async_to_function, func->getName ());
1258+ }
1259+
1260+ void diagnoseUnhandledAsyncSite (DiagnosticEngine &Diags, ASTNode node) {
1261+ switch (getKind ()) {
1262+ case Kind::PotentiallyHandled:
1263+ Diags.diagnose (node.getStartLoc (), diag::async_in_nonasync_function,
1264+ node.isExpr (ExprKind::Await), isAutoClosure ());
1265+ maybeAddAsyncNote (Diags);
1266+ return ;
1267+
1268+ case Kind::EnumElementInitializer:
1269+ case Kind::GlobalVarInitializer:
1270+ case Kind::IVarInitializer:
1271+ case Kind::DefaultArgument:
1272+ case Kind::CatchPattern:
1273+ case Kind::CatchGuard:
1274+ case Kind::DeferBody:
1275+ diagnoseAsyncInIllegalContext (Diags, node);
1276+ return ;
1277+ }
1278+ }
12141279};
12151280
12161281// / A class to walk over a local context and validate the correctness
@@ -1322,6 +1387,12 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
13221387 Self.MaxThrowingKind = ThrowingKind::None;
13231388 }
13241389
1390+ void resetCoverageForAutoclosureBody () {
1391+ Self.Flags .clear (ContextFlags::IsAsyncCovered);
1392+ Self.Flags .clear (ContextFlags::HasAnyAsyncSite);
1393+ Self.Flags .clear (ContextFlags::HasAnyAwait);
1394+ }
1395+
13251396 void resetCoverageForDoCatch () {
13261397 Self.Flags .reset ();
13271398 Self.MaxThrowingKind = ThrowingKind::None;
@@ -1409,6 +1480,7 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
14091480 ShouldRecurse_t checkAutoClosure (AutoClosureExpr *E) {
14101481 ContextScope scope (*this , Context::forClosure (E));
14111482 scope.enterSubFunction ();
1483+ scope.resetCoverageForAutoclosureBody ();
14121484 E->getBody ()->walk (*this );
14131485 scope.preserveCoverageFromAutoclosureBody ();
14141486 return ShouldNotRecurse;
@@ -1572,17 +1644,14 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
15721644 if (classification.isAsync ()) {
15731645 // Remember that we've seen an async call.
15741646 Flags.set (ContextFlags::HasAnyAsyncSite);
1575-
1647+
1648+ // Diagnose async calls in a context that doesn't handle async.
1649+ if (!CurContext.handlesAsync ()) {
1650+ CurContext.diagnoseUnhandledAsyncSite (Ctx.Diags , E);
1651+ }
15761652 // Diagnose async calls that are outside of an await context.
1577- if (!Flags.has (ContextFlags::IsAsyncCovered)) {
1578- SourceRange highlight;
1579-
1580- // Generate more specific messages in some cases.
1581- if (auto e = dyn_cast_or_null<ApplyExpr>(E.dyn_cast <Expr*>()))
1582- highlight = e->getSourceRange ();
1583-
1584- Ctx.Diags .diagnose (E.getStartLoc (), diag::async_call_without_await)
1585- .highlight (highlight);
1653+ else if (!Flags.has (ContextFlags::IsAsyncCovered)) {
1654+ CurContext.diagnoseUncoveredAsyncSite (Ctx, E);
15861655 }
15871656 }
15881657
@@ -1626,10 +1695,16 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
16261695 scope.enterAwait ();
16271696
16281697 E->getSubExpr ()->walk (*this );
1629-
1630- // Warn about 'await' expressions that weren't actually needed.
1631- if (!Flags.has (ContextFlags::HasAnyAsyncSite))
1632- Ctx.Diags .diagnose (E->getAwaitLoc (), diag::no_async_in_await);
1698+
1699+ // Warn about 'await' expressions that weren't actually needed, unless of
1700+ // course we're in a context that could never handle an 'async'. Then, we
1701+ // produce an error.
1702+ if (!Flags.has (ContextFlags::HasAnyAsyncSite)) {
1703+ if (CurContext.handlesAsync ())
1704+ Ctx.Diags .diagnose (E->getAwaitLoc (), diag::no_async_in_await);
1705+ else
1706+ CurContext.diagnoseUnhandledAsyncSite (Ctx.Diags , E);
1707+ }
16331708
16341709 // Inform the parent of the walk that an 'await' exists here.
16351710 scope.preserveCoverageFromAwaitOperand ();
0 commit comments