-
Notifications
You must be signed in to change notification settings - Fork 179
Implement strum(flatten) for EnumIter #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Implement strum(flatten) for EnumIter #425
Conversation
ff6eab2 to
7b81796
Compare
|
This is the manual implementation of what these macros generate. It has some #[derive(Debug, Eq, PartialEq)]
enum Vibe {
Weak,
Average,
Strong,
}
impl Vibe {
fn iter() -> <Self as IntoIterator>::IntoIter {
let vibe = Vibe::Weak;
vibe.into_iter()
}
}
impl IntoIterator for Vibe {
type Item = Vibe;
type IntoIter = std::vec::IntoIter<Vibe>;
fn into_iter(self) -> Self::IntoIter {
vec![Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
}
}
const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
Light,
Med1(Vibe),
Med2(Vibe),
Med3(Vibe),
Dark,
}
impl Shade {
fn iter() -> ShadeIter {
ShadeIter {
idx: 0,
med1_iter: Some(Vibe::iter()),
med2_iter: Some(Vibe::iter()),
med3_iter: Some(Vibe::iter()),
back_idx: 0,
}
}
}
impl Shade {
fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
vec![Shade::Light]
.into_iter()
.chain(Vibe::iter().map(Shade::Med1))
.chain(Vibe::iter().map(Shade::Med2))
.chain(Vibe::iter().map(Shade::Med3))
.chain(vec![Shade::Dark])
}
}
struct ShadeIter {
idx: usize,
med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
back_idx: usize,
}
#[derive(Debug)]
enum Res {
Done(Shade),
DoneStep(Shade),
EndStep,
End,
}
impl ShadeIter {
fn nested_get(
nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
forward: bool,
) -> Res {
let next_inner = if forward {
nested_iter.as_mut().and_then(|t| t.next())
} else {
nested_iter.as_mut().and_then(|t| t.next_back())
};
if let Some(it) = next_inner {
Res::DoneStep(wrap(it))
} else {
nested_iter.take();
Res::EndStep
}
}
fn get(&mut self, idx: usize, forward: bool) -> Res {
let res = match dbg!(idx) {
0 => Res::Done(Shade::Light),
1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
4 => Res::Done(Shade::Dark),
_ => Res::End,
};
dbg!(res)
}
}
impl Iterator for ShadeIter {
type Item = Shade;
fn next(&mut self) -> Option<Self::Item> {
self.nth(0)
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
match ShadeIter::get(self, dbg!(self.idx) + dbg!(n), true) {
Res::Done(x) => {
// move to requested, and past it
self.idx += n + 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
self.idx += n;
Some(x)
}
Res::EndStep => {
// ok, this one failed, move past it and request again
self.idx += 1;
let res = self.nth(0);
res
}
Res::End => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
/*
let min = if self.idx + self.back_idx >= SHADE_NUM {
0
} else {
SHADE_NUM - self.idx - self.back_idx
};
*/
let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
t.len()
});
let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
t.len()
});
let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
t.len()
});
let t = SHADE_NUM
+ dbg!(med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
+ dbg!(med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
+ dbg!(med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
- dbg!(self.idx)
- dbg!(self.back_idx);
(t, Some(t))
}
}
impl ShadeIter {
fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
let res = match ShadeIter::get(
self,
SHADE_NUM - dbg!(self.back_idx) - back_n - 1,
false,
) {
Res::Done(x) => {
// move to requested, and past it
self.back_idx += 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
Some(x)
}
Res::EndStep => {
// ok, this one failed, try the next one
self.back_idx += 1;
self.nth_back(0)
}
Res::End => None,
};
res
}
}
impl DoubleEndedIterator for ShadeIter {
fn next_back(&mut self) -> Option<Self::Item> {
self.nth_back(0)
}
}
impl ExactSizeIterator for ShadeIter {
fn len(&self) -> usize {
self.size_hint().0
}
}
fn main() {
println!("Hello, world!");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flatten() {
let result = Shade::iter().collect::<Vec<_>>();
let expected = vec![
Shade::Light,
Shade::Med1(Vibe::Weak),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Strong),
Shade::Med2(Vibe::Weak),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Strong),
Shade::Med3(Vibe::Weak),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Strong),
Shade::Dark,
];
assert_eq!(result, expected);
}
#[test]
fn flatten_back() {
let result = Shade::iter().rev().collect::<Vec<_>>();
let expected = vec![
Shade::Dark,
Shade::Med3(Vibe::Strong),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Weak),
Shade::Med2(Vibe::Strong),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Weak),
Shade::Med1(Vibe::Strong),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Weak),
Shade::Light,
];
assert_eq!(result, expected);
}
#[test]
fn iter_mixed_next_and_next_back() {
let mut iter = Shade::iter();
assert_eq!(iter.next(), Some(Shade::Light));
assert_eq!(iter.next_back(), Some(Shade::Dark));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
assert_eq!(iter.next_back(), None);
}
#[test]
fn iter_quickheck() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
let mut results = vec![];
let mut expected = vec![];
for _ in 0..500 {
if rng.random_bool(0.5) {
results.push(iter.next());
expected.push(simple_iter.next());
} else {
results.push(iter.next_back());
expected.push(simple_iter.next_back());
}
}
assert_eq!(results, expected);
}
}
#[test]
fn iter_quickheck_sizehint() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
for _ in 0..500 {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
_ = simple_iter.next();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
} else {
dbg!("next_back");
_ = iter.next_back();
_ = simple_iter.next_back();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
}
}
}
}
#[test]
fn iter_quickheck_len() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
const MAX: usize = 11;
assert_eq!(dbg!(iter.len()), MAX);
for i in 1..=MAX {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
} else {
dbg!("next_back");
_ = iter.next_back();
}
assert_eq!(dbg!(iter.len()), MAX - i);
}
}
}
}Open to your comments 🙌 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thx for this PR I also need this but didn't get the time to look into it much, you're saving me big time !
One concern I have is that the generated code isn't compatible with no_std anymore due to the vec![].
Quick testing shows that a simple array also does the trick
Updated example, nothing much changes, every vec![] is replaced by [] and
type IntoIter = std::vec::IntoIter<Vibe>; becomes type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter; which could get tricky, maybe generate an associated constant containing the number of variants (<[Self; 4 + 3] as core::iter::IntoIterator>::IntoIter works) ?
#![no_std]
#[derive(Debug, Eq, PartialEq)]
enum Vibe {
Weak,
Average,
Strong,
}
impl Vibe {
fn iter() -> <Self as IntoIterator>::IntoIter {
let vibe = Vibe::Weak;
vibe.into_iter()
}
}
impl IntoIterator for Vibe {
type Item = Vibe;
type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
[Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
}
}
const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
Light,
Med1(Vibe),
Med2(Vibe),
Med3(Vibe),
Dark,
}
impl Shade {
fn iter() -> ShadeIter {
ShadeIter {
idx: 0,
med1_iter: Some(Vibe::iter()),
med2_iter: Some(Vibe::iter()),
med3_iter: Some(Vibe::iter()),
back_idx: 0,
}
}
}
impl Shade {
fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
[Shade::Light]
.into_iter()
.chain(Vibe::iter().map(Shade::Med1))
.chain(Vibe::iter().map(Shade::Med2))
.chain(Vibe::iter().map(Shade::Med3))
.chain([Shade::Dark])
}
}
struct ShadeIter {
idx: usize,
med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
back_idx: usize,
}
#[derive(Debug)]
enum Res {
Done(Shade),
DoneStep(Shade),
EndStep,
End,
}
impl ShadeIter {
fn nested_get(
nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
forward: bool,
) -> Res {
let next_inner = if forward {
nested_iter.as_mut().and_then(|t| t.next())
} else {
nested_iter.as_mut().and_then(|t| t.next_back())
};
if let Some(it) = next_inner {
Res::DoneStep(wrap(it))
} else {
nested_iter.take();
Res::EndStep
}
}
fn get(&mut self, idx: usize, forward: bool) -> Res {
match idx {
0 => Res::Done(Shade::Light),
1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
4 => Res::Done(Shade::Dark),
_ => Res::End,
}
}
}
impl Iterator for ShadeIter {
type Item = Shade;
fn next(&mut self) -> Option<Self::Item> {
self.nth(0)
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
match ShadeIter::get(self, self.idx + n, true) {
Res::Done(x) => {
// move to requested, and past it
self.idx += n + 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
self.idx += n;
Some(x)
}
Res::EndStep => {
// ok, this one failed, move past it and request again
self.idx += 1;
let res = self.nth(0);
res
}
Res::End => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
/*
let min = if self.idx + self.back_idx >= SHADE_NUM {
0
} else {
SHADE_NUM - self.idx - self.back_idx
};
*/
let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
t.len()
});
let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
t.len()
});
let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
t.len()
});
let t = SHADE_NUM
+ (med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
+ (med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
+ (med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
- (self.idx)
- (self.back_idx);
(t, Some(t))
}
}
impl ShadeIter {
fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
let res = match ShadeIter::get(
self,
SHADE_NUM - self.back_idx - back_n - 1,
false,
) {
Res::Done(x) => {
// move to requested, and past it
self.back_idx += 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
Some(x)
}
Res::EndStep => {
// ok, this one failed, try the next one
self.back_idx += 1;
self.nth_back(0)
}
Res::End => None,
};
res
}
}
impl DoubleEndedIterator for ShadeIter {
fn next_back(&mut self) -> Option<Self::Item> {
self.nth_back(0)
}
}
impl ExactSizeIterator for ShadeIter {
fn len(&self) -> usize {
self.size_hint().0
}
}
const fn main() {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flatten() {
let result = Shade::iter().collect::<Vec<_>>();
let expected = vec![
Shade::Light,
Shade::Med1(Vibe::Weak),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Strong),
Shade::Med2(Vibe::Weak),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Strong),
Shade::Med3(Vibe::Weak),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Strong),
Shade::Dark,
];
assert_eq!(result, expected);
}
#[test]
fn flatten_back() {
let result = Shade::iter().rev().collect::<Vec<_>>();
let expected = vec![
Shade::Dark,
Shade::Med3(Vibe::Strong),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Weak),
Shade::Med2(Vibe::Strong),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Weak),
Shade::Med1(Vibe::Strong),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Weak),
Shade::Light,
];
assert_eq!(result, expected);
}
#[test]
fn iter_mixed_next_and_next_back() {
let mut iter = Shade::iter();
assert_eq!(iter.next(), Some(Shade::Light));
assert_eq!(iter.next_back(), Some(Shade::Dark));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
assert_eq!(iter.next_back(), None);
}
#[test]
fn iter_quickheck() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
let mut results = vec![];
let mut expected = vec![];
for _ in 0..500 {
if rng.random_bool(0.5) {
results.push(iter.next());
expected.push(simple_iter.next());
} else {
results.push(iter.next_back());
expected.push(simple_iter.next_back());
}
}
assert_eq!(results, expected);
}
}
#[test]
fn iter_quickheck_sizehint() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
for _ in 0..500 {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
_ = simple_iter.next();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
} else {
dbg!("next_back");
_ = iter.next_back();
_ = simple_iter.next_back();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
}
}
}
}
#[test]
fn iter_quickheck_len() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
const MAX: usize = 11;
assert_eq!(dbg!(iter.len()), MAX);
for i in 1..=MAX {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
} else {
dbg!("next_back");
_ = iter.next_back();
}
assert_eq!(dbg!(iter.len()), MAX - i);
}
}
}
}|
@vic1707 Vibe::iter() was added because I needed a nested iterator, and yeah, I didn't care much about its implementation, because it wouldn't be present in "real" code. Shade::simple_iter() is there so that I have something to compare results to without writing too many tests, so it wouldn't be present in generated code as well. Thanks for noting that, though. I guess the drawback of |
|
Sorry I for that misunderstanding on my part, good job, can't wait to see it land if the devs are ok 👍 |
| custom_keyword!(default_with); | ||
| custom_keyword!(props); | ||
| custom_keyword!(ascii_case_insensitive); | ||
| custom_keyword!(flatten); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly, the biggest concern I have here is how should #[strum(flatten)] interact with other derives
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically this thing
Fixes #424
As I said, it's possible if slightly complex. I'm not an expert in writing iterators, though, so maybe it's possible to cut some rough edges; I just tried to make it correct.
I tried to produce a slim diff, but DoubleEndedIterator implementation went into pieces.
Also, you can see in tests,
Color::simple_iter()gives much simpler implementation, but maybe a bit slower to run and/or compile? I didn't bench it.UPD: I think I know how to simplify this a little (without going through implementation I pointed out above), so if you're interested I'll try to refactor it a bit