Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/larktools/ebnf_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
//
// The top level rule at which the matching/expansion process starts is named "start".

start: assign_var

assign_var: VARNAME "=" arith_expr
start: multi_line_block

variable: VARNAME ("[" INDEX "]")*
VARNAME: LETTER (LETTER | DIGIT)*
Expand All @@ -27,9 +25,11 @@
// https://lark-parser.readthedocs.io/en/stable/tree_construction.html


line: arith_expr

multi_line_block: (line _NL? | _NL )*
line: arith_expr | assignment

assignment: VARNAME "=" arith_expr


arith_expr: sum
sum: product | addition | subtraction
Expand Down
14 changes: 14 additions & 0 deletions src/larktools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def eval_line(node, env):
child_name = get_name(child)
if child_name == "arith_expr":
return eval_arith_expr(child, env)
elif child_name == "assignment":
return eval_assignment(child, env)

def eval_multi_line_block(node, env):
# this can be either an arithmetic expression or
Expand All @@ -128,6 +130,18 @@ def eval_multi_line_block(node, env):
res = eval_line(child, env)
return res


def eval_assignment(node, env):
# assign result of an expression to a variable
child1, child2 = get_children(node)[0:2]
assert get_name(child1) == "VARNAME"
assert get_name(child2) == "arith_expr"

varname = get_value(child1)
env[varname] = eval_arith_expr(child2, env)
return env[varname]


def eval_variable(node, env):
children = get_children(node)
assert get_name(children[0]) == "VARNAME"
Expand Down
25 changes: 23 additions & 2 deletions tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def parse_and_eval(self, expression: str, env: Optional[Union[None, dict]] = Non
return res


def _parse_and_assert(expression: str, expected: Union[int, float]) -> None:
def _parse_and_assert(expression: str, expected: Union[int, float], env: Optional[Union[None, dict]] = None) -> None:
parser = SyntaxParser()
res = parser.parse_and_eval(expression)
res = parser.parse_and_eval(expression, env)
assert expected == res

def test_multi_line():
Expand All @@ -30,3 +30,24 @@ def test_multi_line():
_parse_and_assert("5+5\n3+4\n1+2", 3)
_parse_and_assert("\n\n5\n\n3\n8", 8)


def test_assignment():
_parse_and_assert("a=5", 5)
_parse_and_assert("z=1+2+3", 6)
_parse_and_assert("y=(1+2+3)", 6)

def test_assignment_env_variable():
# check env variables is set
env = {"a":1}
_parse_and_assert("a=3", 3, env=env)
assert env["a"] == 3

_parse_and_assert("y = x + 3", 20, env={"x":17, "i":123})
_parse_and_assert("y = x + i", 20, env={"x":17, "i":3})

def test_assign_multiline():
_parse_and_assert("x=3 \n y=4 \n z=x+y",7)
_parse_and_assert("x=1 \n z = x + y", 3, env={"y":2})