Skip to content

Commit a0ab679

Browse files
evonreiserikvonreisjagregorytermoshtt
authored
Derive macro for complex enums (#208)
Add new attribute macro `gen_stub_pyclass_complex_enum` that creates internal classes for the variants of rich enums, that is, enums with variants that have fields. Also creates in type_info.rs PyComplexEnumInfo and VariantInfo for manual creation of stub information for rich enums. Both of these are handled by 'ClassDef', which can now have internal classes and will write properly indented internal classes to the stub file. This PR also fixes the implementation on tuples so that tuples of length one are supported. --------- Co-authored-by: Erik von Reis <[email protected]> Co-authored-by: James Gregory <[email protected]> Co-authored-by: Toshiki Teramura <[email protected]>
1 parent cb72463 commit a0ab679

File tree

18 files changed

+817
-13
lines changed

18 files changed

+817
-13
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/pure/pure.pyi

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,100 @@ class B(A):
4545
class MyDate(datetime.date):
4646
...
4747

48+
class NumberComplex:
49+
class FLOAT(NumberComplex):
50+
r"""
51+
Float variant
52+
"""
53+
__match_args__ = ("_0",)
54+
@property
55+
def _0(self) -> builtins.float: ...
56+
def __new__(cls, _0:builtins.float) -> NumberComplex.FLOAT: ...
57+
def __len__(self) -> builtins.int: ...
58+
def __getitem__(self, key:builtins.int) -> typing.Any: ...
59+
60+
class INTEGER(NumberComplex):
61+
r"""
62+
Integer variant
63+
"""
64+
__match_args__ = ("int",)
65+
@property
66+
def int(self) -> builtins.int:
67+
r"""
68+
The integer value
69+
"""
70+
def __new__(cls, int:builtins.int=2) -> NumberComplex.INTEGER: ...
71+
72+
...
73+
74+
class Shape1:
75+
r"""
76+
Example from PyO3 documentation for complex enum
77+
https://pyo3.rs/v0.25.1/class.html#complex-enums
78+
"""
79+
class Circle(Shape1):
80+
__match_args__ = ("radius",)
81+
@property
82+
def radius(self) -> builtins.float: ...
83+
def __new__(cls, radius:builtins.float) -> Shape1.Circle: ...
84+
85+
class Rectangle(Shape1):
86+
__match_args__ = ("width", "height",)
87+
@property
88+
def width(self) -> builtins.float: ...
89+
@property
90+
def height(self) -> builtins.float: ...
91+
def __new__(cls, width:builtins.float, height:builtins.float) -> Shape1.Rectangle: ...
92+
93+
class RegularPolygon(Shape1):
94+
__match_args__ = ("_0", "_1",)
95+
@property
96+
def _0(self) -> builtins.int: ...
97+
@property
98+
def _1(self) -> builtins.float: ...
99+
def __new__(cls, _0:builtins.int, _1:builtins.float) -> Shape1.RegularPolygon: ...
100+
def __len__(self) -> builtins.int: ...
101+
def __getitem__(self, key:builtins.int) -> typing.Any: ...
102+
103+
class Nothing(Shape1):
104+
__match_args__ = ((),)
105+
def __new__(cls) -> Shape1.Nothing: ...
106+
107+
...
108+
109+
class Shape2:
110+
r"""
111+
Example from PyO3 documentation for complex enum
112+
https://pyo3.rs/v0.25.1/class.html#complex-enums
113+
"""
114+
class Circle(Shape2):
115+
__match_args__ = ("radius",)
116+
@property
117+
def radius(self) -> builtins.float: ...
118+
def __new__(cls, radius:builtins.float=1.0) -> Shape2.Circle: ...
119+
120+
class Rectangle(Shape2):
121+
__match_args__ = ("width", "height",)
122+
@property
123+
def width(self) -> builtins.float: ...
124+
@property
125+
def height(self) -> builtins.float: ...
126+
def __new__(cls, *, width:builtins.float, height:builtins.float) -> Shape2.Rectangle: ...
127+
128+
class RegularPolygon(Shape2):
129+
__match_args__ = ("side_count", "radius",)
130+
@property
131+
def side_count(self) -> builtins.int: ...
132+
@property
133+
def radius(self) -> builtins.float: ...
134+
def __new__(cls, side_count:builtins.int, radius:builtins.float=1.0) -> Shape2.RegularPolygon: ...
135+
136+
class Nothing(Shape2):
137+
__match_args__ = ((),)
138+
def __new__(cls) -> Shape2.Nothing: ...
139+
140+
...
141+
48142
class Number(Enum):
49143
FLOAT = ...
50144
INTEGER = ...

examples/pure/src/lib.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,54 @@ pub enum NumberRenameAll {
163163
Integer,
164164
}
165165

166+
#[gen_stub_pyclass_complex_enum]
167+
#[pyclass]
168+
#[pyo3(rename_all = "UPPERCASE")]
169+
#[derive(Debug, Clone)]
170+
pub enum NumberComplex {
171+
/// Float variant
172+
Float(f64),
173+
/// Integer variant
174+
#[pyo3(constructor = (int=2))]
175+
Integer {
176+
/// The integer value
177+
int: i32,
178+
},
179+
}
180+
181+
/// Example from PyO3 documentation for complex enum
182+
/// https://pyo3.rs/v0.25.1/class.html#complex-enums
183+
#[gen_stub_pyclass_complex_enum]
184+
#[pyclass]
185+
enum Shape1 {
186+
Circle { radius: f64 },
187+
Rectangle { width: f64, height: f64 },
188+
RegularPolygon(u32, f64),
189+
Nothing {},
190+
}
191+
192+
/// Example from PyO3 documentation for complex enum
193+
/// https://pyo3.rs/v0.25.1/class.html#complex-enums
194+
#[gen_stub_pyclass_complex_enum]
195+
#[pyclass]
196+
enum Shape2 {
197+
#[pyo3(constructor = (radius=1.0))]
198+
Circle {
199+
radius: f64,
200+
},
201+
#[pyo3(constructor = (*, width, height))]
202+
Rectangle {
203+
width: f64,
204+
height: f64,
205+
},
206+
#[pyo3(constructor = (side_count, radius=1.0))]
207+
RegularPolygon {
208+
side_count: u32,
209+
radius: f64,
210+
},
211+
Nothing {},
212+
}
213+
166214
#[gen_stub_pymethods]
167215
#[pymethods]
168216
impl Number {
@@ -199,6 +247,9 @@ fn pure(m: &Bound<PyModule>) -> PyResult<()> {
199247
m.add_class::<MyDate>()?;
200248
m.add_class::<Number>()?;
201249
m.add_class::<NumberRenameAll>()?;
250+
m.add_class::<NumberComplex>()?;
251+
m.add_class::<Shape1>()?;
252+
m.add_class::<Shape2>()?;
202253
m.add_function(wrap_pyfunction!(sum, m)?)?;
203254
m.add_function(wrap_pyfunction!(create_dict, m)?)?;
204255
m.add_function(wrap_pyfunction!(read_dict, m)?)?;

examples/pure/tests/test_python.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from pure import sum, create_dict, read_dict, echo_path, ahash_dict
1+
from pure import (
2+
sum,
3+
create_dict,
4+
read_dict,
5+
echo_path,
6+
ahash_dict,
7+
NumberComplex,
8+
Shape1,
9+
Shape2,
10+
)
211
import pytest
312
import pathlib
413

@@ -36,6 +45,67 @@ def test_read_dict():
3645
)
3746

3847

48+
def test_number_complex():
49+
i = NumberComplex.INTEGER(1)
50+
f = NumberComplex.FLOAT(1.5)
51+
assert i.int == 1
52+
assert f._0 == 1.5
53+
assert len(f) == 1
54+
i2 = NumberComplex.INTEGER()
55+
assert i2.int == 2
56+
57+
58+
# Test code for complex enum case from PyO3 document
59+
# https://pyo3.rs/v0.25.1/class.html#complex-enums
60+
def test_complex_enum_shape1():
61+
circle = Shape1.Circle(radius=10.0)
62+
square = Shape1.RegularPolygon(4, 10.0)
63+
64+
assert isinstance(circle, Shape1)
65+
assert isinstance(circle, Shape1.Circle)
66+
assert circle.radius == 10.0
67+
68+
assert isinstance(square, Shape1)
69+
assert isinstance(square, Shape1.RegularPolygon)
70+
assert square[0] == 4 # Gets _0 field
71+
assert square[1] == 10.0 # Gets _1 field
72+
73+
def count_vertices(cls, shape):
74+
match shape:
75+
case cls.Circle():
76+
return 0
77+
case cls.Rectangle():
78+
return 4
79+
case cls.RegularPolygon(n):
80+
return n
81+
case cls.Nothing():
82+
return 0
83+
84+
assert count_vertices(Shape1, circle) == 0
85+
assert count_vertices(Shape1, square) == 4
86+
87+
88+
# Test code for complex enum case from PyO3 document
89+
# https://pyo3.rs/v0.25.1/class.html#complex-enums
90+
def test_complex_enum_shape2():
91+
circle = Shape2.Circle()
92+
assert isinstance(circle, Shape2)
93+
assert isinstance(circle, Shape2.Circle)
94+
assert circle.radius == 1.0
95+
96+
square = Shape2.Rectangle(width=1, height=1)
97+
assert isinstance(square, Shape2)
98+
assert isinstance(square, Shape2.Rectangle)
99+
assert square.width == 1
100+
assert square.height == 1
101+
102+
hexagon = Shape2.RegularPolygon(6)
103+
assert isinstance(hexagon, Shape2)
104+
assert isinstance(hexagon, Shape2.RegularPolygon)
105+
assert hexagon.side_count == 6
106+
assert hexagon.radius == 1
107+
108+
39109
def test_path():
40110
out = echo_path(pathlib.Path("test"))
41111
assert out == pathlib.Path("test")

pyo3-stub-gen-derive/src/gen_stub.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,22 @@ mod attr;
7272
mod member;
7373
mod method;
7474
mod pyclass;
75+
mod pyclass_complex_enum;
7576
mod pyclass_enum;
7677
mod pyfunction;
7778
mod pymethods;
7879
mod renaming;
7980
mod signature;
8081
mod stub_type;
8182
mod util;
83+
mod variant;
8284

8385
use arg::*;
8486
use attr::*;
8587
use member::*;
8688
use method::*;
8789
use pyclass::*;
90+
use pyclass_complex_enum::*;
8891
use pyclass_enum::*;
8992
use pyfunction::*;
9093
use pymethods::*;
@@ -123,6 +126,18 @@ pub fn pyclass_enum(item: TokenStream2) -> Result<TokenStream2> {
123126
})
124127
}
125128

