Skip to content

Conversation

@lidavidm
Copy link
Member

This doesn't support variable-width types (e.g. strings) as the implementation here is columnwise. I will work on those separately (they require a rowwise implementation).

Also fixes a small bug in the CommonNumericType implementation (I noticed uint8 was getting promoted to int8).

@github-actions
Copy link

@lidavidm lidavidm force-pushed the arrow-13064 branch 3 times, most recently from 3e7e6cb to a00643e Compare June 25, 2021 17:34
@lidavidm
Copy link
Member Author

lidavidm commented Jun 25, 2021

Note there's some code here for fixed-width types that duplicates what's in ARROW-9430/#10412. They should probably get unified.

@jorisvandenbossche
Copy link
Member

jorisvandenbossche commented Jun 29, 2021

To start my usual round of name bike-shedding ;), I would propose to not call this "select", to avoid the confusion with SQL's SELECT / dplyr's select() function (I think the "select" name only comes from np.select?)
Given this is for SQL's CASE, maybe something like "select_case" or "case_when" (or "choose_case", or ..)?

@lidavidm
Copy link
Member Author

case_when sounds good to me. Or perhaps switch_case by analogy to if_else.

@pitrou
Copy link
Member

pitrou commented Jun 29, 2021

Ok for case_when, another possibility is cond (like in Lisp ;-)).

@lidavidm
Copy link
Member Author

Renamed to case_when.

@pitrou
Copy link
Member

pitrou commented Jun 30, 2021

It looks like this needs rebasing.

@lidavidm
Copy link
Member Author

Thanks for the heads up, rebased.

Copy link
Member

@pitrou pitrou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thank you. Do you want to add a couple benchmarks for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but would be nice to normalize includes here. I think we normally use #include "arrow/..." for intra-Arrow inclusions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead use ValidateOutput from kernels/test_util.h. It will also check that no data is left uninitialized.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be moved to kernels/test_util.{h,cc} instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth noting: this seems to be very similar in function to CheckWithDifferentShapes. I think it'd be worthwhile to unify these two and promote it to test_util.h so it can be reused by more varargs scalar kernels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also somewhat overlapping with CheckScalar now that I look at it. I'll take a look and see if I can't consolidate all three. (I may do so in a different PR until we decide what to do with the implementation here, if we want to split it up into separate 'case' and 'when' functions or not as suggested.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also {scalar_false, values2, scalar_true, values1} -> values1?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a condition being null is the same as being false, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, null is the same as false.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the definition of ReplaceTypes, this check doesn't seem necessary?
cc @bkietz for a second opinion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me we could restructure "case_when" and yield a less daunting interface:

case(
  when(cond_0, value_0),
  when(cond_1, value_1),
  value_else
)

Where when masks slots with null wherever its condition is not true and case is only a variadic coalescing function (takes the first non null).

We'd be allocating a new null bitmap on each call to when which is not ideal. However typing is far clearer (when(cond: Boolean, value: T): T, case(...values: T): T) and case/when can be independently unit tested.

@pitrou what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: variadic coalesce is independently useful https://issues.apache.org/jira/browse/ARROW-13136

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to transpose this would be:

case(
  first_true_in(cond_0, cond1),
  value0,
  value1,
  value_else
)

Where first_true_in returns the index of the first true and case's first argument is an index into the rest of its arguments:

assert first_true_in([True, False, False], [True, True, False]) == [0, 1, None]
assert case([0, 1, None], 'a b c'.split(), 'd e f'.split(), 'x y z'.split()) == 'a e z'.split()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the second example, case would essentially be 'choose' (ARROW-13220). I'm curious which would be faster; the coalesce PR is implemented columnwise (for fixed-width types) and tries to do block copies, while the 'choose' PR I'm looking at is necessarily rowwise. Probably depends on the distribution of the inputs.

Might the masking 'when' kernel and 'first_true_in' be useful kernels to have anyways?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that

case(
  when(cond_0, value_0),
  when(cond_1, value_1),
  value_else
)

is not necessarily identical to case_when(cond_0, value_0, cond_1, value_1, value_else) as done in this PR, when it comes to nulls in the values?

For example, if you have a null in value_0, it should be chosen if the corresponding bit of cond_0 is True. However, if you implement that with a masking when and then coalesce (taking the first non null), it would not pick the null value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, you're right…

We can still try the 'choose' based version though I would expect that to be slower since it works rowwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put up PRs for both choose: #10642 and coalesce: #10608 (though as you point out, we can't quite use coalesce for this).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick check with choose shows that it's only about 2/3 as fast:

