Skip to content

Commit 48c26f8

Browse files
committed
Merge branch 'main' of github.com:EmmyLuaLs/emmylua-analyzer-rust
2 parents 467155c + b3b55ee commit 48c26f8

File tree

9 files changed

+230
-22
lines changed

9 files changed

+230
-22
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,19 +250,38 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
250250
if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer) {
251251
let name_token = tag.get_name_token()?;
252252
let name = name_token.get_name_text();
253-
let cast_op_type = tag.get_op_type()?;
253+
254+
let op_types: Vec<_> = tag.get_op_types().collect();
255+
let cast_op_type = op_types.first()?;
256+
257+
// Bind the true condition type
254258
if let Some(node_type) = cast_op_type.get_type() {
255259
let typ = infer_type(analyzer, node_type.clone());
256260
let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id());
257261
let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id);
258262
bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ));
259263
};
260264

265+
// Bind the false condition type if present
266+
let fallback_cast = if op_types.len() > 1 {
267+
let fallback_op_type = &op_types[1];
268+
if let Some(node_type) = fallback_op_type.get_type() {
269+
let typ = infer_type(analyzer, node_type.clone());
270+
let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id());
271+
let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id);
272+
bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ));
273+
}
274+
Some(fallback_op_type.to_ptr())
275+
} else {
276+
None
277+
};
278+
261279
analyzer.db.get_flow_index_mut().add_signature_cast(
262280
analyzer.file_id,
263281
signature_id,
264282
name.to_string(),
265283
cast_op_type.to_ptr(),
284+
fallback_cast,
266285
);
267286
} else {
268287
report_orphan_tag(analyzer, &tag);

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,4 +1355,108 @@ _2 = a[1]
13551355
"#
13561356
));
13571357
}
1358+
1359+
#[test]
1360+
fn test_return_cast_with_fallback() {
1361+
let mut ws = VirtualWorkspace::new();
1362+
1363+
ws.def(
1364+
r#"
1365+
---@class Creature
1366+
1367+
---@class Player: Creature
1368+
1369+
---@class Monster: Creature
1370+
1371+
---@return boolean
1372+
---@return_cast creature Player else Monster
1373+
local function isPlayer(creature)
1374+
return true
1375+
end
1376+
1377+
local creature ---@type Creature
1378+
1379+
if isPlayer(creature) then
1380+
a = creature
1381+
else
1382+
b = creature
1383+
end
1384+
"#,
1385+
);
1386+
1387+
let a = ws.expr_ty("a");
1388+
let a_expected = ws.ty("Player");
1389+
assert_eq!(a, a_expected);
1390+
1391+
let b = ws.expr_ty("b");
1392+
let b_expected = ws.ty("Monster");
1393+
assert_eq!(b, b_expected);
1394+
}
1395+
1396+
#[test]
1397+
fn test_return_cast_with_fallback_self() {
1398+
let mut ws = VirtualWorkspace::new();
1399+
1400+
ws.def(
1401+
r#"
1402+
---@class Creature
1403+
1404+
---@class Player: Creature
1405+
1406+
---@class Monster: Creature
1407+
local m = {}
1408+
1409+
---@return boolean
1410+
---@return_cast self Player else Monster
1411+
function m:isPlayer()
1412+
end
1413+
1414+
if m:isPlayer() then
1415+
a = m
1416+
else
1417+
b = m
1418+
end
1419+
"#,
1420+
);
1421+
1422+
let a = ws.expr_ty("a");
1423+
let a_expected = ws.ty("Player");
1424+
assert_eq!(a, a_expected);
1425+
1426+
let b = ws.expr_ty("b");
1427+
let b_expected = ws.ty("Monster");
1428+
assert_eq!(b, b_expected);
1429+
}
1430+
1431+
#[test]
1432+
fn test_return_cast_backward_compatibility() {
1433+
let mut ws = VirtualWorkspace::new();
1434+
1435+
ws.def(
1436+
r#"
1437+
---@return boolean
1438+
---@return_cast n integer
1439+
local function isInteger(n)
1440+
return true
1441+
end
1442+
1443+
local a ---@type integer | string
1444+
1445+
if isInteger(a) then
1446+
d = a
1447+
else
1448+
e = a
1449+
end
1450+
"#,
1451+
);
1452+
1453+
let d = ws.expr_ty("d");
1454+
let d_expected = ws.ty("integer");
1455+
assert_eq!(d, d_expected);
1456+
1457+
// Should still use the original behavior (remove integer from union)
1458+
let e = ws.expr_ty("e");
1459+
let e_expected = ws.ty("string");
1460+
assert_eq!(e, e_expected);
1461+
}
13581462
}

