Meta Segment Anything Model에 GUI 적용하기

개요

최근 심화반수업을 수강하는 학원생분을 통해 Meta AI(과거 Facebook) 에서 발표한 SAM (Segment Anything Model) 이라는 기술을 알게되었습니다. 

아래 웹사이트 링크에서 데모를 실행해 볼 수 있습니다.

Meta SAM demo link

잠깐 살펴본 바로는 이미지(정지영상, 실시간 모두)내부 특정 사물을 분류하는 기술입니다.

Meta 예제를 활용해 Python으로 GUI를 추가해 보았습니다.

[PyQt5 로 만든 예제]

이미지를 미리 학습된 Predict Model에 넣고 마우스로 클릭하면 위와 같이 Segment 됩니다.

(별표시가 마우스 클릭한 좌표)

자세한 내용은 Meta Git Link 를 참조바라며, 아래는 Meta의 Model Diagram 소개자료입니다.

[ 이미지 출처 : Meta ]

 Meta에서 제공하는 Git 예제를 공부 중 코드에서 이미지 좌표를 키보드로 일일이 입력하니 좀 불편하고, 다양한 이미지에서 시도해 보고 싶은 마음이 듭니다.

그래서 PyQt5를 이용해 GUI로 활용가능하도록 코드를 작성해 보았습니다.

목록으로 정리해 보자면 Meta의 Sample 에서 아래의 편리성을 추가하였습니다.

1. QFileDialog 를 이용, 다양한 이미지 선택 가능

2. Python Thread를 이용, 대용량 학습데이터 미리 로딩, 시간단축

3. 마우스 좌클릭을 이용, 클릭한 좌표에 대한 Multiple Segment 가능

4. 마우스 우클릭을 통해 Segment 초기화

 

개발환경 

  • Windows 11 Pro, Visual Studio 2022

  • Python 3.9 64bit

  • Matplotlib 3.6.2, opencv-python 4.6.0.66

  • PyTorch 1.8.1+cu101 (Cuda 10.1)

  • segment-anything 1.0


소스코드

2개의 파이썬 파일로 구성 (Meta_GUI.py, Meta_Func.py)


Meta_GUI.pyPyQt5로 만든 GUI 위젯.

Meta_Func.py 는 Meta Sample Code의 Function 집합.


Meta_GUI.py

시작파일이며, QWidget에서 상속받은 Window Class로 구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QFileDialog
 
from matplotlib.backend_bases import MouseButton
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
import matplotlib.pyplot as plt
 
from threading import Thread
import time
 
import cv2
import sys
import numpy as np
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
from Meta_Func import show_mask, show_points, show_box
 
class Window(QWidget):
 
    def __init__(self):
        super().__init__()
        self.resize(1200,800)
        self.setWindowTitle('Ocean Coding School')
 
        self.run = True
        self.mx = 0
        self.my = 0
        self.LeftClick = False
        self.RightClick = False
         
        vbox = QVBoxLayout()
 
        self.pb = QPushButton('이미지 열기', self)
        vbox.addWidget(self.pb)
 
        self.fig = plt.Figure()
        self.canvas = FigureCanvasQTAgg(self.fig)
        vbox.addWidget(self.canvas)
 
        self.setLayout(vbox)
 
        self.pb.clicked.connect(self.onImgOpen)
        self.canvas.mpl_connect('motion_notify_event', self.onMouseMove)
        self.canvas.mpl_connect('button_press_event', self.onMouseClick)
 
        self.pb.setEnabled(False)
        self.t1 = Thread(target=self.createPredictor)
        self.t1.start()
 
    def createPredictor(self):
        sam_checkpoint = "d:/ML_Dataset/sam_vit_h_4b8939.pth"
        model_type = "vit_h"
 
        # Check cuda abailable in torch
        import torch
        isCuda = torch.cuda.is_available()
        print('Cuda is available : ', isCuda)
 
        if isCuda:
            device = 'cuda'
        else:
            device = 'cpu'
 
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(device=device)
 
        self.predictor = SamPredictor(sam)
        self.pb.setEnabled(True)       
 
    def onImgOpen(self):
        self.run = False
        path = QFileDialog.getOpenFileName(self, '', '', 'Image Files(*.jpg *.png)')
        if path[0]:
 
            image = cv2.imread(path[0])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
            self.fig.clear()
            ax = self.fig.subplots()
            ax.imshow(image)
            self.canvas.draw()
 
            # create thread
            self.run = True       
            self.t2 = Thread(target=self.threadFunc, args=(image,ax))       
            self.t2.start()
 
    def onMouseMove(self, e):
        self.mx = e.xdata
        self.my = e.ydata
        #print('X:', e.x)
        #print('Y:', e.y)       
 
    def onMouseClick(self, e):      
        if e.button is MouseButton.LEFT:       
            self.mx = e.xdata
            self.my = e.ydata
            print(e.xdata, e.ydata)
            self.LeftClick = True
        elif e.button is MouseButton.RIGHT:
            self.RightClick = True    
 
    def closeEvent(self, e):
        self.run = False
 
    def threadFunc(self, image, ax):
        pd = self.predictor
        pd.set_image(image)
 
        while self.run:
            if self.RightClick:                               
                pd.reset_image()
                pd.set_image(image)  
                ax.clear()
                ax.imshow(image)
                self.canvas.draw()
                self.RightClick = False               
 
 
            if self.LeftClick:
                input_point = np.array([[self.mx, self.my]])
                input_label = np.array([1])
 
                masks, scores, logits = pd.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=False,
                )
 
                #self.fig.clear()
 
                for i, (mask, score) in enumerate(zip(masks, scores)):               
                    show_mask(mask, ax)
                    show_points(input_point, input_label, ax)
                 
                    self.fig.suptitle(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
                    #ax.axis('off')
 
                self.canvas.draw()
                self.LeftClick=False
                 
            print('thread working...')
            time.sleep(0.1)
 
 
if __name__ == '__main__':
    app = QApplication(sys.argv)
    w = Window()
    w.show()
    sys.exit(app.exec_())


Meta_Func.py

Meta 예제의 함수들을 정리한 파일.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import matplotlib.pyplot as plt
 
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
     
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)  
     
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


유의사항

1. PyTorch GPU 지원 관련

Meta_GUI.py 53번 라인에서 PyTorchCuda.is_availableTrue,

하지만, 제 노트북 기종이 오래되어 안타깝게도 GPU 버전이 지원되지 않네요. 


2. Thread 를 2개 사용한 이유

Meta_GUI.py 46번 라인 t1 쓰레드는 SamPredictor 객체 생성시 지연 방지용.

Meta_GUI.py 84번 라인 t2 쓰레드는 Segment 된 이미지 Update 용도.

107번 라인의 predictor set_image 함수가 긴 지연(약 30초이상)을 유발.

따라서 t2쓰레드의 predictor 가 동작하는 시점은 console에 'thread working...' 이 찍히고 나서 가능.

 

3. 미리 학습된 모델 다운로드

Meta_GUI.py 50번 라인 "ML_Dataset/sam_vit_h_4b8939.pth" 는 위에 Meta Git Link 에서 다운로드 후 여러분의 Local 경로로 설정.


감사합니다.

댓글

이 블로그의 인기 게시물

Qt Designer 설치하기

PyQt5 기반 동영상 플레이어앱 만들기