Meta Segment Anything Model에 GUI 적용하기
개요
최근 심화반수업을 수강하는 학원생분을 통해 Meta AI(과거 Facebook) 에서 발표한 SAM (Segment Anything Model) 이라는 기술을 알게되었습니다.
아래 웹사이트 링크에서 데모를 실행해 볼 수 있습니다.
잠깐 살펴본 바로는 이미지(정지영상, 실시간 모두)내부 특정 사물을 분류하는 기술입니다.
Meta 예제를 활용해 Python으로 GUI를 추가해 보았습니다.
[PyQt5 로 만든 예제] |
이미지를 미리 학습된 Predict Model에 넣고 마우스로 클릭하면 위와 같이 Segment 됩니다.
(별표시가 마우스 클릭한 좌표)
자세한 내용은 Meta Git Link 를 참조바라며, 아래는 Meta의 Model Diagram 소개자료입니다.
[ 이미지 출처 : Meta ] |
Meta에서 제공하는 Git 예제를 공부 중 코드에서 이미지 좌표를 키보드로
일일이 입력하니 좀 불편하고, 다양한 이미지에서 시도해 보고 싶은 마음이
듭니다.
그래서 PyQt5를 이용해 GUI로 활용가능하도록 코드를 작성해 보았습니다.
목록으로 정리해 보자면 Meta의 Sample 에서 아래의 편리성을 추가하였습니다.
1. QFileDialog 를 이용, 다양한 이미지 선택 가능
2. Python Thread를 이용, 대용량 학습데이터 미리 로딩, 시간단축
3. 마우스 좌클릭을 이용, 클릭한 좌표에 대한 Multiple Segment 가능
4. 마우스 우클릭을 통해 Segment 초기화
개발환경
-
Windows 11 Pro, Visual Studio 2022
-
Python 3.9 64bit
-
Matplotlib 3.6.2, opencv-python 4.6.0.66
-
PyTorch 1.8.1+cu101 (Cuda 10.1)
-
segment-anything 1.0
소스코드
2개의 파이썬 파일로 구성 (Meta_GUI.py, Meta_Func.py)
Meta_GUI.py 는 PyQt5로 만든 GUI 위젯.
Meta_Func.py 는 Meta Sample Code의 Function 집합.
Meta_GUI.py
시작파일이며, QWidget에서 상속받은 Window Class로 구현
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QFileDialog from matplotlib.backend_bases import MouseButton from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg import matplotlib.pyplot as plt from threading import Thread import time import cv2 import sys import numpy as np sys.path.append("..") from segment_anything import sam_model_registry, SamPredictor from Meta_Func import show_mask, show_points, show_box class Window(QWidget): def __init__(self): super().__init__() self.resize(1200,800) self.setWindowTitle('Ocean Coding School') self.run = True self.mx = 0 self.my = 0 self.LeftClick = False self.RightClick = False vbox = QVBoxLayout() self.pb = QPushButton('이미지 열기', self) vbox.addWidget(self.pb) self.fig = plt.Figure() self.canvas = FigureCanvasQTAgg(self.fig) vbox.addWidget(self.canvas) self.setLayout(vbox) self.pb.clicked.connect(self.onImgOpen) self.canvas.mpl_connect('motion_notify_event', self.onMouseMove) self.canvas.mpl_connect('button_press_event', self.onMouseClick) self.pb.setEnabled(False) self.t1 = Thread(target=self.createPredictor) self.t1.start() def createPredictor(self): sam_checkpoint = "d:/ML_Dataset/sam_vit_h_4b8939.pth" model_type = "vit_h" # Check cuda abailable in torch import torch isCuda = torch.cuda.is_available() print('Cuda is available : ', isCuda) if isCuda: device = 'cuda' else: device = 'cpu' sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) self.predictor = SamPredictor(sam) self.pb.setEnabled(True) def onImgOpen(self): self.run = False path = QFileDialog.getOpenFileName(self, '', '', 'Image Files(*.jpg *.png)') if path[0]: image = cv2.imread(path[0]) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.fig.clear() ax = self.fig.subplots() ax.imshow(image) self.canvas.draw() # create thread self.run = True self.t2 = Thread(target=self.threadFunc, args=(image,ax)) self.t2.start() def onMouseMove(self, e): self.mx = e.xdata self.my = e.ydata #print('X:', e.x) #print('Y:', e.y) def onMouseClick(self, e): if e.button is MouseButton.LEFT: self.mx = e.xdata self.my = e.ydata print(e.xdata, e.ydata) self.LeftClick = True elif e.button is MouseButton.RIGHT: self.RightClick = True def closeEvent(self, e): self.run = False def threadFunc(self, image, ax): pd = self.predictor pd.set_image(image) while self.run: if self.RightClick: pd.reset_image() pd.set_image(image) ax.clear() ax.imshow(image) self.canvas.draw() self.RightClick = False if self.LeftClick: input_point = np.array([[self.mx, self.my]]) input_label = np.array([1]) masks, scores, logits = pd.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) #self.fig.clear() for i, (mask, score) in enumerate(zip(masks, scores)): show_mask(mask, ax) show_points(input_point, input_label, ax) self.fig.suptitle(f"Mask {i+1}, Score: {score:.3f}", fontsize=12) #ax.axis('off') self.canvas.draw() self.LeftClick=False print('thread working...') time.sleep(0.1) if __name__ == '__main__': app = QApplication(sys.argv) w = Window() w.show() sys.exit(app.exec_())
Meta_Func.py
Meta 예제의 함수들을 정리한 파일.
import numpy as np import matplotlib.pyplot as plt def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
유의사항
1. PyTorch GPU 지원 관련
Meta_GUI.py 53번 라인에서 PyTorch의 Cuda.is_available 은 True,
하지만, 제 노트북 기종이 오래되어 안타깝게도 GPU 버전이 지원되지 않네요.
2. Thread 를 2개 사용한 이유
Meta_GUI.py 46번 라인 t1 쓰레드는 SamPredictor 객체 생성시 지연 방지용.
Meta_GUI.py 84번 라인 t2 쓰레드는 Segment 된 이미지 Update 용도.
107번 라인의 predictor 에 set_image 함수가 긴 지연(약
30초이상)을 유발.
따라서 t2쓰레드의 predictor 가 동작하는 시점은 console에 'thread working...' 이 찍히고 나서 가능.
3. 미리 학습된 모델 다운로드
Meta_GUI.py 50번 라인 "ML_Dataset/sam_vit_h_4b8939.pth" 는 위에 Meta Git Link 에서 다운로드 후 여러분의 Local 경로로 설정.
감사합니다.
댓글
댓글 쓰기