Skip to content

Commit 5519b5d

Browse files
committed
Added add_floor utility to db
1 parent 980d7cc commit 5519b5d

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

MLStructFP/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
__email__ = '[email protected]'
1111
__keywords__ = ['ml', 'ai', 'floor plan', 'architectural', 'dataset', 'cnn']
1212
__license__ = 'MIT'
13-
__version__ = '0.7.2'
13+
__version__ = '0.7.3'
1414

1515
# URL
1616
__url__ = 'https://github.com/MLSTRUCT/MLSTRUCT-FP'

MLStructFP/db/_db_loader.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class DbLoader(object):
3333
__filter: Optional[Callable[['Floor'], bool]]
3434
__filtered_floors: List['Floor']
3535
__floor: Dict[int, 'Floor']
36+
__floor_categories: Dict[int, str]
3637
__path: str
3738

3839
def __init__(self, db: str, floor_only: bool = False) -> None:
@@ -45,17 +46,17 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
4546
assert os.path.isfile(db), f'Dataset file {db} not found'
4647
self.__filter = None
4748
self.__filtered_floors = []
48-
self.__path = str(Path(os.path.realpath(db)).parent)
4949
self.__floor = {}
50+
self.__floor_categories: Dict[int, str] = {}
51+
self.__path = str(Path(os.path.realpath(db)).parent)
5052

5153
with open(db, 'r', encoding='utf8') as dbfile:
5254
data: dict = json.load(dbfile)
5355
meta: dict = data['meta'] if 'meta' in data else {}
5456

5557
# Load metadata
56-
floor_categories: Dict[int, str] = {}
5758
for cat in (meta['floor_categories'] if 'floor_categories' in meta else {}):
58-
floor_categories[meta['floor_categories'][cat]] = cat
59+
self.__floor_categories[meta['floor_categories'][cat]] = cat
5960
item_types: Dict[int, Tuple[str, str]] = {}
6061
for cat in (meta['item_types'] if 'item_types' in meta else {}):
6162
ic = meta['item_types'][cat]
@@ -83,7 +84,7 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
8384
project_id=project_id,
8485
project_label=project_label[project_id] if project_id in project_label else '',
8586
category=f_cat,
86-
category_name=floor_categories.get(f_cat, ''),
87+
category_name=self.__floor_categories.get(f_cat, ''),
8788
elevation=f_data['elevation'] if 'elevation' in f_data else False
8889
)
8990
if floor_only:
@@ -153,6 +154,31 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
153154
def __getitem__(self, item: int) -> 'Floor':
154155
return self.__floor[item]
155156

157+
def add_floor(self, floor_image: str, scale: float, category: int, elevation: bool) -> 'Floor':
158+
"""
159+
Adds a floor to the dataset. No project.
160+
161+
:param floor_image: Floor image file
162+
:param scale: Image scale
163+
:param category: Floor category
164+
:param elevation: Floor is elevation
165+
:return: Added floor object
166+
"""
167+
assert os.path.isfile(floor_image)
168+
f_id: int = len(self.__floor) + 1
169+
f = Floor(
170+
floor_id=int(f_id),
171+
image_path=floor_image,
172+
image_scale=scale,
173+
project_id=-1,
174+
project_label='',
175+
category=category,
176+
category_name=self.__floor_categories.get(category, ''),
177+
elevation=elevation
178+
)
179+
self.__floor[f_id] = f
180+
return f
181+
156182
@property
157183
def floors(self) -> Tuple['Floor', ...]:
158184
if len(self.__filtered_floors) == 0:

test/test_db.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ def test_hist(self) -> None:
107107
db = DbLoader(DB_PATH)
108108
self.assertEqual(db.hist(show_plot=False), ('',))
109109

110+
def test_add_floor(self) -> None:
111+
"""
112+
Test add floor to database.
113+
"""
114+
db = DbLoader(DB_PATH)
115+
f0 = db.floors[0]
116+
f = db.add_floor(floor_image=f0.image_path, scale=f0.image_scale, category=f0.category, elevation=f0.elevation)
117+
self.assertEqual(f.image_path, f0.image_path)
118+
self.assertEqual(f.image_scale, f0.image_scale)
119+
self.assertEqual(f.category, f0.category)
120+
self.assertEqual(f.elevation, f0.elevation)
121+
110122
def test_image(self) -> None:
111123
"""
112124
Test image obtain in binary/photo.

0 commit comments

Comments
 (0)