Skip to content

fix: lua bind validation from inheritance class names #3444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/lua/functions/items/weapon_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ int WeaponFunctions::luaWeaponRegister(lua_State* L) {
if (weaponPtr && *weaponPtr) {
WeaponShared_ptr weapon = *weaponPtr;
if (weapon->weaponType == WEAPON_DISTANCE || weapon->weaponType == WEAPON_AMMO || weapon->weaponType == WEAPON_MISSILE) {
weapon = Lua::getUserdataShared<WeaponDistance>(L, 1, "WeaponDistance");
weapon = Lua::getUserdataShared<WeaponDistance>(L, 1, "Weapon");
} else if (weapon->weaponType == WEAPON_WAND) {
weapon = Lua::getUserdataShared<WeaponWand>(L, 1, "WeaponWand");
weapon = Lua::getUserdataShared<WeaponWand>(L, 1, "Weapon");
} else {
weapon = Lua::getUserdataShared<WeaponMelee>(L, 1, "WeaponMelee");
weapon = Lua::getUserdataShared<WeaponMelee>(L, 1, "Weapon");
}

const uint16_t id = weapon->getID();
Expand Down Expand Up @@ -279,7 +279,7 @@ int WeaponFunctions::luaWeaponBreakChance(lua_State* L) {

int WeaponFunctions::luaWeaponWandDamage(lua_State* L) {
// weapon:damage(damage[min, max]) only use this if the weapon is a wand!
const auto &weapon = Lua::getUserdataShared<WeaponWand>(L, 1, "WeaponWand");
const auto &weapon = Lua::getUserdataShared<WeaponWand>(L, 1, "Weapon");
if (weapon) {
weapon->setMinChange(Lua::getNumber<uint32_t>(L, 2));
if (lua_gettop(L) > 2) {
Expand Down Expand Up @@ -557,7 +557,7 @@ int WeaponFunctions::luaWeaponSlotType(lua_State* L) {

int WeaponFunctions::luaWeaponAmmoType(lua_State* L) {
// weapon:ammoType(type)
const auto &weapon = Lua::getUserdataShared<WeaponDistance>(L, 1, "WeaponDistance");
const auto &weapon = Lua::getUserdataShared<WeaponDistance>(L, 1, "Weapon");
if (weapon) {
const uint16_t id = weapon->getID();
ItemType &it = Item::items.getItemType(id);
Expand Down
60 changes: 60 additions & 0 deletions src/lua/functions/lua_functions_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ void Lua::setWeakMetatable(lua_State* L, int32_t index, const std::string &name)
lua_pushnil(L);
lua_setfield(L, metatable, "__gc");

lua_pushstring(L, name.c_str());
lua_setfield(L, metatable, "__name");

lua_remove(L, childMetatable);
} else {
luaL_getmetatable(L, weakName.c_str());
Expand Down Expand Up @@ -696,6 +699,14 @@ void Lua::registerClass(lua_State* L, const std::string &className, const std::s
lua_pushnumber(L, parents);
lua_rawseti(L, metatable, 'p');

lua_pushstring(L, className.c_str());
lua_setfield(L, metatable, "__name");

if (!baseClass.empty()) {
lua_pushstring(L, baseClass.c_str());
lua_setfield(L, metatable, "baseclass");
}

// className.metatable['t'] = type
auto userTypeEnum = magic_enum::enum_cast<LuaData_t>(className);
if (userTypeEnum.has_value()) {
Expand Down Expand Up @@ -803,3 +814,52 @@ int Lua::validateDispatcherContext(std::string_view fncName) {

return 0;
}

bool Lua::checkMetatableInheritance(lua_State* L, int index, const char* expectedName) {
if (!lua_getmetatable(L, index)) {
return false;
}

// Traverse the inheritance chain.
bool found = false;
while (true) {
// Check the "__name" field.
lua_getfield(L, -1, "__name");
const char* currentName = lua_tostring(L, -1);
lua_pop(L, 1); // Remove __name.
if (currentName && strcmp(currentName, expectedName) == 0) {
found = true;
break;
}

// Check for a "baseclass" field.
lua_getfield(L, -1, "baseclass");
if (lua_isstring(L, -1)) {
const char* baseName = lua_tostring(L, -1);
lua_pop(L, 1); // Remove baseclass value.
if (baseName && strcmp(baseName, expectedName) == 0) {
found = true;
break;
}
// Move to the metatable of the base class.
luaL_getmetatable(L, baseName);
if (lua_isnil(L, -1)) {
lua_pop(L, 1);
break;
}
lua_remove(L, -2); // Remove current metatable.
continue;
}
lua_pop(L, 1); // Remove non-string baseclass.

// Fallback: try the "__index" table.
lua_getfield(L, -1, "__index");
if (!lua_istable(L, -1)) {
lua_pop(L, 1);
break;
}
lua_remove(L, -2); // Remove current metatable; keep __index.
}
lua_pop(L, 1); // Remove final metatable.
return found;
}
50 changes: 34 additions & 16 deletions src/lua/functions/lua_functions_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,33 +216,51 @@ class Lua {
}

/**
* @brief Retrieves a shared pointer to a userdata object from the Lua stack.
* @brief Checks whether the metatable of the object at the given Lua stack index inherits from a specified metatable.
*
* This function verifies whether the metatable of the object either exactly matches the expected metatable
* or, by traversing its inheritance chain via the "__name" and "baseclass" fields (or via the "__index" table),
* determines if it ultimately inherits from a metatable with the expected name.
*
* @param L The Lua state.
* @param index The stack index of the object.
* @param expectedName The expected metatable name to check for in the inheritance chain.
* @return true if the object's metatable is the expected one or inherits from it; false otherwise.
*/
static bool checkMetatableInheritance(lua_State* L, int index, const char* expectedName);

/**
* @brief Retrieves a shared pointer to a userdata object from the Lua stack with inheritance support.
*
* This function attempts to extract a `std::shared_ptr<T>` from the given Lua stack index.
* It ensures that the userdata at the specified index has the expected metatable before
* attempting to retrieve it. This validation prevents crashes due to invalid or outdated
* Lua bindings, ensuring that only correctly-typed userdata is accessed.
* It ensures that the userdata at the specified index either has the expected metatable directly
* or inherits from the expected metatable (by traversing the inheritance chain via the "__name"
* and "baseclass" fields) before attempting to retrieve it. This validation prevents crashes
* due to invalid or outdated Lua bindings, ensuring that only correctly-typed userdata is accessed.
*
* @tparam T The C++ class type of the userdata.
* @param L The Lua state.
* @param arg The index of the Lua stack where the userdata is expected to be.
* @param metatableName The expected metatable name associated with the userdata.
* This ensures that the retrieved object is of the correct type.
* The metatable name should match the one assigned when the userdata
* was originally pushed into Lua.
* @param expectedMetatableName The expected metatable name associated with the userdata.
* This ensures that the retrieved object is of the correct type.
* The metatable name should match the one assigned when the userdata
* was originally pushed into Lua, or be found within its inheritance chain.
*
* @return std::shared_ptr<T> A valid shared pointer to the requested object if the userdata
* exists and has the correct metatable. If the userdata is missing or has an incorrect
* metatable, returns nullptr.
* @return std::shared_ptr<T> A valid shared pointer to the requested object if the userdata exists
* and either has the correct metatable or inherits from the expected metatable. If the userdata
* is missing or does not satisfy these conditions, returns nullptr.
*/
template <class T>
static std::shared_ptr<T> getUserdataShared(lua_State* L, int32_t arg, const char* metatableName) {
// Verify that the userdata at 'arg' has the correct metatable.
auto userdata = static_cast<std::shared_ptr<T>*>(luaL_testudata(L, arg, metatableName));
static std::shared_ptr<T> getUserdataShared(lua_State* L, int32_t arg, const char* expectedMetatableName) {
auto userdata = static_cast<std::shared_ptr<T>*>(luaL_testudata(L, arg, expectedMetatableName));
if (!userdata) {
return nullptr;
if (!checkMetatableInheritance(L, arg, expectedMetatableName)) {
return nullptr;
}

userdata = static_cast<std::shared_ptr<T>*>(lua_touserdata(L, arg));
}
return *userdata;
return userdata ? *userdata : nullptr;
}

template <class T>
Expand Down
Loading