129+
pub fn pyclass_complex_enum(item: TokenStream2) -> Result<TokenStream2> {
130+
let inner = PyComplexEnumInfo::try_from(parse2::<ItemEnum>(item.clone())?)?;
131+
let derive_stub_type = StubType::from(&inner);
132+
Ok(quote! {
133+
#item
134+
#derive_stub_type
135+
pyo3_stub_gen::inventory::submit! {
136+
#inner
137+
}
138+
})
139+
}
140+
126141
pub fn pymethods(item: TokenStream2) -> Result<TokenStream2> {
127142
let mut item_impl = parse2::<ItemImpl>(item)?;
128143
let inner = PyMethodsInfo::try_from(item_impl.clone())?;

pyo3-stub-gen-derive/src/gen_stub/attr.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub enum Attr {
5858
Set,
5959
SetAll,
6060
Module(String),
61+
Constructor(Signature),
6162
Signature(Signature),
6263
RenameAll(RenamingRule),
6364
Extends(Type),
@@ -148,6 +149,9 @@ pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
148149
[Ident(ident), Punct(_), Group(group)] => {
149150
if ident == "signature" {
150151
pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
152+
} else if ident == "constructor" {
153+
pyo3_attrs
154+
.push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
151155
}
152156
}
153157
[Ident(ident), Punct(_), Ident(ident2)] => {

pyo3-stub-gen-derive/src/gen_stub/member.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ use crate::gen_stub::{attr::parse_gen_stub_default, extract_documents};
22

33
use super::{escape_return_type, parse_pyo3_attrs, Attr};
44

5+
use crate::gen_stub::arg::ArgInfo;
56
use proc_macro2::TokenStream as TokenStream2;
67
use quote::{quote, ToTokens, TokenStreamExt};
78
use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result, Type};
89

9-
#[derive(Debug)]
10+
#[derive(Debug, Clone)]
1011
pub struct MemberInfo {
1112
doc: String,
1213
name: String,
@@ -196,3 +197,11 @@ impl ToTokens for MemberInfo {
196197
})
197198
}
198199
}
200+
201+
impl From<MemberInfo> for ArgInfo {
202+
fn from(value: MemberInfo) -> Self {
203+
let MemberInfo { name, r#type, .. } = value;
204+
205+
Self { name, r#type }
206+
}
207+
}

pyo3-stub-gen-derive/src/gen_stub/pyclass.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
12
use proc_macro2::TokenStream as TokenStream2;
23
use quote::{quote, ToTokens, TokenStreamExt};
34
use syn::{parse_quote, Error, ItemStruct, Result, Type};
45

5-
use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType};
6-
76
pub struct PyClassInfo {
87
pyclass_name: String,
98
struct_type: Type,
@@ -142,7 +141,7 @@ mod test {
142141
"#,
143142
)?;
144143
let out = PyClassInfo::try_from(input)?.to_token_stream();
145-
insta::assert_snapshot!(format_as_value(out), @r###"
144+
insta::assert_snapshot!(format_as_value(out), @r#"
146145
::pyo3_stub_gen::type_info::PyClassInfo {
147146
pyclass_name: "Placeholder",
148147
struct_id: std::any::TypeId::of::<PyPlaceholder>,
@@ -171,7 +170,7 @@ mod test {
171170
doc: "",
172171
bases: &[],
173172
}
174-
"###);
173+
"#);
175174
Ok(())
176175
}
177176

0 commit comments

Comments
 (0)