Skip to content

Commit c993884

Browse files
committed
allow json types
1 parent 30d5198 commit c993884

File tree

4 files changed

+81
-54
lines changed

4 files changed

+81
-54
lines changed

src/groups.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
# BenchmarkGroup #
33
##################
44

5+
const KeyTypes = Union{String,Int,Float64}
6+
makekey(v::KeyTypes) = v
7+
makekey(v::Real) = (v2 = Float64(v); v2 == v ? v2 : string(v))
8+
makekey(v::Integer) = typemin(Int) <= v <= typemax(Int) ? Int(v) : string(v)
9+
makekey(v::Tuple) = (Any[i isa Tuple ? string(i) : makekey(i) for i in v]...,)::Tuple{Vararg{KeyTypes}}
10+
makekey(v::Any) = string(v)::String
11+
512
struct BenchmarkGroup
6-
tags::Vector{String}
7-
data::Dict{String,Any}
13+
tags::Vector{Any}
14+
data::Dict{Any,Any}
815
end
916

10-
BenchmarkGroup(tags::Vector{String}, args::Pair{String}...) = BenchmarkGroup(tags, Dict{String,Any}(args))
11-
BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict{String,Any}((string(k) => v for (k, v) in args)))
12-
BenchmarkGroup(args::Pair...) = BenchmarkGroup(String[], args...)
17+
BenchmarkGroup(tags::Vector, args::Pair...) = BenchmarkGroup(tags, Dict{Any,Any}((makekey(k) => v for (k, v) in args)))
18+
BenchmarkGroup(args::Pair...) = BenchmarkGroup([], args...)
1319

1420
function addgroup!(suite::BenchmarkGroup, id, args...)
1521
g = BenchmarkGroup(args...)
@@ -25,18 +31,14 @@ Base.copy(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), copy(group.d
2531
Base.similar(group::BenchmarkGroup) = BenchmarkGroup(copy(group.tags), empty(group.data))
2632
Base.isempty(group::BenchmarkGroup) = isempty(group.data)
2733
Base.length(group::BenchmarkGroup) = length(group.data)
28-
Base.getindex(group::BenchmarkGroup, k::String) = getindex(group.data, k)
29-
Base.getindex(group::BenchmarkGroup, k) = getindex(group.data, string(k))
30-
Base.getindex(group::BenchmarkGroup, k...) = getindex(group.data, string(k))
31-
Base.setindex!(group::BenchmarkGroup, v, k::String) = setindex!(group.data, v, k)
32-
Base.setindex!(group::BenchmarkGroup, v, k) = setindex!(group.data, v, string(k))
33-
Base.setindex!(group::BenchmarkGroup, v, k...) = setindex!(group.data, v, string(k))
34-
Base.delete!(group::BenchmarkGroup, k::String) = delete!(group.data, k)
35-
Base.delete!(group::BenchmarkGroup, k) = delete!(group.data, string(k))
36-
Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, string(k))
37-
Base.haskey(group::BenchmarkGroup, k::String) = haskey(group.data, k)
38-
Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, string(k))
39-
Base.haskey(group::BenchmarkGroup, k...) = haskey(group.data, string(k))
34+
Base.getindex(group::BenchmarkGroup, k) = getindex(group.data, makekey(k))
35+
Base.getindex(group::BenchmarkGroup, k...) = getindex(group.data, makekey(k))
36+
Base.setindex!(group::BenchmarkGroup, v, k) = setindex!(group.data, v, makekey(k))
37+
Base.setindex!(group::BenchmarkGroup, v, k...) = setindex!(group.data, v, makekey(k))
38+
Base.delete!(group::BenchmarkGroup, k) = delete!(group.data, makekey(k))
39+
Base.delete!(group::BenchmarkGroup, k...) = delete!(group.data, makekey(k))
40+
Base.haskey(group::BenchmarkGroup, k) = haskey(group.data, makekey(k))
41+
Base.haskey(group::BenchmarkGroup, k...) = haskey(group.data, makekey(k))
4042
Base.keys(group::BenchmarkGroup) = keys(group.data)
4143
Base.values(group::BenchmarkGroup) = values(group.data)
4244
Base.iterate(group::BenchmarkGroup, i=1) = iterate(group.data, i)
@@ -128,11 +130,11 @@ end
128130
# leaf iteration/indexing #
129131
#-------------------------#
130132

131-
leaves(group::BenchmarkGroup) = leaves!(Any[], String[], group)
133+
leaves(group::BenchmarkGroup) = leaves!([], [], group)
132134

133135
function leaves!(results, parents, group::BenchmarkGroup)
134136
for (k, v) in group
135-
keys = vcat(parents, k)
137+
keys = Base.typed_vcat(Any, parents, k)
136138
if isa(v, BenchmarkGroup)
137139
leaves!(results, keys, v)
138140
else
@@ -172,16 +174,16 @@ macro tagged(expr)
172174
return :(BenchmarkTools.TagFilter(tags -> $(tagpredicate!(expr))))
173175
end
174176