crates/emmylua_code_analysis/src/db_index/flow/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,19 @@ impl LuaFlowIndex {
5252
signature_id: LuaSignatureId,
5353
name: String,
5454
cast: LuaAstPtr<LuaDocOpType>,
55+
fallback_cast: Option<LuaAstPtr<LuaDocOpType>>,
5556
) {
5657
self.signature_cast_cache
5758
.entry(file_id)
5859
.or_default()
59-
.insert(signature_id, LuaSignatureCast { name, cast });
60+
.insert(
61+
signature_id,
62+
LuaSignatureCast {
63+
name,
64+
cast,
65+
fallback_cast,
66+
},
67+
);
6068
}
6169
}
6270

crates/emmylua_code_analysis/src/db_index/flow/signature_cast.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ use emmylua_parser::{LuaAstPtr, LuaDocOpType};
44
pub struct LuaSignatureCast {
55
pub name: String,
66
pub cast: LuaAstPtr<LuaDocOpType>,
7+
pub fallback_cast: Option<LuaAstPtr<LuaDocOpType>>,
78
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,50 @@ fn get_type_at_call_expr_by_signature_self(
228228
};
229229

230230
let signature_root = syntax_tree.get_chunk_node();
231-
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
232-
return Ok(ResultTypeOrContinue::Continue);
231+
232+
// Choose the appropriate cast based on condition_flow and whether fallback exists
233+
let result_type = match condition_flow {
234+
InferConditionFlow::TrueCondition => {
235+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
236+
return Ok(ResultTypeOrContinue::Continue);
237+
};
238+
cast_type(
239+
db,
240+
signature_id.get_file_id(),
241+
cast_op_type,
242+
antecedent_type,
243+
condition_flow,
244+
)?
245+
}
246+
InferConditionFlow::FalseCondition => {
247+
// Use fallback_cast if available, otherwise use the default behavior
248+
if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast {
249+
let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else {
250+
return Ok(ResultTypeOrContinue::Continue);
251+
};
252+
cast_type(
253+
db,
254+
signature_id.get_file_id(),
255+
fallback_op_type,
256+
antecedent_type.clone(),
257+
InferConditionFlow::TrueCondition, // Apply fallback as force cast
258+
)?
259+
} else {
260+
// Original behavior: remove the true type from antecedent
261+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
262+
return Ok(ResultTypeOrContinue::Continue);
263+
};
264+
cast_type(
265+
db,
266+
signature_id.get_file_id(),
267+
cast_op_type,
268+
antecedent_type,
269+
condition_flow,
270+
)?
271+
}
272+
}
233273
};
234274

235-
let result_type = cast_type(
236-
db,
237-
signature_id.get_file_id(),
238-
cast_op_type,
239-
antecedent_type,
240-
condition_flow,
241-
)?;
242275
Ok(ResultTypeOrContinue::Result(result_type))
243276
}
244277

@@ -304,17 +337,50 @@ fn get_type_at_call_expr_by_signature_param_name(
304337
};
305338

