Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions anylabeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
__appname__ = "AnyLabeling"
__version__ = "0.0.9"
from .app_info import *
3 changes: 3 additions & 0 deletions anylabeling/app_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__appname__ = "AnyLabeling"
__appdescription__ = "Effortless data labeling with AI support"
__version__ = "0.1.1"
11 changes: 7 additions & 4 deletions anylabeling/services/auto_labeling/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import abstractmethod
import logging
import os
import pathlib
import urllib.request
from abc import abstractmethod

import yaml

from .types import AutoLabelingResult
Expand Down Expand Up @@ -52,12 +53,14 @@ def get_model_abs_path(self, model_path):
# Try download model from url
if model_path.startswith("anylabeling_assets/"):
self.on_message(
"Downloading model from model registry. This may take a while..."
"Downloading model from model registry. This may take a"
" while..."
)
relative_path = model_path.replace("anylabeling_assets/", "")
download_url = self.BASE_DOWNLOAD_URL + relative_path
model_abs_path = os.path.join(
os.path.abspath("data"), relative_path
home_dir = os.path.expanduser("~")
model_abs_path = os.path.abspath(
os.path.join(home_dir, "data", relative_path)
)
if os.path.exists(model_abs_path):
return model_abs_path
Expand Down
16 changes: 10 additions & 6 deletions anylabeling/services/auto_labeling/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
import yaml
import importlib.resources as pkg_resources
from threading import Lock
from PyQt5.QtCore import QObject, pyqtSignal

import yaml
from PyQt5.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
from anylabeling.utils import GenericWorker
import importlib.resources as pkg_resources

from anylabeling import configs as anylabeling_configs
from anylabeling.services.auto_labeling.types import AutoLabelingResult
from anylabeling.utils import GenericWorker


class ModelManager(QObject):
Expand Down Expand Up @@ -108,12 +108,16 @@ def _load_model(self, model_name):
if model_info["type"] == "yolov5":
from .yolov5 import YOLOv5

model_info["model"] = YOLOv5(model_info, on_message=self.new_model_status.emit)
model_info["model"] = YOLOv5(
model_info, on_message=self.new_model_status.emit
)
self.auto_segmentation_model_unselected.emit()
elif model_info["type"] == "segment_anything":
from .segment_anything import SegmentAnything

model_info["model"] = SegmentAnything(model_info, on_message=self.new_model_status.emit)
model_info["model"] = SegmentAnything(
model_info, on_message=self.new_model_status.emit
)
self.auto_segmentation_model_selected.emit()
else:
raise Exception(f"Unknown model type: {model_info['type']}")
Expand Down
3 changes: 2 additions & 1 deletion anylabeling/services/auto_labeling/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import os
from copy import deepcopy

import onnxruntime
import cv2
import numpy as np
import onnxruntime
from PyQt5 import QtCore

from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img

from .model import Model
from .types import AutoLabelingResult

Expand Down
4 changes: 3 additions & 1 deletion anylabeling/services/auto_labeling/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __init__(self, shapes, replace=True):


class AutoLabelingMode:
NONE = None
OBJECT = "AUTOLABEL_OBJECT"
ADD = "AUTOLABEL_ADD"
REMOVE = "AUTOLABEL_REMOVE"
Expand Down Expand Up @@ -43,3 +42,6 @@ def __eq__(self, other):
self.edit_mode == other.edit_mode
and self.shape_type == other.shape_type
)


AutoLabelingMode.NONE = AutoLabelingMode(None, None)
3 changes: 3 additions & 0 deletions anylabeling/services/auto_labeling/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img

from .model import Model
from .types import AutoLabelingResult

Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(self, model_config, on_message) -> None:
raise Exception(f"Model not found: {model_abs_path}")

self.net = cv2.dnn.readNet(model_abs_path)
self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
self.classes = self.config["classes"]

def pre_process(self, input_image, net):
Expand Down
7 changes: 4 additions & 3 deletions anylabeling/views/labeling/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
__appname__ = "AnyLabeling"

import argparse
import codecs
import logging
Expand All @@ -10,6 +8,7 @@
import yaml
from PyQt5 import QtCore, QtWidgets