CaseWhenBench32/1048576/0              48183534 ns     48183323 ns           13 bytes_per_second=334.659M/s
CaseWhenBench64/1048576/0              48610718 ns     48610432 ns           13 bytes_per_second=660.866M/s
CaseWhenBench32/1048576/99             48226819 ns     48226677 ns           15 bytes_per_second=334.327M/s
CaseWhenBench64/1048576/99             48891957 ns     48892172 ns           13 bytes_per_second=656.996M/s

though note we have to call both first_true_in, then fill_null, to get the 'else' behavior which doesn't help.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere, you should also test what happens when case_when is called with zero arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of uint8_t* out_values, you may want this to take a ArrayData* out, since you'll need it for non-fixed-width types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-fixed-width types, it might be simpler to handle it with a builder and AppendScalar?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit excessive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be faster as written, oddly.

Before:

-----------------------------------------------------------------------------------------------
Benchmark                                     Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench32/1048576/0              31933112 ns     31932368 ns           22 bytes_per_second=504.974M/s
CaseWhenBench64/1048576/0              33170481 ns     33168735 ns           21 bytes_per_second=968.533M/s
CaseWhenBench32/1048576/99             32487300 ns     32487411 ns           21 bytes_per_second=496.299M/s
CaseWhenBench64/1048576/99             33682029 ns     33680901 ns           21 bytes_per_second=953.715M/s
CaseWhenBench32Contiguous/1048576/0     7255445 ns      7255387 ns           96 bytes_per_second=1.632G/s
CaseWhenBench64Contiguous/1048576/0     7932437 ns      7932171 ns           88 bytes_per_second=2.97013G/s
CaseWhenBench32Contiguous/1048576/99    7526742 ns      7526709 ns           92 bytes_per_second=1.57303G/s
CaseWhenBench64Contiguous/1048576/99    8172498 ns      8172239 ns           83 bytes_per_second=2.88261G/s

After:

-----------------------------------------------------------------------------------------------
Benchmark                                     Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench32/1048576/0              44166172 ns     44165634 ns           16 bytes_per_second=365.103M/s
CaseWhenBench64/1048576/0              44605356 ns     44603995 ns           16 bytes_per_second=720.227M/s
CaseWhenBench32/1048576/99             44867670 ns     44867051 ns           16 bytes_per_second=359.361M/s
CaseWhenBench64/1048576/99             45077818 ns     45076721 ns           15 bytes_per_second=712.607M/s
CaseWhenBench32Contiguous/1048576/0    17757494 ns     17757271 ns           39 bytes_per_second=682.819M/s
CaseWhenBench64Contiguous/1048576/0    18236327 ns     18235892 ns           38 bytes_per_second=1.29193G/s
CaseWhenBench32Contiguous/1048576/99   15051281 ns     15051008 ns           39 bytes_per_second=805.518M/s
CaseWhenBench64Contiguous/1048576/99   15504081 ns     15503998 ns           45 bytes_per_second=1.51944G/s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use GetNullCount() I think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem significantly different from CheckDispatchBest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree; there's too much indirection here to see what's actually being tested. Please inline this some

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth noting: this seems to be very similar in function to CheckWithDifferentShapes. I think it'd be worthwhile to unify these two and promote it to test_util.h so it can be reused by more varargs scalar kernels

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::copy allows the ranges to overlap (unlike memcpy), so for simple pointers like this it gets inlined to memmove. Memmove can be slower (or faster) than memcpy; depends on your libc. It shouldn't differ by a very wide margin though so I don't think it's necessary to avoid std::copy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me we could restructure "case_when" and yield a less daunting interface:

case(
  when(cond_0, value_0),
  when(cond_1, value_1),
  value_else
)

Where when masks slots with null wherever its condition is not true and case is only a variadic coalescing function (takes the first non null).

We'd be allocating a new null bitmap on each call to when which is not ideal. However typing is far clearer (when(cond: Boolean, value: T): T, case(...values: T): T) and case/when can be independently unit tested.

@pitrou what do you think?

@lidavidm lidavidm marked this pull request as draft July 1, 2021 17:14
@lidavidm lidavidm marked this pull request as ready for review July 2, 2021 16:49
@lidavidm
Copy link
Member Author

lidavidm commented Jul 7, 2021

I've changed this so the signature is CaseWhen<T>(struct(bool...), T...) instead; if this looks more reasonable, we can proceed, else I'll revert it.

@jorisvandenbossche
Copy link
Member