175-
tagpredicate!(@nospecialize tag) = :(in(string($(esc(tag))), tags))
177+
tagpredicate!(@nospecialize tag) = :(in(makekey($(esc(tag))), tags))
176178

177179
function tagpredicate!(sym::Symbol)
178180
sym == :ALL && return true
179-
return :(in(string($(esc(sym))), tags))
181+
return :(in(makekey($(esc(sym))), tags))
180182
end
181183

182184
# build the body of the tag predicate in place
183185
function tagpredicate!(expr::Expr)
184-
expr.head == :quote && return :(in(string($(esc(expr))), tags))
186+
expr.head == :quote && return :(in(makekey($(esc(expr))), tags))
185187
for i in 1:length(expr.args)
186188
f = (i == 1 && expr.head === :call ? esc : tagpredicate!)
187189
expr.args[i] = f(expr.args[i])
@@ -191,17 +193,17 @@ end
191193

192194
function Base.getindex(src::BenchmarkGroup, f::TagFilter)
193195
dest = similar(src)
194-
loadtagged!(f, dest, src, src, String[], src.tags)
196+
loadtagged!(f, dest, src, src, [], src.tags)
195197
return dest
196198
end
197199

198200
# normal union doesn't have the behavior we want
199201
# (e.g. union(["1"], "2") === ["1", '2'])
200-
keyunion(args...) = unique(vcat(args...))
202+
keyunion(args...) = unique(Base.typed_vcat(Any, args...))
201203

202204
function tagunion(args...)
203205
unflattened = keyunion(args...)
204-
result = String[]
206+
result = []
205207
for i in unflattened
206208
if isa(i, Tuple)
207209
for j in i
@@ -215,7 +217,7 @@ function tagunion(args...)
215217
end
216218

217219
function loadtagged!(f::TagFilter, dest::BenchmarkGroup, src::BenchmarkGroup,
218-
group::BenchmarkGroup, keys::Vector{String}, tags::Vector{String})
220+
group::BenchmarkGroup, keys::Vector, tags::Vector)
219221
if f.predicate(tags)
220222
child_dest = createchild!(dest, src, keys)
221223
for (k, v) in group

src/serialization.jl

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,66 @@ const VERSIONS = Dict("Julia" => string(VERSION),
22
"BenchmarkTools" => string(BENCHMARKTOOLS_VERSION))
33

44
# TODO: Add any new types as they're added
5-
const SUPPORTED_TYPES = [Benchmark, BenchmarkGroup, Parameters, TagFilter, Trial,
6-
TrialEstimate, TrialJudgement, TrialRatio]
5+
const SUPPORTED_TYPES = Dict{Symbol,Type}(Base.typename(x).name => x for x in [
6+
BenchmarkGroup, Parameters, TagFilter, Trial,
7+
TrialEstimate, TrialJudgement, TrialRatio])
8+
# n.b. Benchmark type not included here, since it is gensym'd
79

8-
for T in SUPPORTED_TYPES
9-
@eval function JSON.lower(x::$T)
10-
d = Dict{String,Any}()
11-
for i = 1:nfields(x)
12-
name = String(fieldname($T, i))
13-
field = getfield(x, i)
14-
value = typeof(field) in SUPPORTED_TYPES ? JSON.lower(field) : field
15-
push!(d, name => value)
16-
end
17-
[string(typeof(x)), d]
10+
function JSON.lower(x::Union{values(SUPPORTED_TYPES)...})
11+
d = Dict{String,Any}()
12+
T = typeof(x)
13+
for i = 1:nfields(x)
14+
name = String(fieldname(T, i))
15+
field = getfield(x, i)
16+
ft = typeof(field)
17+
value = ft <: get(SUPPORTED_TYPES, ft.name.name, Union{}) ? JSON.lower(field) : field
18+
d[name] = value
1819
end
20+
[string(typeof(x).name.name), d]
1921
end
2022

