Skip to content

Commit 7a23ec2

Browse files
authored
fix(server): fix memory leak on lua error (#4236)
The bug: calling lua_error does not return, instead it unwinds the Lua call stack until an error handler is found or the script exits. This lead to memory leak on object that should release memory in destructor. Specific example is the absl::FixedArray<string_view, 4> args(argc); which allocates on heap if argc > 4. The free was not called leading to memory leak. The fix: Add scoping to to the function so that the destructor is called before calling raise error Signed-off-by: adi_holden <[email protected]>
1 parent 1c09056 commit 7a23ec2

File tree

3 files changed

+116
-45
lines changed

3 files changed

+116
-45
lines changed

src/core/interpreter.cc

Lines changed: 65 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,10 @@ void SetGlobalArrayInternal(lua_State* lua, const char* name, Interpreter::Slice
256256
/* In case the error set into the Lua stack by PushError() was generated
257257
* by the non-error-trapping version of redis.pcall(), which is redis.call(),
258258
* this function will raise the Lua error so that the execution of the
259-
* script will be halted. */
260-
int RaiseError(lua_State* lua) {
259+
* script will be halted.
260+
* This function never returns, it unwinds the Lua call stack until an error handler is found or the
261+
* script exits */
262+
int RaiseErrorAndAbort(lua_State* lua) {
261263
lua_pushstring(lua, "err");
262264
lua_gettable(lua, -2);
263265
return lua_error(lua);
@@ -467,7 +469,7 @@ int RedisLogCommand(lua_State* lua) {
467469
int argc = lua_gettop(lua);
468470
if (argc < 2) {
469471
PushError(lua, "redis.log() requires two arguments or more.");
470-
return RaiseError(lua);
472+
return RaiseErrorAndAbort(lua);
471473
}
472474

473475
return 0;
@@ -891,40 +893,12 @@ void Interpreter::RunGC() {
891893
lua_gc(lua_, LUA_GCCOLLECT);
892894
}
893895

894-
// Returns number of results, which is always 1 in this case.
895-
// Please note that lua resets the stack once the function returns so no need
896-
// to unwind the stack manually in the function (though lua allows doing this).
897-
int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplorer* explorer) {
898-
/* By using Lua debug hooks it is possible to trigger a recursive call
899-
* to luaRedisGenericCommand(), which normally should never happen.
900-
* To make this function reentrant is futile and makes it slower, but
901-
* we should at least detect such a misuse, and abort. */
902-
if (cmd_depth_) {
903-
const char* recursion_warning =
904-
"luaRedisGenericCommand() recursive call detected. "
905-
"Are you doing funny stuff with Lua debug hooks?";
906-
PushError(lua_, recursion_warning);
907-
return 1;
908-
}
909-
910-
if (!redis_func_) {
911-
PushError(lua_, "internal error - redis function not defined");
912-
return raise_error ? RaiseError(lua_) : 1;
913-
}
914-
915-
cmd_depth_++;
896+
std::optional<absl::FixedArray<std::string_view, 4>> Interpreter::PrepareArgs() {
916897
int argc = lua_gettop(lua_);
917-
918-
#define RETURN_ERROR(err) \
919-
{ \
920-
PushError(lua_, err); \
921-
cmd_depth_--; \
922-
return raise_error ? RaiseError(lua_) : 1; \
923-
}
924-
925898
/* Require at least one argument */
926899
if (argc == 0) {
927-
RETURN_ERROR("Please specify at least one argument for redis.call()");
900+
PushError(lua_, "Please specify at least one argument for redis.call()");
901+
return std::nullopt;
928902
}
929903

930904
size_t blob_len = 0;
@@ -947,21 +921,22 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplore
947921
blob_len += lua_rawlen(lua_, idx) + 1;
948922
continue;
949923
default:
950-
RETURN_ERROR("Lua redis() command arguments must be strings or integers");
924+
PushError(lua_, "Lua redis() command arguments must be strings or integers");
925+
return std::nullopt;
951926
}
952927
}
953928

954-
char name_buffer[32]; // backing storage for cmd name
955929
absl::FixedArray<string_view, 4> args(argc);
956930

957931
// Copy command name to name_buffer and set it as first arg.
958932
unsigned name_len = lua_rawlen(lua_, 1);
959-
if (name_len >= sizeof(name_buffer)) {
960-
RETURN_ERROR("Lua redis() command name too long");
933+
if (name_len >= sizeof(name_buffer_)) {
934+
PushError(lua_, "Lua redis() command name too long");
935+
return std::nullopt;
961936
}
962937

963-
memcpy(name_buffer, lua_tostring(lua_, 1), name_len);
964-
args[0] = {name_buffer, name_len};
938+
memcpy(name_buffer_, lua_tostring(lua_, 1), name_len);
939+
args[0] = {name_buffer_, name_len};
965940
buffer_.resize(blob_len + 4, '\0'); // backing storage for args
966941

967942
char* cur = buffer_.data();
@@ -993,7 +968,13 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplore
993968
/* Pop all arguments from the stack, we do not need them anymore
994969
* and this way we guaranty we will have room on the stack for the result. */
995970
lua_pop(lua_, argc);
971+
return args;
972+
}
996973

974+
// Calls redis function
975+
// Returns false if error needs to be raised.
976+
bool Interpreter::CallRedisFunction(bool raise_error, bool async, ObjectExplorer* explorer,
977+
SliceSpan args) {
997978
// Calling with custom explorer is not supported with errors or async
998979
DCHECK(explorer == nullptr || (!raise_error && !async));
999980

@@ -1003,8 +984,8 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplore
1003984
translator.emplace(lua_);
1004985
explorer = &*translator;
1005986
}
1006-
1007-
redis_func_(CallArgs{SliceSpan{args}, &buffer_, explorer, async, raise_error, &raise_error});
987+
cmd_depth_++;
988+
redis_func_(CallArgs{args, &buffer_, explorer, async, raise_error, &raise_error});
1008989
cmd_depth_--;
1009990

1010991
// Shrink reusable buffer if it's too big.
@@ -1014,18 +995,57 @@ int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplore
1014995
}
1015996

1016997
if (!translator)
1017-
return 0;
998+
return true;
1018999

10191000
// Raise error for regular 'call' command if needed.
10201001
if (raise_error && translator->HasError()) {
10211002
// error is already on top of stack
1022-
return RaiseError(lua_);
1003+
return false;
10231004
}
10241005

10251006
if (!async)
10261007
DCHECK_EQ(1, lua_gettop(lua_));
10271008

1028-
return 1;
1009+
return true;
1010+
}
1011+
1012+
// Returns number of results, which is always 1 in this case.
1013+
// Please note that lua resets the stack once the function returns so no need
1014+
// to unwind the stack manually in the function (though lua allows doing this).
1015+
int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplorer* explorer) {
1016+
/* By using Lua debug hooks it is possible to trigger a recursive call
1017+
* to luaRedisGenericCommand(), which normally should never happen.
1018+
* To make this function reentrant is futile and makes it slower, but
1019+
* we should at least detect such a misuse, and abort. */
1020+
if (cmd_depth_) {
1021+
const char* recursion_warning =
1022+
"luaRedisGenericCommand() recursive call detected. "
1023+
"Are you doing funny stuff with Lua debug hooks?";
1024+
PushError(lua_, recursion_warning);
1025+
return 1;
1026+
}
1027+
1028+
if (!redis_func_) {
1029+
PushError(lua_, "internal error - redis function not defined");
1030+
if (raise_error) {
1031+
return RaiseErrorAndAbort(lua_);
1032+
}
1033+
return 1;
1034+
}
1035+
1036+
// IMPORTANT! all allocations within this funciton must be freed
1037+
// BEFORE calling RaiseErrorAndAbort in case of script error. RaiseErrorAndAbort
1038+
// uses longjmp which bypasses stack unwinding and skips the destruction of objects.
1039+
{
1040+
std::optional<absl::FixedArray<std::string_view, 4>> args = PrepareArgs();
1041+
if (args.has_value()) {
1042+
raise_error = !CallRedisFunction(raise_error, async, explorer, SliceSpan{*args});
1043+
}
1044+
}
1045+
if (!raise_error) {
1046+
return 1;
1047+
}
1048+
return RaiseErrorAndAbort(lua_); // this function never returns, it unwinds the Lua call stack
10291049
}
10301050

10311051
int Interpreter::RedisCallCommand(lua_State* lua) {

src/core/interpreter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma once
66

7+
#include <absl/container/fixed_array.h>
78
#include <absl/types/span.h>
89

910
#include <functional>
@@ -139,10 +140,14 @@ class Interpreter {
139140
static int RedisACallCommand(lua_State* lua);
140141
static int RedisAPCallCommand(lua_State* lua);
141142

143+
std::optional<absl::FixedArray<std::string_view, 4>> PrepareArgs();
144+
bool CallRedisFunction(bool raise_error, bool async, ObjectExplorer* explorer, SliceSpan args);
145+
142146
lua_State* lua_;
143147
unsigned cmd_depth_ = 0;
144148
RedisFunc redis_func_;
145149
std::string buffer_;
150+
char name_buffer_[32]; // backing storage for cmd name
146151
};
147152

148153
// Manages an internal interpreter pool. This allows multiple connections residing on the same

tests/dragonfly/memory_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,49 @@ async def test_rss_oom_ratio(df_factory: DflyInstanceFactory, admin_port):
114114
# new client create shoud not fail after memory usage decrease
115115
client = df_server.client()
116116
await client.execute_command("set x y")
117+
118+
119+
@pytest.mark.asyncio
120+
@dfly_args(
121+
{
122+
"maxmemory": "512mb",
123+
"proactor_threads": 1,
124+
}
125+
)
126+
async def test_eval_with_oom(df_factory: DflyInstanceFactory):
127+
"""
128+
Test running eval commands when dragonfly returns OOM on write commands and check rss memory
129+
This test was writen after detecting memory leak in script runs on OOM state
130+
"""
131+
df_server = df_factory.create()
132+
df_server.start()
133+
134+
client = df_server.client()
135+
await client.execute_command("DEBUG POPULATE 20000 key 40000 RAND")
136+
137+
await asyncio.sleep(1) # Wait for another RSS heartbeat update in Dragonfly
138+
139+
info = await client.info("memory")
140+
logging.debug(f'Used memory {info["used_memory"]}, rss {info["used_memory_rss"]}')
141+
142+
reject_limit = 512 * 1024 * 1024 # 256mb
143+
assert info["used_memory"] > reject_limit
144+
rss_before_eval = info["used_memory_rss"]
145+
146+
pipe = client.pipeline(transaction=False)
147+
MSET_SCRIPT = """
148+
redis.call('MSET', KEYS[1], ARGV[1], KEYS[2], ARGV[2])
149+
"""
150+
151+
for _ in range(20):
152+
for _ in range(8000):
153+
pipe.eval(MSET_SCRIPT, 2, "x1", "y1", "x2", "y2")
154+
# reject mset due to oom
155+
with pytest.raises(redis.exceptions.ResponseError):
156+
await pipe.execute()
157+
158+
await asyncio.sleep(1) # Wait for another RSS heartbeat update in Dragonfly
159+
160+
info = await client.info("memory")
161+
logging.debug(f'Used memory {info["used_memory"]}, rss {info["used_memory_rss"]}')
162+
assert rss_before_eval * 1.01 > info["used_memory_rss"]

0 commit comments

Comments
 (0)