중첩(Nested)클래스로 Iterator 직접 구현하기
C++ STL의 템플릿 클래스들이 제공하는 이터레이터를 직접 구현해본 예제 입니다.
예를 들면, std::vector 의 이터레이터를 다음과 같이 사용합니다.
1 2 3 4 5 6 7 | std::vector< int > v; v.push_back(5); v.push_back(3); v.push_back(7); // 이터레이터 선언 std::vector< int >::iterator itr = v.begin(); |
직접 템플릿 클래스를 만들고 이를 순회하는 iterator 도 직접 만들어 보는
것이죠.
이진 탐색 트리(Binary Search Tree, 이하 BST)의 Iterator를 직접 구현해보고자 하는 학원생이 있어 같이 공부하며 만들어 보았습니다.
이 게시물은 BST의 기본적인 구현을 설명하지 않습니다. 인터넷에 많은 자료가 있으므로 참조하기 바라며, Custom Iterator 의 내용 위주로 설명하겠습니다.
구글 검색으로 아래 사이트의 도움을 많이 받았습니다.
미국 "The Ohio State University, Steven J Zeil" 교수님이 구현한
예제이며, 해당 주제에 대해 잘 정리되어 있습니다.
링크 :
https://www.cs.odu.edu/~zeil/cs361/latest/Public/treetraversal/index.html
아래와 같이 사용합니다.
Main.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | #include <iostream> #include <string> #include "bst.h" int main() { // int type bst ocs::bst< int > b1; b1.insert(5); b1.insert(3); b1.insert(7); b1.insert(2); b1.insert(10); std::cout << "[b1]" << '\n' ; for ( const auto& it : b1) { std::cout << it << std::endl; } std::cout << std::endl; // string type tree auto func = []( const auto& s1, const auto& s2) { return s1.size() < s2.size(); }; ocs::bst<std::string, decltype(func)> b2(func); b2.insert( "abc" ); b2.insert( "my name is kim" ); b2.insert( "hi" ); b2.insert( "hello world" ); std::cout << "[b2]" << '\n' ; ocs::bst<std::string, decltype(func)>::iterator itr; itr = b2.begin(); b2.erase(itr); for (itr = b2.begin(); itr!= b2.end(); ++itr) { std::cout << *itr << std::endl; } } |
bst clsss의 객체를 생성하고, 이진트리에 값을 Insert 후 Custom Iterator로
순회하는 코드입니다.
BST가 구현된 bst.h는 뒤에 설명하고, 출력결과는 다음과 같습니다.
Custom Iterator가 잘 동작하는 모습입니다.
![]() |
[b1, b2 BST를 생성 후 출력] |
컴파일 오류가 발생하면, C++ 컴파일러 버전을 C++14 이상으로 잡아야 합니다.
![]() |
[VS 프로젝트 속성] |
24번 라인에 작성된 람다식의 auto 전달인자 때문입니다.
C++11까지 람다 함수의 전달인자는 명시적으로 제한되었지만, C++14부터는 auto 타입이 가능해졌습니다.
이 람다 함수는 템플릿으로 작성된 트리값의 타입이 문자열인 경우, 비교(트리의 왼쪽 or 오른쪽) 후 노드에 연결하기 위해 작성되었습니다.
bst.h
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | #pragma once #include <utility> #include <stdexcept> namespace ocs { template < typename T, typename compare = std::less<>> class bst { private : // Nested class bst::node class node { public : node(T _val = 0) : val(_val), left(nullptr), right(nullptr), parent(nullptr) {} ~node() {} private : friend class bst; node(T _val, node* _left, node* _right, node* _parent) : val(_val), left(_left), right(_right), parent(_parent) {} private : T val; node* left; node* right; node* parent; }; public : // Nested class bst::iterator class iterator { public : iterator(); ~iterator() {} // operator overloading iterator& operator++(); iterator& operator--(); const iterator operator++( int ); const iterator operator--( int ); bool operator==( const iterator& it) const ; bool operator!=( const iterator& it) const ; T operator*() const ; const T* operator->() const ; private : friend class bst; iterator( const node* _node, const bst* _tree); private : const node* pNode; const bst* pTree; }; // bst members public : bst(compare _cmp = compare()); bst( const bst& other); bst( const bst&& other); bst& operator=( const bst& other); bst& operator=( const bst& other); ~bst(); public : iterator insert( const T& _val); bool erase( const iterator& it); const T& findMin() const ; const T& findMax() const ; iterator begin() const ; iterator end() const ; iterator find( const T& _val) const ; size_t size() const ; void clear(); private : node* insert( const T& _val, node * &p, node *_parent); node* findMin(node *p) const ; node* findMax(node *p) const ; bool remove (node* &p, const T& _val); void removeAll(node* p); node* clone(node *p) const ; private : compare cmp; node* root; }; #include "bst.hpp" } // end of namespace ocs |
중첩 클래스의 내용을 제거하면 좀 더 가독성이 좋아집니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | #pragma once #include <utility> #include <stdexcept> namespace ocs { template < typename T, typename compare = std::less<>> class bst { private : // Nested class bst::node class node ...생략... public : // Nested class bst::iterator class iterator ...생략... // bst members public : bst(compare _cmp = compare()); bst( const bst& other); bst( const bst&& other); bst& operator=( const bst& other); bst& operator=( const bst& other); ~bst(); public : iterator insert( const T& _val); bool erase( const iterator& it); const T& findMin() const ; const T& findMax() const ; iterator begin() const ; iterator end() const ; iterator find( const T& _val) const ; size_t size() const ; void clear(); private : node* insert( const T& _val, node * &p, node *_parent); node* findMin(node *p) const ; node* findMax(node *p) const ; bool remove (node* &p, const T& _val); void removeAll(node* p); node* clone(node *p) const ; private : compare cmp; node* root; }; #include "bst.hpp" } // end of namespace ocs |
한결 보기가 편해졌습니다. bst 클래스 UML 다이어 그램을 살펴보겠습니다.
![]() |
[ bst class ] |
bst class는 템플릿 클래스로 구성되어 있으며,
트리값을 저장할 typename T, 트리값을 비교할 typename compapre 를 가집니다.
- 이진 검색 트리를 관리하는 Template Class
- namespace ocs 내부에 선언
-
T 타입의 크다, 작다 비교용 compare 타입, 초기값 std::less<>
- 중첩 클래스로 node, iterator 를 선언
- Copy, Move Constructor 및 Copy, Move Assignment Operator 선언
- insert, erase, findMin, findMax, size, clear 등 트리 관리 멤버함수 선언
- iterator class를 리턴하는 begin, end 등 멤버함수 선언
다음으로 bst::node 중첩클래스는 bst의 private영역에 선언되어 있습니다.
간단한 노드 클래스 입니다.
![]() |
[ bst::node class ] |
-
node 형 포인터 좌, 우, 부모 및 트리 저장값을 멤버변수로 선언
마지막 bst::iterator 중첩클래스는 외부에서 사용해야 하므로 bst의 public 영역에 선언합니다.
![]() |
[ bst::iterator class ] |
- node*, bst* 를 멤버 변수로 선언
- 전위/후위 (prefix, postfix) 증가 및 감소 연산자 operator ++, -- 선언
- 비교 연산자 operator ==, != 선언
- 역참조(Pointer dereference) 연산자 operator * 선언
- 화살표(Member selection) 연산자 operator -> 선언
bst.hpp
템플릿 클래스는 *.h 선언부에 구현을 같이 포함하는 것이 일반적이지만, 코드의
가독성을 위해 bst.hpp에 bst 클래스의 구현(Implementation)을 포함하도록 나누어
보았습니다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | #pragma once using namespace ocs; // iterator class member function template < typename T, typename compare> bst<T, compare>::iterator::iterator() : pNode(nullptr), pTree(nullptr) { } template < typename T, typename compare> bst<T, compare>::iterator::iterator( const node* _node, const bst* _tree) : pNode(_node), pTree(_tree) { } template < typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() { node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception( "Error" ); while (pNode->left != nullptr) pNode = pNode->left; } else if (pNode->right != nullptr) { pNode = pNode->right; while (pNode->left != nullptr) pNode = pNode->left; } else { p = pNode->parent; while (p != nullptr && pNode == p->right) { pNode = p; p = p->parent; } pNode = p; } return * this ; } template < typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator--() { node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception( "Error" ); while (pNode->right != nullptr) pNode = pNode->right; } else if (pNode->left != nullptr) { pNode = pNode->left; while (pNode->right != nullptr) pNode = pNode->right; } else { p = pNode->parent; while (p != nullptr && pNode == p->left) { pNode = p; p = p->parent; } pNode = p; } return * this ; } template < typename T, typename compare> typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator++( int ) { auto temp = * this ; this ->operator++(); return temp; } template < typename T, typename compare> typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator--( int ) { auto temp = * this ; operator--(); return temp; } template < typename T, typename compare> typename bool bst<T, compare>::iterator::operator==( const iterator& it) const { return pTree == it.pTree && pNode == it.pNode; } template < typename T, typename compare> typename bool bst<T, compare>::iterator::operator!=( const iterator& it) const { return pTree != it.pTree || pNode != it.pNode; } template < typename T, typename compare> typename T bst<T, compare>::iterator::operator*() const { if (pNode == nullptr) throw std::exception( "Error" ); return pNode->val; } template < typename T, typename compare> typename const T* bst<T, compare>::iterator::operator->() const { return &pNode->val; } // bst clas member function template < typename T, typename compare> bst<T, compare>::bst(compare _cmp) : cmp(_cmp), root(nullptr) { } template < typename T, typename compare> bst<T, compare>::bst( const bst& other) { root = clone(other.root); cmp = other.cmp; } template < typename T, typename compare> bst<T, compare>::bst( const bst&& other) : root(other.root), cmp(other.cmp) { other.root = nullptr; } template < typename T, typename compare> bst<T, compare>::~bst() { clear(); } template < typename T, typename compare> typename bst<T, compare>& bst<T, compare>::operator=( const bst& other) { bst temp(other); std::swap(* this , temp); return * this ; } template < typename T, typename compare> typename bst<T, compare>& bst<T, compare>::operator=( const bst&& other) { std::swap(* this ->root, other.root); std::swap(* this ->cmp, other.cmp); return * this ; } template < typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::clone(node* p) const { if (p == nullptr) return nullptr; else return new node(p->val, clone(p->left), clone(p->right), p->parent); } template < typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::insert( const T& _val) { auto p = insert(_val, root, nullptr); if (p == nullptr) return end(); else return iterator(p, this ); } template < typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::insert( const T& _val, node * &p, node *_parent) { if (p == nullptr) { p = new node(_val, nullptr, nullptr, _parent); return p; } else { if (cmp(_val, p->val)) return insert(_val, p->left, p); else if (cmp(p->val, _val)) return insert(_val, p->right, p); else return nullptr; //duplicate } } template < typename T, typename compare> typename bool bst<T, compare>::erase( const iterator& it) { return remove (root, *it); } template < typename T, typename compare> typename bool bst<T, compare>:: remove (node* &p, const T& _val) { if (p == nullptr) return false ; if (cmp(_val, p->val)) return remove (p->left, _val); else if (cmp(p->val, _val)) return remove (p->right, _val); else if (p->left != nullptr && p->right != nullptr) { p->val = findMin(p->right)->val; remove (p->right, p->val); return true ; } else { node *old = p; p = (p->left != nullptr) ? p->left : p->right; delete old; return true ; } } template < typename T, typename compare> typename void bst<T, compare>::removeAll(node *p) { if (p != nullptr) { removeAll(p->left); removeAll(p->right); delete p; } p = nullptr; } template < typename T, typename compare> typename const T& bst<T, compare>::findMin() const { if (root == nullptr) throw std::exception( "Error" ); return findMin(root)->val; } template < typename T, typename compare> typename const T& bst<T, compare>::findMax() const { if (root == nullptr) throw std::exception( "Error" ); return findMax(root)->val; } template < typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::findMin(node* p) const { if (p == nullptr || p->left == nullptr) return p; return findMin(p->left); } template < typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::findMax(node* p) const { if (p == nullptr || p->right == nullptr) return p; return findMax(p->right); } template < typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::begin() const { return iterator(findMin(root), this ); } template < typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::end() const { return iterator(nullptr, this ); } template < typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::find( const T& _val) const { node *p = root; while (p != nullptr && !(p->val == _val)) { p = _val < p->val ? p->left : p->right; } return iterator(p, this ); } template < typename T, typename compare> typename size_t bst<T, compare>::size() const { size_t size = 0; for (iterator itr = begin(); itr != end(); ++itr) size += 1; return size; } template < typename T, typename compare> typename void bst<T, compare>::clear() { removeAll(root); } |
bst, bst::node, bst::iterator 클래스의 멤버함수 구현이며, 중첩 클래스와 템플릿으로 인해 코드가 길어보입니다.
참고로 C++11 using 지시자의 템플릿 별칭을 이용하면 가독성이 아래와 같이 향상됩니다.
[원본 코드, bst::iterator 클래스의 레퍼런스를 리턴]
1 2 | template < typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() {} |
[using 템플릿 별칭 선언]
1 2 | template < typename T, typename compare> using _iterator = typename ocs::bst<T, compare>::iterator; |
[템플릿 별칭 사용, bst::iterator 클래스의 레퍼런스를 리턴]
1 2 | template < typename T, typename compare> typename _iterator<T, compare>& bst<T, compare>::iterator::operator++() {} |
Custom Iterator
1. bst class 객체 생성, 후 insert
1 2 3 4 | ocs::bst< int > b1; b1.insert(5); b1.insert(3); b1.insert(7); |
- 현재까지 bst::iterator 객체 생성 X.
2. bst::begin() 함수 호출 후, 리턴되는 iterator 객체 itr에 대입
- 대입연산자 우변의 b1.begin() 함수 호출
1 | ocs::bst< int >::iterator itr = b1.begin(); |
- bst::begin() 함수는 bst::iterator 클래스의 임시 객체를 생성
1 2 3 4 | typename bst<T, compare>::iterator bst<T, compare>::begin() const { return iterator(findMin(root), this ); } |
- bst::begin() 에 의해 생성되는 iterator 임시 객체는 findMin() 함수를 이용, 트리 값 5,3,7 중 가장 작은 3값을 가진 node*와 bst (this) 를 이용해 bst::Iterator 클래스 private 생성자 호출
- bst::iterator 의 private 생성자 모습
1 2 3 4 | bst<T, compare>::iterator::iterator( const node* _node, const bst* _tree) : pNode(_node), pTree(_tree) { } |
- 따라서 bst::begin() 에 의해 생성되는 iterator 임시객체는 3값을 갖는 node* 와 b1 트리의 값을 멤버변수로 가짐
- 이제 이 임시객체를 itr 타입에 대입하면 operator=() 이 아닌 iterator의 Copy Constructor가 호출되며, 디폴트 복사생성자는 멤버변수끼리 얕은 복사를 수행하므로 itr::pNode는 3값을 갖는 node*를 가리키고, itr::pTree는 b1를 가르키게 됨.
여기까지가 Custom Iterator 객체가 생성되는 과정입니다.
Iterator 임시객체의 소멸자는 Heap에 생성된 node *를 delete 하지 않으므로 우려할 부분은 없습니다.
node 객체를 동적생성, 해지하는 것은 bst 클래스에서 이루어 집니다.
아래 그림을 살펴보면 이해가 쉽습니다.
![]() |
[ bst, iterator, node 관계 ] |
3. bst::iterator::operator 반복
- 코드를 아래와 같이 구성합니다.
1 2 3 4 5 6 7 8 9 | ocs::bst< int > b1; b1.insert(5); b1.insert(3); b1.insert(7); for (auto itr = b1.begin(); itr != b1.end(); ++itr) { std::cout << *itr << std::endl; } |
- for 루프 전 상황은 2번 설명의 그림처럼 b1, node 객체가 만들어진 상황
-
auto itr 는 bst::begin() 함수 리턴값에 의해 bst::iterator객체 저장
- *itr 은 operator*() 연산자 함수 호출, 3을 출력
1 2 3 4 5 6 | typename T bst<T, compare>::iterator::operator*() const { if (pNode == nullptr) throw std::exception( "Error" ); return pNode->val; } |
- iterator 는 bst::begin() 함수호출에 의해 생성 시, 가장 작은 값을 가지는 노드를 node* pNode에 저장하기 때문
- 반복문의 1회전 때 노드값 3출력 후, ++itr 에 의해 operator++(), 전위 증가 연산자 호출
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | { node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception( "Error" ); while (pNode->left != nullptr) pNode = pNode->left; } else if (pNode->right != nullptr) { pNode = pNode->right; while (pNode->left != nullptr) pNode = pNode->left; } else { p = pNode->parent; while (p != nullptr && pNode == p->right) { pNode = p; p = p->parent; } pNode = p; } return * this ; } |
댓글
댓글 쓰기