23+
# a minimal 'eval' function, mirroring KeyTypes, but being slightly more lenient
24+
safeeval(@nospecialize x) = x
25+
safeeval(x::QuoteNode) = x.value
26+
function safeeval(x::Expr)
27+
x.head === :quote && return x.args[1]
28+
x.head === :inert && return x.args[1]
29+
x.head === :tuple && return ((safeeval(a) for a in x.args)...,)
30+
x
31+
end
2132
function recover(x::Vector)
2233
length(x) == 2 || throw(ArgumentError("Expecting a vector of length 2"))
2334
typename = x[1]::String
2435
fields = x[2]::Dict
25-
T = Core.eval(@__MODULE__, Meta.parse(typename))::Type
36+
startswith(typename, "BenchmarkTools.") && (typename = typename[sizeof("BenchmarkTools.")+1:end])
37+
T = SUPPORTED_TYPES[Symbol(typename)]
2638
fc = fieldcount(T)
2739
xs = Vector{Any}(undef, fc)
2840
for i = 1:fc
2941
ft = fieldtype(T, i)
3042
fn = String(fieldname(T, i))
31-
xs[i] = if ft in SUPPORTED_TYPES
32-
recover(fields[fn])
43+
if ft <: get(SUPPORTED_TYPES, ft.name.name, Union{})
44+
xsi = recover(fields[fn])
3345
else
34-
convert(ft, fields[fn])
46+
xsi = convert(ft, fields[fn])
3547
end
36-
if T == BenchmarkGroup && xs[i] isa Dict
37-
for (k, v) in xs[i]
48+
if T == BenchmarkGroup && xsi isa Dict
49+
for (k, v) in copy(xsi)
50+
k = k::String
51+
if startswith(k, "(") || startswith(k, ":")
52+
kt = Meta.parse(k, raise=false)
53+
if !(kt isa Expr && kt.head === :error)
54+
delete!(xsi, k)
55+
k = safeeval(kt)
56+
xsi[k] = v
57+
end
58+
end
3859
if v isa Vector && length(v) == 2 && v[1] isa String
39-
xs[i][k] = recover(v)
60+
xsi[k] = recover(v)
4061
end
4162
end
4263
end
64+
xs[i] = xsi
4365
end
4466
T(xs...)
4567
end
@@ -73,7 +95,7 @@ function save(io::IO, args...)
7395
"The name will be ignored and the object will be serialized " *
7496
"in the order it appears in the input.")
7597
continue
76-
elseif !any(T->arg isa T, SUPPORTED_TYPES)
98+
elseif !(arg isa get(SUPPORTED_TYPES, typeof(arg).name.name, Union{}))
7799
throw(ArgumentError("Only BenchmarkTools types can be serialized."))
78100
end
79101
push!(goodargs, arg)

test/GroupsTests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,12 @@ gnest = BenchmarkGroup(["1"],
216216
10 => BenchmarkGroup(["3"]),
217217
11 => BenchmarkGroup()))
218218

219-
@test sort!(leaves(gnest)) == Any[(["2","1"],1), (["4","5"],6), (["7"],8), (["a","(11, \"b\")"],:b), (Any["a","a"],:a)]
219+
@test sort(leaves(gnest), by=string) ==
220+
Any[(Any["2",1],1), (Any["a","a"],:a), (Any["a",(11,"b")],:b), (Any[4,5],6), (Any[7],8)]
220221

221222
@test gnest[@tagged 11 || 10] == BenchmarkGroup(["1"],
223+
"a" => BenchmarkGroup(["3"],
224+
(11, "b") => :b),
222225
9 => gnest[9])
223226

224227
@test gnest[@tagged "3"] == BenchmarkGroup(["1"], "2" => gnest["2"], 4 => gnest[4], "a" => gnest["a"],

test/SerializationTests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module SerializationTests
33
using BenchmarkTools
44
using Test
55

6-
eq(x::T, y::T) where {T<:Union{BenchmarkTools.SUPPORTED_TYPES...}} =
6+
eq(x::T, y::T) where {T<:Union{values(BenchmarkTools.SUPPORTED_TYPES)...}} =
77
all(i->eq(getfield(x, i), getfield(y, i)), 1:fieldcount(T))
88
eq(x::T, y::T) where {T} = isapprox(x, y)
99

@@ -25,13 +25,13 @@ end
2525
withtempdir() do
2626
tmp = joinpath(pwd(), "tmp.json")
2727

28-
BenchmarkTools.save(tmp, b, bb)
28+
BenchmarkTools.save(tmp, b.params, bb)
2929
@test isfile(tmp)
3030

3131
results = BenchmarkTools.load(tmp)
3232
@test results isa Vector{Any}
3333
@test length(results) == 2
34-
@test eq(results[1], b)
34+
@test eq(results[1], b.params)
3535
@test eq(results[2], bb)
3636
end
3737

@@ -56,18 +56,18 @@ end
5656
tune!(b)
5757
bb = run(b)
5858

59-
@test_throws ArgumentError BenchmarkTools.save("x.jld", b)
60-
@test_throws ArgumentError BenchmarkTools.save("x.txt", b)
59+
@test_throws ArgumentError BenchmarkTools.save("x.jld", b.params)
60+
@test_throws ArgumentError BenchmarkTools.save("x.txt", b.params)
6161
@test_throws ArgumentError BenchmarkTools.save("x.json")
6262
@test_throws ArgumentError BenchmarkTools.save("x.json", 1)
6363

6464
withtempdir() do
6565
tmp = joinpath(pwd(), "tmp.json")
66-
@test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b)
66+
@test_logs (:warn, r"Naming variables") BenchmarkTools.save(tmp, "b", b.params)
6767
@test isfile(tmp)
6868
results = BenchmarkTools.load(tmp)
6969
@test length(results) == 1
70-
@test eq(results[1], b)
70+
@test eq(results[1], b.params)
7171
end
7272

7373
@test_throws ArgumentError BenchmarkTools.load("x.jld")

0 commit comments

Comments
 (0)