306339
let signature_root = syntax_tree.get_chunk_node();
307-
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
308-
return Ok(ResultTypeOrContinue::Continue);
340+
341+
// Choose the appropriate cast based on condition_flow and whether fallback exists
342+
let result_type = match condition_flow {
343+
InferConditionFlow::TrueCondition => {
344+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
345+
return Ok(ResultTypeOrContinue::Continue);
346+
};
347+
cast_type(
348+
db,
349+
signature_id.get_file_id(),
350+
cast_op_type,
351+
antecedent_type,
352+
condition_flow,
353+
)?
354+
}
355+
InferConditionFlow::FalseCondition => {
356+
// Use fallback_cast if available, otherwise use the default behavior
357+
if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast {
358+
let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else {
359+
return Ok(ResultTypeOrContinue::Continue);
360+
};
361+
cast_type(
362+
db,
363+
signature_id.get_file_id(),
364+
fallback_op_type,
365+
antecedent_type.clone(),
366+
InferConditionFlow::TrueCondition, // Apply fallback as force cast
367+
)?
368+
} else {
369+
// Original behavior: remove the true type from antecedent
370+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
371+
return Ok(ResultTypeOrContinue::Continue);
372+
};
373+
cast_type(
374+
db,
375+
signature_id.get_file_id(),
376+
cast_op_type,
377+
antecedent_type,
378+
condition_flow,
379+
)?
380+
}
381+
}
309382
};
310383

311-
let result_type = cast_type(
312-
db,
313-
signature_id.get_file_id(),
314-
cast_op_type,
315-
antecedent_type,
316-
condition_flow,
317-
)?;
318384
Ok(ResultTypeOrContinue::Result(result_type))
319385
}
320386

crates/emmylua_parser/src/grammar/doc/tag.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,21 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult {
353353
}
354354

355355
// ---@return_cast <param name> <type>
356+
// ---@return_cast <param name> <true_type> else <false_type>
356357
fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult {
357358
p.set_state(LuaDocLexerState::Normal);
358359
let m = p.mark(LuaSyntaxKind::DocTagReturnCast);
359360
p.bump();
360361
expect_token(p, LuaTokenKind::TkName)?;
361362

362363
parse_op_type(p)?;
364+
365+
// Allow optional second type after 'else' for false condition
366+
if p.current_token() == LuaTokenKind::TkDocElse {
367+
p.bump();
368+
parse_op_type(p)?;
369+
}
370+
363371
p.set_state(LuaDocLexerState::Description);
364372
parse_description(p);
365373
Ok(m.complete(p))

crates/emmylua_parser/src/kind/lua_token_kind.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ pub enum LuaTokenKind {
143143
TkDocAs, // as
144144
TkDocIn, // in
145145
TkDocInfer, // infer
146+
TkDocElse, // else (for return_cast)
146147
TkDocContinue, // ---
147148
TkDocContinueOr, // ---| or ---|+ or ---|>
148149
TkDocDetail, // a description

crates/emmylua_parser/src/lexer/lua_doc_lexer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ fn to_token_or_name(text: &str) -> LuaTokenKind {
638638
"as" => LuaTokenKind::TkDocAs,
639639
"and" => LuaTokenKind::TkAnd,
640640
"or" => LuaTokenKind::TkOr,
641+
"else" => LuaTokenKind::TkDocElse,
641642
_ => LuaTokenKind::TkName,
642643
}
643644
}

crates/emmylua_parser/src/syntax/node/doc/tag.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,8 +1458,8 @@ impl LuaAstNode for LuaDocTagReturnCast {
14581458
impl LuaDocDescriptionOwner for LuaDocTagReturnCast {}
14591459

14601460
impl LuaDocTagReturnCast {
1461-
pub fn get_op_type(&self) -> Option<LuaDocOpType> {
1462-
self.child()
1461+
pub fn get_op_types(&self) -> LuaAstChildren<LuaDocOpType> {
1462+
self.children()
14631463
}
14641464

14651465
pub fn get_name_token(&self) -> Option<LuaNameToken> {

0 commit comments

Comments
 (0)