What's the advantage of using a struct of boolean arrays as first argument instead of a vector? (eg for a struct you need to give it (dummy) names)

@lidavidm
Copy link
Member Author

lidavidm commented Jul 7, 2021

@bkietz and I discussed this. Mostly it simplifies the function signature a lot and gets rid of the custom typechecking code in DispatchBest. It does require materializing scalar conditions and would need callers to do some pre-processing if they wish to avoid that (e.g. dropping any values associated with a scalar false/null condition, using values associated with a scalar true condition as the 'else' clause).

It's mostly a point of discussion. I realized I forgot to benchmark so let me do that now.

@lidavidm
Copy link
Member Author

lidavidm commented Jul 7, 2021

This is quite faster though (actually almost suspiciously so; also, the benchmark loop doesn't include constructing a structarray from individual boolean arrays).

Struct:

CaseWhenBench32/1048576/0               5780694 ns      5780134 ns          116 bytes_per_second=2.72434G/s
CaseWhenBench64/1048576/0               6079444 ns      6078645 ns          115 bytes_per_second=5.16103G/s
CaseWhenBench32/1048576/99              5717893 ns      5717306 ns          122 bytes_per_second=2.75402G/s
CaseWhenBench64/1048576/99              6160254 ns      6159488 ns          110 bytes_per_second=5.09281G/s
CaseWhenBench32Contiguous/1048576/0     1652127 ns      1645047 ns          423 bytes_per_second=7.19786G/s
CaseWhenBench64Contiguous/1048576/0     1986278 ns      1986066 ns          334 bytes_per_second=11.8624G/s
CaseWhenBench32Contiguous/1048576/99    1826875 ns      1826675 ns          377 bytes_per_second=6.48156G/s
CaseWhenBench64Contiguous/1048576/99    2415119 ns      2414697 ns          312 bytes_per_second=9.75582G/s

Alternating arguments: (EDIT: updated since I forgot to rebuild first)

CaseWhenBench32/1048576/0              32471264 ns     32470765 ns           22 bytes_per_second=496.601M/s
CaseWhenBench64/1048576/0              33443346 ns     33443152 ns           21 bytes_per_second=960.585M/s
CaseWhenBench32/1048576/99             33171599 ns     33171474 ns           21 bytes_per_second=486.065M/s
CaseWhenBench64/1048576/99             33970437 ns     33969940 ns           21 bytes_per_second=945.6M/s
CaseWhenBench32Contiguous/1048576/0     7543780 ns      7543809 ns           92 bytes_per_second=1.56961G/s
CaseWhenBench64Contiguous/1048576/0     7982627 ns      7982575 ns           86 bytes_per_second=2.95137G/s
CaseWhenBench32Contiguous/1048576/99    7894886 ns      7894844 ns           89 bytes_per_second=1.49968G/s
CaseWhenBench64Contiguous/1048576/99    8419204 ns      8419033 ns           82 bytes_per_second=2.79811G/s

@lidavidm
Copy link
Member Author

Any other thoughts here? If we're ok with CaseWhen(struct(bool...), T...) over CaseWhen(bool, T, bool, T, ...) I can make this kernel also support writing into slices.

@bkietz
Copy link
Member

bkietz commented Jul 13, 2021

I'm fine with CaseWhen(struct(bool...), T...), and I'd love for this to support writing into slices. Could you also file a follow up JIRA for simplification of call expressions involving case_when?

@lidavidm
Copy link
Member Author

Filed ARROW-13325.

@lidavidm
Copy link
Member Author

Updated to support can_write_into_slices. Also, updated the benchmark bytes_per_second scaling to reflect the size of the output and not the size of all the inputs so the numbers are more reasonable.

-----------------------------------------------------------------------------------------------
Benchmark                                     Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench32/1048576/0               6065579 ns      6065602 ns           97 bytes_per_second=659.456M/s
CaseWhenBench64/1048576/0               6485780 ns      6485814 ns           93 bytes_per_second=1.20455G/s
CaseWhenBench32/1048576/99              6280547 ns      6280511 ns          114 bytes_per_second=636.831M/s
CaseWhenBench64/1048576/99              6694807 ns      6694710 ns          105 bytes_per_second=1.16686G/s
CaseWhenBench32Contiguous/1048576/0     1761802 ns      1761781 ns          378 bytes_per_second=2.21722G/s
CaseWhenBench64Contiguous/1048576/0     2267441 ns      2267379 ns          280 bytes_per_second=3.44561G/s
CaseWhenBench32Contiguous/1048576/99    1897469 ns      1897449 ns          325 bytes_per_second=2.05849G/s
CaseWhenBench64Contiguous/1048576/99    2515244 ns      2515231 ns          275 bytes_per_second=3.10578G/s

Copy link
Member

@bkietz bkietz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments, otherwise this looks good to go

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and some of the other comments) need to be updated with the new signature

