Skip to content

Commit 85b03fe

Browse files
committed
Big blockmerge rehaul started
1 parent e748654 commit 85b03fe

17 files changed

+471
-218
lines changed

devscripts/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ArchGDAL = "c9ce4bd3-c3d5-55b8-8973-c0e20141b8c3"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
CFTime = "179af706-886a-5703-950a-314cd64e0468"
5+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
56
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
67
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
78
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -22,6 +23,7 @@ NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
2223
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
2324
PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
2425
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
26+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2527
Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
2628
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2729
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

devscripts/maintest.jl

Lines changed: 167 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,85 @@ using DiskArrayEngine
22
import DiskArrayEngine as DAE
33
using DiskArrays: ChunkType, RegularChunks
44
using Statistics
5-
using Zarr, DiskArrays, OffsetArrays
6-
#using DiskArrayEngine: MWOp, internal_size, ProductArray, InputArray, getloopinds, UserOp, mysub, ArrayBuffer, NoFilter, AllMissing,
7-
# create_buffers, read_range, generate_inbuffers, generate_outbuffers, get_bufferindices, offset_from_range, generate_outbuffer_collection, put_buffer,
8-
# Output, _view, Input, applyfilter, apply_function, LoopWindows, GMDWop, results_as_diskarrays, create_userfunction, steps_per_chunk, apparent_chunksize,
9-
# find_adjust_candidates, generate_LoopRange, get_loopsplitter, split_loopranges_threads, merge_loopranges_threads, LocalRunner,
10-
# merge_outbuffer_collection, DistributedRunner
5+
using Zarr, DiskArrays
116
using StatsBase: rle, mode
127
using CFTime: timedecode
138
using Dates
149
using OnlineStats
1510
using Logging
1611
using Distributed
17-
#global_logger(SimpleLogger(stdout,Logging.Debug))
18-
#global_logger(SimpleLogger(stdout))
1912
using LoggingExtras
2013
using Test
2114
using DataStructures: OrderedSet
2215

16+
17+
18+
19+
inwindows1 = DAE.MovingWindow(1, 5, 5, 4)
20+
outwindows1 = 1:4
21+
inwindows2 = DAE.MovingWindow(1, 2, 2, 2)
22+
outwindows2 = 1:2
23+
24+
inar = DAE.InputArray(1:20, windows=(inwindows1,))
25+
outspecs = (DAE.create_outwindows(4, windows=(outwindows1,)),)
26+
f = create_userfunction(sum, Float64)
27+
op1 = DAE.GMDWop((inar,), outspecs, f)
28+
r = results_as_diskarrays(op1)[1]
29+
30+
inar2 = DAE.InputArray(r, windows=(inwindows2,))
31+
outspecs2 = (DAE.create_outwindows(4, windows=(outwindows2,)),)
32+
f2 = create_userfunction(sum, Float64)
33+
op2 = DAE.GMDWop((inar2,), outspecs2, f2)
34+
r2 = results_as_diskarrays(op2)[1]
35+
36+
g = DAE.result_to_graph(r2)
37+
38+
@test length(g.nodes) == 3
39+
@test g.nodes[1] == DAE.MwopOutNode(false, nothing, (2,), Float64)
40+
@test g.nodes[2] == DAE.MwopOutNode(false, nothing, (4,), Float64)
41+
@test g.nodes[3] == 1:20
42+
43+
@test length(g.connections) == 2
44+
conn1, conn2 = g.connections
45+
@test conn1.inputids == [3]
46+
@test conn1.outputids == [2]
47+
@test conn2.inputids == [2]
48+
@test conn2.outputids == [1]
49+
50+
nodemergestrategies = DAE.collect_strategies(g)
51+
52+
@test only(nodemergestrategies[2]) isa DAE.BlockMerge
53+
@test nodemergestrategies[1] == [nothing]
54+
@test nodemergestrategies[3] == [nothing]
55+
56+
dimmap = DAE.create_loopdimmap(conn1, conn2, 2)
57+
@test dimmap isa DAE.DimMap
58+
@test dimmap.d == Dict(1 => 1)
59+
60+
newop = DAE.merge_operations(DAE.BlockMerge, conn1, conn2, 2, dimmap)
61+
@test newop isa DAE.UserOp
62+
@test newop.f isa DAE.BlockFunctionChain
63+
@test newop.f.funcs[1] === f.f
64+
@test newop.f.funcs[2] === f2.f
65+
@test newop.f.args == [((1,), (1,)), ((2,), (2,))]
66+
@test newop.f.transfers == [1 => [2]]
67+
68+
newconn, newnodes = DAE.merged_connection(DAE.BlockMerge, g, conn1, conn2, 2, newop, nodemergestrategies, dimmap)
69+
70+
@test newconn isa DAE.MwopConnection
71+
@test newconn.f === newop
72+
@test newconn.inputids == [3, 4]
73+
@test newconn.outputids == [2, 1]
74+
75+
newconn.inwindows[1].windows.members[1]
76+
newconn.inwindows[2].windows.members[1]
77+
newconn.outwindows[1].windows.members[1]
78+
newconn.outwindows[2].windows.members[1]
79+
80+
81+
82+
83+
2384
g = zopen("https://s3.bgc-jena.mpg.de:9000/esdl-esdc-v2.1.1/esdc-8d-0.25deg-184x90x90-2.1.1.zarr")
2485

2586

@@ -30,29 +91,121 @@ t = g["time"]
3091
tvec = timedecode(t[:], t.attrs["units"]);
3192
groups = yearmonth.(tvec)
3293

33-
r = aggregate_diskarray(a, mean, (1 => nothing, 2 => 8, 3 => groups), strategy=:reduce)
94+
r = aggregate_diskarray(a, mean, (1 => nothing, 2 => 8, 3 => groups), strategy=:direct)
3495

3596
#a = compute(r)
3697

3798
r2 = aggregate_diskarray(r, maximum, (2 => nothing,))
3899

100+
101+
39102
r3 = r .+ 273.15
40103

41104
finalres = r2 .+ r3
42105

43-
finalres[45, 100]
106+
finalres[1, 45, 100]
44107

45108
g = DAE.MwopGraph()
46109
outnode = DAE.to_graph!(g, r2);
110+
47111
DAE.remove_aliases!(g)
112+
48113
using CairoMakie, GraphMakie
49-
#p = graphplot(g,elabels=DAE.edgenames(g),ilabels=DAE.nodenames(g))
114+
p = graphplot(g, elabels=DAE.edgenames(g), ilabels=DAE.nodenames(g))
115+
116+
#DAE.fuse_step_direct!(g)
50117

51-
dg = DAE.DimensionGraph(g)
52-
dg.concomps
53118

119+
nodemergestrategies = DAE.collect_strategies(g)
120+
i_eliminate = findfirst(nodemergestrategies) do strat
121+
!isempty(strat) && !all(isnothing, strat)
122+
end
123+
### DAE.eliminate_node(g, i_eliminate, nodemergestrategies[i_eliminate], BlockMerge)
124+
nodegraph = g
125+
inconids = DAE.inconnections(nodegraph, i_eliminate)
126+
outconids = DAE.outconnections(nodegraph, i_eliminate)
127+
inconns = nodegraph.connections[inconids]
128+
outconns = nodegraph.connections[outconids]
129+
130+
inconn = only(inconns)
131+
outconn = only(outconns)
132+
133+
dimmap = DAE.create_loopdimmap(inconn, outconn, i_eliminate)
134+
135+
newop = DAE.merge_operations(DAE.BlockMerge, inconn, outconn, i_eliminate, dimmap)
136+
137+
newconn = DAE.merged_connection(DAE.BlockMerge, nodegraph, inconn, outconn, i_eliminate, newop, nodemergestrategies, dimmap)
138+
139+
newconn.inputids
140+
newconn.outputids
141+
newconn.inwindows[2].windows.members[2]
142+
143+
144+
nodemergestrategies = DAE.collect_strategies(g)
145+
i_eliminate = findfirst(nodemergestrategies) do strat
146+
!isempty(strat) && !all(isnothing, strat)
147+
end
148+
149+
nodegraph = g;
150+
inconids = DAE.inconnections(nodegraph, i_eliminate)
151+
outconids = DAE.outconnections(nodegraph, i_eliminate)
152+
inconns = nodegraph.connections[inconids]
153+
outconns = nodegraph.connections[outconids]
154+
155+
inconn = only(inconns)
156+
outconn = only(outconns)
157+
158+
dimmap = DAE.create_loopdimmap(inconn, outconn, i_eliminate)
159+
160+
chain1 = DAE.BlockFunctionChain(inconn)
161+
chain2 = DAE.BlockFunctionChain(outconn)
162+
163+
to_eliminate = i_eliminate
164+
165+
chain1.args
166+
chain2.args
167+
ifrom = findfirst(==(to_eliminate), inconn.outputids)
168+
ito = findall(==(to_eliminate), outconn.inputids)
169+
transfer = ifrom => ito
170+
171+
newfunc = DAE.build_chain(chain1, chain2, dimmap, transfer)
172+
173+
174+
newop = DAE.merge_operations(DAE.BlockMerge, inconn, outconn, i_eliminate, dimmap)
175+
176+
newconn = DAE.merged_connection(DAE.BlockMerge, nodegraph, inconn, outconn, i_eliminate, newop, nodemergestrategies, dimmap)
177+
178+
deleteat!(nodegraph.connections, [inconids; outconids])
179+
push!(nodegraph.connections, newconn)
180+
181+
nodegraph.connections
182+
183+
conn = only(nodegraph.connections)
184+
op = conn.f
185+
inputs = InputArray.(g.nodes[conn.inputids], conn.inwindows)
186+
outspecs = map(g.nodes[conn.outputids], conn.outwindows) do outnode, outwindow
187+
(; lw=outwindow, chunks=outnode.chunks, ismem=outnode.ismem)
188+
end
189+
190+
191+
function gmwop_from_conn(conn)
192+
op = conn.f
193+
inputs = InputArray.(g.nodes[conn.inputids], conn.inwindows)
194+
outspecs = map(g.nodes[conn.outputids], conn.outwindows) do outnode, outwindow
195+
(; lw=outwindow, chunks=outnode.chunks, ismem=outnode.ismem)
196+
end
197+
DAE.GMDWop(inputs, outspecs, op)
198+
end
199+
200+
201+
using Graphs: nv, outneighbors, inneighbors
202+
ioutnodes = (findall(n -> !isempty(inneighbors(g, n)) && isempty(outneighbors(g, n)), 1:nv(g)))
203+
lastop = findall(conn -> all(in(conn.outputids), ioutnodes), g.connections) |> only
204+
op = gmwop_from_conn(g.connections[lastop])
205+
rgraph = results_as_diskarrays(op)[1]
206+
207+
runner = rgraph[1, 45, 100]
54208

55-
DAE.fuse_step_direct!(g)
56209

57210

58211

@@ -98,13 +251,7 @@ remaining_conn.outputids
98251
remaining_conn.outwindows[1]
99252

100253

101-
op = remaining_conn.f
102-
inputs = InputArray.(g.nodes[remaining_conn.inputids], remaining_conn.inwindows)
103-
outspecs = map(g.nodes[remaining_conn.outputids], remaining_conn.outwindows) do outnode, outwindow
104-
(; lw=outwindow, chunks=outnode.chunks, ismem=outnode.ismem)
105-
end
106254

107-
mergedop = DAE.GMDWop(inputs, outspecs, op)
108255

109256
rnow = DAE.results_as_diskarrays(mergedop)[2]
110257

src/buffers.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ windowbuffersize(looprange, window) = maximum(c -> internal_size(inner_index(win
3131
function generate_inbuffers(inars, loopranges)
3232
map(inars) do ia
3333
et = eltype(ia.a)
34-
#@show loopranges
35-
3634
Array{et}(undef, getbufsize(ia, loopranges))
3735
end
3836
end

src/disjointranges.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ windowmaximum(w) = maximum(maximum,w)
2525
last_contains_value(w::AbstractVector{<:Number},i) = findlast(<=(i),w)
2626
last_contains_value(w::AbstractRange,i) = searchsortedlast(w,i)
2727
function last_contains_value(w,i)
28-
ii = findlast(r->maximum(r)<=i,w)
28+
ii = findlast(r -> i in r, w)
2929
if ii === nothing
3030
length(w)+1
3131
else

src/enginearrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function collect_bcdims(A)
3939
newsize
4040
elseif newsize == 1
4141
oldsize
42-
elseif odlsize == 1
42+
elseif oldsize == 1
4343
newsize
4444
else
4545
error("Dimension length do not match $newsize and $oldsize")

src/executionplan.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ function get_chunkspec(outspec,ot)
198198
si = map(m->last(last(m))-first(first(m))+1,outspec.lw.windows.members)
199199
if cs isa GridChunks
200200
cs = cs.chunks
201+
elseif cs === nothing
202+
cs = map(_->nothing,si)
201203
end
202204
cs = map(cs,si) do csnow,s
203205
if csnow === nothing

src/graphs.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ describe(z::ZArray, _) = Zarr.zname(z)
2222
describe(z::Array{<:Any,0}, _) = z[]
2323
describe(z, i) = "Input $i"
2424
EmptyInput(n::MwopOutNode) = EmptyInput{n.eltype,length(n.size)}(n.size)
25+
Base.eltype(n::MwopOutNode) = n.eltype
2526

2627
mutable struct MwopGraph <: AbstractGraph{Int}
2728
dims::UnitRange{Int}
@@ -257,7 +258,9 @@ function eliminate_node(nodegraph, i_eliminate, strategies, appliedstrat)
257258

258259
newop = merge_operations(appliedstrat, inconn, outconn, i_eliminate, dimmap)
259260

260-
newconn = merged_connection(appliedstrat, nodegraph, inconn, outconn, i_eliminate, newop, strategies, dimmap)
261+
newconn, newnodes = merged_connection(appliedstrat, nodegraph, inconn, outconn, i_eliminate, newop, strategies, dimmap)
262+
263+
append!(nodegraph.nodes, newnodes)
261264

262265
deleteat!(nodegraph.connections, [inconids; outconids])
263266
push!(nodegraph.connections, newconn)
@@ -319,4 +322,10 @@ function fuse_graph!(g::MwopGraph)
319322
while length(g.connections) > 1
320323
fuse_step_block!(g::MwopGraph)
321324
end
325+
end
326+
327+
function result_to_graph(res)
328+
g = MwopGraph()
329+
to_graph!(g, res)
330+
g
322331
end

0 commit comments

Comments
 (0)