from ...app_info import __appname__
from .config import get_config
from .label_widget import MainWindow
from .logger import logger
Expand Down Expand Up @@ -37,7 +36,9 @@ def main():
"recognized as file, else as directory)"
),
)
default_config_file = os.path.join(os.path.expanduser("~"), ".anylabelingrc")
default_config_file = os.path.join(
os.path.expanduser("~"), ".anylabelingrc"
)
parser.add_argument(
"--config",
dest="config",
Expand Down
1 change: 1 addition & 0 deletions anylabeling/views/labeling/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import PIL.Image

from ...app_info import __version__
from . import utils
from .logger import logger

Expand Down
121 changes: 67 additions & 54 deletions anylabeling/views/labeling/label_widget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
__appname__ = "AnyLabeling"

import functools
import html
import math
Expand All @@ -12,34 +10,21 @@
import natsort
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import Qt, pyqtSlot
from PyQt5.QtWidgets import (
QDockWidget,
QHBoxLayout,
QLabel,
QPlainTextEdit,
QVBoxLayout,
QWhatsThis,
)
from PyQt5.QtWidgets import (QDockWidget, QHBoxLayout, QLabel, QPlainTextEdit,
QVBoxLayout, QWhatsThis)

from anylabeling.services.auto_labeling.types import AutoLabelingMode

from ...app_info import __appname__
from . import utils
from .config import get_config
from .label_file import LabelFile, LabelFileError
from .logger import logger
from .shape import Shape
from .widgets import (
BrightnessContrastDialog,
Canvas,
FileDialogPreview,
LabelDialog,
LabelListWidget,
LabelListWidgetItem,
ToolBar,
UniqueLabelQListWidget,
ZoomWidget,
AutoLabelingWidget,
)

from anylabeling.services.auto_labeling.types import AutoLabelingMode
from .widgets import (AutoLabelingWidget, BrightnessContrastDialog, Canvas,
FileDialogPreview, LabelDialog, LabelListWidget,
LabelListWidgetItem, ToolBar, UniqueLabelQListWidget,
ZoomWidget)

LABEL_COLORMAP = imgviz.label_colormap()

Expand Down Expand Up @@ -879,9 +864,6 @@ def __init__(
self.auto_labeling_widget.finish_auto_labeling_object_action_requested.connect(
self.finish_auto_labeling_object
)
self.canvas.auto_labeling_mode_changed.connect(
lambda mode: self.auto_labeling_widget.update_button_colors(mode)
)
self.auto_labeling_widget.hide() # Hide by default
central_layout.addWidget(self.label_instruction)
central_layout.addWidget(self.auto_labeling_widget)
Expand Down Expand Up @@ -1216,6 +1198,10 @@ def toggle_draw_mode(self, edit=True, create_mode="rectangle"):
self.actions.edit_mode.setEnabled(not edit)

def set_edit_mode(self):
# Diable auto labeling
self.clear_auto_labeling_marks()
self.auto_labeling_widget.set_auto_labeling_mode(None)

self.toggle_draw_mode(True)

def update_file_menu(self):
Expand Down Expand Up @@ -1496,7 +1482,18 @@ def format_shape(s):
)
return data

shapes = [format_shape(item.shape()) for item in self.label_list]
# Get current shapes
# Excluding auto labeling special shapes
shapes = [
format_shape(item.shape())
for item in self.label_list
if item.shape().label
not in [
AutoLabelingMode.OBJECT,
AutoLabelingMode.ADD,
AutoLabelingMode.REMOVE,
]
]
flags = {}
for i in range(self.flag_widget.count()):
item = self.flag_widget.item(i)
Expand Down Expand Up @@ -1728,6 +1725,8 @@ def toggle_polygons(self, value):
def load_file(self, filename=None):
"""Load the specified file, or the last opened file if None."""

self.clear_auto_labeling_marks()

# Changing file_list_widget loads file
if filename in self.image_list and (
self.file_list_widget.currentRow()
Expand Down Expand Up @@ -2364,25 +2363,54 @@ def new_shapes_from_auto_labeling(self, auto_labeling_result):
self.label_list.clear()
self.load_shapes(auto_labeling_result.shapes, replace=True)
else: # Just update existing shapes
# Remove shapes with label "AUTOLABEL_OBJECT"
# Remove shapes with label AutoLabelingMode.OBJECT
for shape in self.canvas.shapes:
if shape.label == "AUTOLABEL_OBJECT":
if shape.label == AutoLabelingMode.OBJECT:
item = self.label_list.find_item_by_shape(shape)
self.label_list.remove_item(item)
self.load_shapes(auto_labeling_result.shapes, replace=False)

self.set_dirty()

def clear_auto_labeling_marks(self):
"""Clear auto labeling marks."""
"""Clear auto labeling marks from the current image."""
# Clean up label list
for shape in self.canvas.shapes:
if shape.label in [
AutoLabelingMode.OBJECT,
AutoLabelingMode.ADD,
AutoLabelingMode.REMOVE,
]:
item = self.label_list.find_item_by_shape(shape)
self.label_list.remove_item(item)
try:
item = self.label_list.find_item_by_shape(shape)
self.label_list.remove_item(item)
except ValueError:
pass

# Clean up unique label list
for shape_label in [
AutoLabelingMode.OBJECT,
AutoLabelingMode.ADD,
AutoLabelingMode.REMOVE,
]:
for item in self.unique_label_list.find_items_by_label(
shape_label
):
self.unique_label_list.takeItem(
self.unique_label_list.row(item)
)

# Remove shapes from the canvas
self.canvas.shapes = [
shape
for shape in self.canvas.shapes
if shape.label
not in [
AutoLabelingMode.OBJECT,
AutoLabelingMode.ADD,
AutoLabelingMode.REMOVE,
]
]
self.canvas.update()

def finish_auto_labeling_object(self):
Expand Down Expand Up @@ -2424,8 +2452,10 @@ def finish_auto_labeling_object(self):
shape.group_id = group_id
# Update unique label list
if not self.unique_label_list.find_items_by_label(shape.label):
unique_label_item = self.unique_label_list.create_item_from_label(
shape.label
unique_label_item = (
self.unique_label_list.create_item_from_label(
shape.label
)
)
self.unique_label_list.addItem(unique_label_item)
rgb = self._get_rgb_by_label(shape.label)
Expand All @@ -2445,25 +2475,8 @@ def finish_auto_labeling_object(self):
else:
item.setText(f"{shape.label} ({shape.group_id})")

# Remove all ADD and REMOVE marks
for shape in self.canvas.shapes:
if shape.label in [AutoLabelingMode.ADD, AutoLabelingMode.REMOVE]:
item = self.label_list.find_item_by_shape(shape)
self.label_list.remove_item(item)

# Remove all auto labeling marks from unique label list
for item in self.unique_label_list.find_items_by_label(
AutoLabelingMode.OBJECT
):
self.unique_label_list.takeItem(self.unique_label_list.row(item))
for item in self.unique_label_list.find_items_by_label(
AutoLabelingMode.ADD
):
self.unique_label_list.takeItem(self.unique_label_list.row(item))
for item in self.unique_label_list.find_items_by_label(
AutoLabelingMode.REMOVE
):
self.unique_label_list.takeItem(self.unique_label_list.row(item))
# Clean up auto labeling objects
self.clear_auto_labeling_marks()

# Update shape colors
for shape in self.canvas.shapes:
Expand Down
34 changes: 7 additions & 27 deletions anylabeling/views/labeling/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
# flake8: noqa

from ._io import lblsave
from .image import (
apply_exif_orientation,
img_arr_to_b64,
img_b64_to_arr,
img_data_to_arr,
img_data_to_pil,
img_data_to_png_data,
img_pil_to_data,
)
from .qt import (
Struct,
add_actions,
distance,
distance_to_line,
fmt_shortcut,
label_validator,
new_action,
new_button,
new_icon,
)
from .shape import (
labelme_shapes_to_label,
masks_to_bboxes,
polygons_to_mask,
shape_to_mask,
shapes_to_label,
)
from .image import (apply_exif_orientation, img_arr_to_b64, img_b64_to_arr,
img_data_to_arr, img_data_to_pil, img_data_to_png_data,
img_pil_to_data)
from .qt import (Struct, add_actions, distance, distance_to_line, fmt_shortcut,
label_validator, new_action, new_button, new_icon)
from .shape import (labelme_shapes_to_label, masks_to_bboxes, polygons_to_mask,
shape_to_mask, shapes_to_label)
Loading