Comment on lines 150 to 152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto null_scalar = MakeNullScalar(boolean());
ASSERT_OK_AND_ASSIGN(auto nulls,
MakeArrayFromScalar(*null_scalar, len - 2 * (len / 3)));
ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(boolean(), len - 2 * (len / 3)));

Comment on lines +111 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto cond1 = std::static_pointer_cast<BooleanArray>(
rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
auto cond2 = std::static_pointer_cast<BooleanArray>(
rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
auto cond3 = std::static_pointer_cast<BooleanArray>(
rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
auto fld = field("cond", boolean(), key_value_metadata({{"null_probability", "1.0"}}));
auto cond = rand.ArrayOf(field("", struct_({fld, fld, fld}), len);

?

@lidavidm
Copy link
Member Author

Fixed docstrings and fixed a TODO I noticed about fully initializing the output buffer.

Also added a benchmark case for when both the cond struct array and the child cond boolean arrays can have nulls. This case is especially terrible, I made a slight optimization to eliminate one of the more egregious offenders I saw in perf, but it's still bad even then:

-----------------------------------------------------------------------------------------------
Benchmark                                     Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench64/1048576/0               6045379 ns      6045204 ns          125 bytes_per_second=1.29235G/s
CaseWhenBench64/1048576/99              6106510 ns      6106346 ns          123 bytes_per_second=1.27929G/s
CaseWhenBench64OuterNulls/1048576/0    32735686 ns     32733991 ns           23 bytes_per_second=244.394M/s
CaseWhenBench64OuterNulls/1048576/99   34098317 ns     34097473 ns           21 bytes_per_second=234.599M/s
CaseWhenBench64Contiguous/1048576/0     2229728 ns      2229687 ns          338 bytes_per_second=3.50386G/s
CaseWhenBench64Contiguous/1048576/99    2456084 ns      2456085 ns          237 bytes_per_second=3.18058G/s

@bkietz
Copy link
Member

bkietz commented Jul 14, 2021

@lidavidm what would you think about just raising an error for top level nulls? It doesn't seem like a useful case to me

@lidavidm
Copy link
Member Author

Should be fine. It would certainly trim down the inner loop a lot.

@lidavidm
Copy link
Member Author

Removed support for toplevel nulls.

@lidavidm
Copy link
Member Author

@bkietz anything left either here or in #10608?

@jorisvandenbossche
Copy link
Member

Any other thoughts here? If we're ok with CaseWhen(struct(bool...), T...) over CaseWhen(bool, T, bool, T, ...)

As mentioned on the meeting, I don't have a strong opinion on the exact signature (assuming you can create the inputs easily in the bindings either way), and if it simplifies the implementation / signature quite a bit, that sounds as a good reason.

I quickly tried it out in Python:

In [7]: cond = pc.project(pa.array([True, False, None]), pa.array([False, True, None]), field_names=[b"a", b"b"])

In [8]: pc.case_when(cond, pa.array([1, 2, 3]), pa.array([11, 12, 13]))
Out[8]: 
<pyarrow.lib.Int64Array object at 0x7f104124c820>
[
  1,
  12,
  null
]

What I find a little bit annoying is that I have to provide "dummy" names (that are never used) to my boolean conditions in project (maybe we could have some default for this so for this specific case specifying names is not required?)
(I also noticed that the field names needed to be bytes and not strings, I suppose that's an error in the cython bindings, will open an issue about that).
One other issue is that this uses the "project" kernel, which I thought was not really meant for users? (ARROW-11206) I could of course also have used pa.StructArray.from_arrays I assume.

Copy link
Member

@bkietz bkietz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jorisvandenbossche I'll make a PR for ARROW-11206 shortly and allow make_struct to be called with no names provided (defaulting to empty names or str(field_index) or so). That should enable you to write

pc.case_when(
    pc.make_struct([True, False, None],
                   [False, True, None]),
    [1, 2, 3],
    [11, 12, 13])

ASSERT_FALSE(sig.MatchesInputs(args));
}
{
KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working this out

@bkietz bkietz closed this in dbeed52 Jul 15, 2021
@jorisvandenbossche
Copy link
Member

@bkietz sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants