Skip to content

Commit b6b337e

Browse files
authored
Merge pull request #8 from hugoabonizio/feat/add-save-and-load
Add save and load methods
2 parents 09fe537 + c4a74df commit b6b337e

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

spec/network_spec.cr

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,46 @@ describe Fann::Network do
7676
ann.close
7777
(result < [0.1]).should be_true
7878
end
79+
80+
context "saving current network" do
81+
it "saves standard networks" do
82+
tempfile = Tempfile.new("foo")
83+
File.size(tempfile.path).should eq 0
84+
ann = Fann::Network::Standard.new(2, [2], 1)
85+
ann.save(tempfile.path)
86+
(File.size(tempfile.path) > 0).should be_true
87+
end
88+
89+
it "saves cascade networks" do
90+
tempfile = Tempfile.new("bar")
91+
File.size(tempfile.path).should eq 0
92+
ann = Fann::Network::Cascade.new(2, 1)
93+
ann.save(tempfile.path)
94+
(File.size(tempfile.path) > 0).should be_true
95+
end
96+
end
97+
98+
context "loading a configuration file" do
99+
it "loads standard networks" do
100+
input = 2
101+
output = 1
102+
tempfile = Tempfile.new("standard")
103+
original = Fann::Network::Standard.new(input, [2], output)
104+
original.save(tempfile.path)
105+
loaded = Fann::Network::Standard.new(tempfile.path)
106+
loaded.input_size.should eq input
107+
loaded.output_size.should eq output
108+
end
109+
110+
it "loads cascade networks" do
111+
input = 2
112+
output = 1
113+
tempfile = Tempfile.new("cascade")
114+
original = Fann::Network::Cascade.new(input, output)
115+
original.save(tempfile.path)
116+
loaded = Fann::Network::Cascade.new(tempfile.path)
117+
loaded.input_size.should eq input
118+
loaded.output_size.should eq output
119+
end
120+
end
79121
end

spec/spec_helper.cr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
require "spec"
2+
require "tempfile"
23
require "../src/crystal-fann"

src/crystal-fann/cascade_network.cr

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@ module Fann
22
module Network
33
class Cascade
44
property :nn
5+
getter :input_size
6+
getter :output_size
57

68
def initialize(input : Int32, output : Int32)
79
@output_size = output
810
@input_size = input
911
@nn = LibFANN.create_shortcut(2, input, output)
1012
end
1113

14+
def initialize(path : String)
15+
@nn = LibFANN.create_from_file(path)
16+
@input_size = LibFANN.get_num_input(@nn)
17+
@output_size = LibFANN.get_num_output(@nn)
18+
end
19+
1220
def mse
1321
LibFANN.get_mse(@nn)
1422
end
@@ -43,6 +51,10 @@ module Fann
4351
result = LibFANN.run(@nn, input.to_unsafe)
4452
Slice.new(result, @output_size).to_a
4553
end
54+
55+
def save(path : String) : Int32
56+
LibFANN.save(@nn, path)
57+
end
4658
end
4759
end
4860
end

src/crystal-fann/standard_network.cr

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module Fann
22
module Network
33
class Standard
44
property :nn
5+
getter :input_size
6+
getter :output_size
57

68
def initialize(input : Int32, hidden : Array(Int32), output : Int32)
79
@logger = Logger.new(STDOUT)
@@ -17,6 +19,13 @@ module Fann
1719
@nn = LibFANN.create_standard_array(layers.size, layers.to_unsafe)
1820
end
1921

22+
def initialize(path : String)
23+
@logger = Logger.new(STDOUT)
24+
@nn = LibFANN.create_from_file(path)
25+
@input_size = LibFANN.get_num_input(@nn)
26+
@output_size = LibFANN.get_num_output(@nn)
27+
end
28+
2029
def mse
2130
LibFANN.get_mse(@nn)
2231
end
@@ -79,6 +88,10 @@ module Fann
7988
result = LibFANN.run(@nn, input.to_unsafe)
8089
Slice.new(result, @output_size).to_a
8190
end
91+
92+
def save(path : String) : Int32
93+
LibFANN.save(@nn, path)
94+
end
8295
end
8396
end
8497
end

0 commit comments

Comments
 (0)