Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ void BindGuard(pybind11::module *m) {
.def(py::init<const py::function &>(), py::arg("guard_check_fn"));
py::class_<GuardGroup, GuardBase, std::shared_ptr<GuardGroup>>(
*m, "GuardGroup", R"DOC(GuardGroup Class.)DOC")
.def(py::init<std::vector<std::shared_ptr<GuardBase>>>(),
.def(py::init<const std::vector<std::shared_ptr<GuardBase>> &>(),
py::arg("guards"));
py::class_<TypeMatchGuard, GuardBase, std::shared_ptr<TypeMatchGuard>>(
*m, "TypeMatchGuard", R"DOC(TypeMatchGuard Class.)DOC")
.def(py::init<const py::type &>(), py::arg("py_type"));
py::class_<LengthMatchGuard, GuardBase, std::shared_ptr<LengthMatchGuard>>(
*m, "LengthMatchGuard", R"DOC(LengthMatchGuard Class.)DOC")
.def(py::init<Py_ssize_t>(), py::arg("length"));
.def(py::init<const Py_ssize_t &>(), py::arg("length"));
py::class_<ValueMatchGuard, GuardBase, std::shared_ptr<ValueMatchGuard>>(
*m, "ValueMatchGuard", R"DOC(ValueMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("py_value"));
Expand All @@ -90,6 +90,9 @@ void BindGuard(pybind11::module *m) {
py::class_<LayerMatchGuard, GuardBase, std::shared_ptr<LayerMatchGuard>>(
*m, "LayerMatchGuard", R"DOC(LayerMatchGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("layer_obj"));
py::class_<ShapeMatchGuard, GuardBase, std::shared_ptr<ShapeMatchGuard>>(
*m, "ShapeMatchGuard", R"DOC(ShapeMatchGuard Class.)DOC")
.def(py::init<const std::vector<py::object> &>(), py::arg("shape"));

m->def(
"merge_guard",
Expand Down
53 changes: 39 additions & 14 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/sot/guards.h"
#include "paddle/phi/api/include/tensor.h"

#if SOT_IS_SUPPORTED

Expand All @@ -25,8 +26,16 @@ static inline PyObject* PyObject_CallOneArg(PyObject* func, PyObject* arg) {
}
#endif

std::optional<paddle::Tensor> GetTensorFromPyObject(PyObject* obj) {
if (!paddle::pybind::PyCheckTensor(obj)) {
// TODO(zrr1999): PyCheckTensor only check if the object is a p_tensor_type.
return std::nullopt;
}
return reinterpret_cast<paddle::pybind::TensorObject*>(obj)->tensor;
}

bool LambdaGuard::check(PyObject* value) {
PyObject* x = PyObject_CallOneArg(_guard_check_fn, value);
PyObject* x = PyObject_CallOneArg(guard_check_fn_, value);
if (x == nullptr) {
PyErr_Clear();
return false;
Expand All @@ -37,7 +46,7 @@ bool LambdaGuard::check(PyObject* value) {
}

bool GuardGroup::check(PyObject* value) {
for (auto& guard : _guards) {
for (auto& guard : guards_) {
if (!guard->check(value)) {
return false;
}
Expand All @@ -46,17 +55,17 @@ bool GuardGroup::check(PyObject* value) {
}

bool TypeMatchGuard::check(PyObject* value) {
return Py_TYPE(value) == _expected;
return Py_TYPE(value) == expected_;
}

bool ValueMatchGuard::check(PyObject* value) {
if (value == _expected_value) {
if (value == expected_value_) {
return true;
}
if (Py_TYPE(value) != _expected_type) {
if (Py_TYPE(value) != expected_type_) {
return false;
}
int result = PyObject_RichCompareBool(value, _expected_value, Py_EQ);
int result = PyObject_RichCompareBool(value, expected_value_, Py_EQ);
// Check for exception
if (result == -1) {
PyErr_Clear();
Expand All @@ -66,25 +75,41 @@ bool ValueMatchGuard::check(PyObject* value) {
}

bool LengthMatchGuard::check(PyObject* value) {
return PySequence_Size(value) == _expected;
return PySequence_Size(value) == expected_;
}

bool DtypeMatchGuard::check(PyObject* value) {
if (!paddle::pybind::PyCheckTensor(value)) {
// TODO(zrr1999): PyCheckTensor only check if the object is a p_tensor_type.
auto tensor = GetTensorFromPyObject(value);
if (!tensor) {
return false;
}
auto dtype =
reinterpret_cast<paddle::pybind::TensorObject*>(value)->tensor.type();
return phi::TransToProtoVarType(dtype) == _expected;
auto dtype = tensor->type();
return phi::TransToProtoVarType(dtype) == expected_;
}

bool ShapeMatchGuard::check(PyObject* value) {
auto tensor = GetTensorFromPyObject(value);
if (!tensor) {
return false;
}
auto shape = tensor->shape();
if (shape.size() != expected_.size()) {
return false;
}
for (size_t i = 0; i < shape.size(); ++i) {
if (expected_[i] && shape[i] != *expected_[i]) {
return false;
}
}
return true;
}

bool LayerMatchGuard::check(PyObject* value) {
if (value != _layer_ptr) {
if (value != layer_ptr_) {
return false;
}
PyObject* training = PyObject_GetAttrString(value, "training");
return (training == Py_True) == _training;
return (training == Py_True) == training_;
}

#endif
81 changes: 50 additions & 31 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */
#include <Python.h>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/pybind/sot/macros.h"
#include "paddle/phi/core/framework/heter_service.pb.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/utils/pybind.h"
#include "pybind11/pybind11.h"
Expand All @@ -38,111 +37,131 @@ class GuardBase {
class LambdaGuard : public GuardBase {
public:
explicit LambdaGuard(PyObject* guard_check_fn)
: _guard_check_fn(guard_check_fn) {}
: guard_check_fn_(guard_check_fn) {}

explicit LambdaGuard(const py::function& guard_check_fn)
: _guard_check_fn(guard_check_fn.ptr()) {
Py_INCREF(_guard_check_fn);
: guard_check_fn_(guard_check_fn.ptr()) {
Py_INCREF(guard_check_fn_);
}

~LambdaGuard() { Py_DECREF(_guard_check_fn); }
~LambdaGuard() { Py_DECREF(guard_check_fn_); }

bool check(PyObject* value);

private:
PyObject* _guard_check_fn;
PyObject* guard_check_fn_;
};

class GuardGroup : public GuardBase {
public:
explicit GuardGroup(std::vector<std::shared_ptr<GuardBase>> guards) {
explicit GuardGroup(const std::vector<std::shared_ptr<GuardBase>>& guards) {
for (auto& guard : guards) {
if (auto group = dynamic_cast<GuardGroup*>(guard.get())) {
_guards.insert(
_guards.end(), group->_guards.begin(), group->_guards.end());
guards_.insert(
guards_.end(), group->guards_.begin(), group->guards_.end());
} else {
_guards.push_back(std::move(guard));
guards_.push_back(std::move(guard));
}
}
}
bool check(PyObject* value);

private:
std::vector<std::shared_ptr<GuardBase>> _guards;
std::vector<std::shared_ptr<GuardBase>> guards_;
};

class TypeMatchGuard : public GuardBase {
public:
explicit TypeMatchGuard(PyObject* type_ptr)
: _expected(reinterpret_cast<PyTypeObject*>(type_ptr)) {}
: expected_(reinterpret_cast<PyTypeObject*>(type_ptr)) {}

explicit TypeMatchGuard(const py::type& py_type)
: _expected(reinterpret_cast<PyTypeObject*>(py_type.ptr())) {}
: expected_(reinterpret_cast<PyTypeObject*>(py_type.ptr())) {}

bool check(PyObject* value);

private:
PyTypeObject* _expected;
PyTypeObject* expected_;
};

class ValueMatchGuard : public GuardBase {
public:
explicit ValueMatchGuard(PyObject* value_ptr)
: _expected_value(value_ptr), _expected_type(value_ptr->ob_type) {}
: expected_value_(value_ptr), expected_type_(value_ptr->ob_type) {}

explicit ValueMatchGuard(const py::object& py_value)
: _expected_value(py_value.ptr()),
_expected_type(Py_TYPE(py_value.ptr())) {
Py_INCREF(_expected_value);
: expected_value_(py_value.ptr()),
expected_type_(Py_TYPE(py_value.ptr())) {
Py_INCREF(expected_value_);
}

~ValueMatchGuard() { Py_DECREF(_expected_value); }
~ValueMatchGuard() { Py_DECREF(expected_value_); }

bool check(PyObject* value);

private:
PyObject* _expected_value;
PyTypeObject* _expected_type;
PyObject* expected_value_;
PyTypeObject* expected_type_;
};

class LengthMatchGuard : public GuardBase {
public:
explicit LengthMatchGuard(Py_ssize_t length) : _expected(length) {}
explicit LengthMatchGuard(const Py_ssize_t& length) : expected_(length) {}

bool check(PyObject* value);

private:
Py_ssize_t _expected;
Py_ssize_t expected_;
};

class DtypeMatchGuard : public GuardBase {
public:
explicit DtypeMatchGuard(const paddle::framework::proto::VarType& dtype_ptr)
: _expected(dtype_ptr.type()) {}
: expected_(dtype_ptr.type()) {}

explicit DtypeMatchGuard(const phi::DataType& dtype_ptr)
: _expected(phi::TransToProtoVarType(dtype_ptr)) {}
: expected_(phi::TransToProtoVarType(dtype_ptr)) {}

bool check(PyObject* value);

private:
int _expected;
int expected_;
};

class ShapeMatchGuard : public GuardBase {
public:
explicit ShapeMatchGuard(const std::vector<std::optional<int64_t>>& shape)
: expected_(shape) {}

explicit ShapeMatchGuard(const std::vector<py::object>& shape) {
expected_.resize(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
if (py::isinstance<py::int_>(shape[i]) && shape[i].cast<int64_t>() > 0) {
expected_[i] = std::make_optional(shape[i].cast<int64_t>());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else 是 None 是么?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块主要是针对SombolcInt,现在只要不是整数都被当成动态shape了,没有做其他类型的检查

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那在构造的时候需要小心传入 -1,会产生奇怪的问题,或者直接不允许 -1(负数)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

bool check(PyObject* value);

private:
std::vector<std::optional<int64_t>> expected_;
};

class LayerMatchGuard : public GuardBase {
public:
explicit LayerMatchGuard(PyObject* layer_ptr) : _layer_ptr(layer_ptr) {
_training = PyObject_GetAttrString(layer_ptr, "training") == Py_True;
explicit LayerMatchGuard(PyObject* layer_ptr) : layer_ptr_(layer_ptr) {
training_ = PyObject_GetAttrString(layer_ptr, "training") == Py_True;
}

explicit LayerMatchGuard(const py::object& layer_obj)
: _layer_ptr(layer_obj.ptr()), _training(layer_obj.attr("training")) {}
: layer_ptr_(layer_obj.ptr()), training_(layer_obj.attr("training")) {}

bool check(PyObject* value);

private:
PyObject* _layer_ptr;
bool _training;
PyObject* layer_ptr_;
bool training_;
};

#endif
37 changes: 35 additions & 2 deletions python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@
import paddle

from ...profiler import EventGuard
from ...utils import current_symbol_registry, log, log_do
from ...utils import (
ENV_SOT_ENABLE_FASTER_GUARD,
current_symbol_registry,
log,
log_do,
)

Guard = Callable[[types.FrameType], bool]

if TYPE_CHECKING:
from .variables import VariableBase

GuardBase = paddle.framework.core.GuardBase
CheckGuardInputT = TypeVar("CheckGuardInputT", bound=VariableBase)

# NOTE(SigureMo): [How to write Stringified Guard?]
Expand Down Expand Up @@ -83,6 +89,33 @@ def __hash__(self):
return hash(self.inlined_expr)


class FasterStringifiedExpression(StringifiedExpression):
def __init__(
self,
expr_template: str,
faster_guard: GuardBase,
sub_exprs: list[StringifiedExpression],
free_vars: dict[str, Any],
):
self.faster_guard = faster_guard
if ENV_SOT_ENABLE_FASTER_GUARD:
original_expr_template = expr_template
guard_cls_name = faster_guard.__class__.__name__
guard_name = f"{guard_cls_name}_{id(faster_guard)}"
expr_template = (
guard_name + "(" + ", ".join(["{}"] * len(sub_exprs)) + ")"
)
free_vars = union_free_vars(
free_vars, {guard_name: faster_guard.check}
)
log(
3,
f"[FasterGuard]: transform {original_expr_template} to {expr_template}\n",
)

super().__init__(expr_template, sub_exprs, free_vars)


def union_free_vars(*free_vars: dict[str, Any]):
return {k: v for d in free_vars for k, v in d.items()}

Expand Down Expand Up @@ -132,7 +165,7 @@ def support_weak_ref(obj):


def check_guard(
fn: Callable[[CheckGuardInputT], list[StringifiedExpression]]
fn: Callable[[CheckGuardInputT], list[StringifiedExpression]],
) -> Callable[[CheckGuardInputT], list[StringifiedExpression]]:
def wrapper(self: CheckGuardInputT) -> list[StringifiedExpression]:
assert (
Expand Down
Loading