중첩(Nested)클래스로 Iterator 직접 구현하기
C++ STL의 템플릿 클래스들이 제공하는 이터레이터를 직접 구현해본 예제 입니다.
예를 들면, std::vector 의 이터레이터를 다음과 같이 사용합니다.
std::vector<int> v; v.push_back(5); v.push_back(3); v.push_back(7); // 이터레이터 선언 std::vector<int>::iterator itr = v.begin();
직접 템플릿 클래스를 만들고 이를 순회하는 iterator 도 직접 만들어 보는
것이죠.
이진 탐색 트리(Binary Search Tree, 이하 BST)의 Iterator를 직접 구현해보고자 하는 학원생이 있어 같이 공부하며 만들어 보았습니다.
이 게시물은 BST의 기본적인 구현을 설명하지 않습니다. 인터넷에 많은 자료가 있으므로 참조하기 바라며, Custom Iterator 의 내용 위주로 설명하겠습니다.
구글 검색으로 아래 사이트의 도움을 많이 받았습니다.
미국 "The Ohio State University, Steven J Zeil" 교수님이 구현한
예제이며, 해당 주제에 대해 잘 정리되어 있습니다.
링크 :
https://www.cs.odu.edu/~zeil/cs361/latest/Public/treetraversal/index.html
아래와 같이 사용합니다.
Main.cpp
#include <iostream> #include <string> #include "bst.h" int main() { // int type bst ocs::bst<int> b1; b1.insert(5); b1.insert(3); b1.insert(7); b1.insert(2); b1.insert(10); std::cout << "[b1]" << '\n'; for (const auto& it : b1) { std::cout << it << std::endl; } std::cout << std::endl; // string type tree auto func = [](const auto& s1, const auto& s2) { return s1.size() < s2.size(); }; ocs::bst<std::string, decltype(func)> b2(func); b2.insert("abc"); b2.insert("my name is kim"); b2.insert("hi"); b2.insert("hello world"); std::cout << "[b2]" << '\n'; ocs::bst<std::string, decltype(func)>::iterator itr; itr = b2.begin(); b2.erase(itr); for (itr = b2.begin(); itr!= b2.end(); ++itr) { std::cout << *itr << std::endl; } }
bst clsss의 객체를 생성하고, 이진트리에 값을 Insert 후 Custom Iterator로
순회하는 코드입니다.
BST가 구현된 bst.h는 뒤에 설명하고, 출력결과는 다음과 같습니다.
Custom Iterator가 잘 동작하는 모습입니다.
[b1, b2 BST를 생성 후 출력] |
컴파일 오류가 발생하면, C++ 컴파일러 버전을 C++14 이상으로 잡아야 합니다.
[VS 프로젝트 속성] |
24번 라인에 작성된 람다식의 auto 전달인자 때문입니다.
C++11까지 람다 함수의 전달인자는 명시적으로 제한되었지만, C++14부터는 auto 타입이 가능해졌습니다.
이 람다 함수는 템플릿으로 작성된 트리값의 타입이 문자열인 경우, 비교(트리의 왼쪽 or 오른쪽) 후 노드에 연결하기 위해 작성되었습니다.
bst.h
#pragma once #include <utility> #include <stdexcept> namespace ocs { template<typename T, typename compare = std::less<>> class bst { private: // Nested class bst::node class node { public: node(T _val = 0) : val(_val), left(nullptr), right(nullptr), parent(nullptr) {} ~node() {} private: friend class bst; node(T _val, node* _left, node* _right, node* _parent) : val(_val), left(_left), right(_right), parent(_parent) {} private: T val; node* left; node* right; node* parent; }; public: // Nested class bst::iterator class iterator { public: iterator(); ~iterator() {} // operator overloading iterator& operator++(); iterator& operator--(); const iterator operator++(int); const iterator operator--(int); bool operator==(const iterator& it) const; bool operator!=(const iterator& it) const; T operator*() const; const T* operator->() const; private: friend class bst; iterator(const node* _node, const bst* _tree); private: const node* pNode; const bst* pTree; }; // bst members public: bst(compare _cmp = compare()); bst(const bst& other); bst(const bst&& other); bst& operator=(const bst& other); bst& operator=(const bst& other); ~bst(); public: iterator insert(const T& _val); bool erase(const iterator& it); const T& findMin() const; const T& findMax() const; iterator begin() const; iterator end() const; iterator find(const T& _val) const; size_t size() const; void clear(); private: node* insert(const T& _val, node * &p, node *_parent); node* findMin(node *p) const; node* findMax(node *p) const; bool remove(node* &p, const T& _val); void removeAll(node* p); node* clone(node *p) const; private: compare cmp; node* root; }; #include "bst.hpp" } // end of namespace ocs
중첩 클래스의 내용을 제거하면 좀 더 가독성이 좋아집니다.
#pragma once #include <utility> #include <stdexcept> namespace ocs { template<typename T, typename compare = std::less<>> class bst { private: // Nested class bst::node class node ...생략... public: // Nested class bst::iterator class iterator ...생략... // bst members public: bst(compare _cmp = compare()); bst(const bst& other); bst(const bst&& other); bst& operator=(const bst& other); bst& operator=(const bst& other); ~bst(); public: iterator insert(const T& _val); bool erase(const iterator& it); const T& findMin() const; const T& findMax() const; iterator begin() const; iterator end() const; iterator find(const T& _val) const; size_t size() const; void clear(); private: node* insert(const T& _val, node * &p, node *_parent); node* findMin(node *p) const; node* findMax(node *p) const; bool remove(node* &p, const T& _val); void removeAll(node* p); node* clone(node *p) const; private: compare cmp; node* root; }; #include "bst.hpp" } // end of namespace ocs
한결 보기가 편해졌습니다. bst 클래스 UML 다이어 그램을 살펴보겠습니다.
[ bst class ] |
bst class는 템플릿 클래스로 구성되어 있으며,
트리값을 저장할 typename T, 트리값을 비교할 typename compapre 를 가집니다.
- 이진 검색 트리를 관리하는 Template Class
- namespace ocs 내부에 선언
-
T 타입의 크다, 작다 비교용 compare 타입, 초기값 std::less<>
- 중첩 클래스로 node, iterator 를 선언
- Copy, Move Constructor 및 Copy, Move Assignment Operator 선언
- insert, erase, findMin, findMax, size, clear 등 트리 관리 멤버함수 선언
- iterator class를 리턴하는 begin, end 등 멤버함수 선언
다음으로 bst::node 중첩클래스는 bst의 private영역에 선언되어 있습니다.
간단한 노드 클래스 입니다.
[ bst::node class ] |
-
node 형 포인터 좌, 우, 부모 및 트리 저장값을 멤버변수로 선언
마지막 bst::iterator 중첩클래스는 외부에서 사용해야 하므로 bst의 public 영역에 선언합니다.
[ bst::iterator class ] |
- node*, bst* 를 멤버 변수로 선언
- 전위/후위 (prefix, postfix) 증가 및 감소 연산자 operator ++, -- 선언
- 비교 연산자 operator ==, != 선언
- 역참조(Pointer dereference) 연산자 operator * 선언
- 화살표(Member selection) 연산자 operator -> 선언
bst.hpp
템플릿 클래스는 *.h 선언부에 구현을 같이 포함하는 것이 일반적이지만, 코드의
가독성을 위해 bst.hpp에 bst 클래스의 구현(Implementation)을 포함하도록 나누어
보았습니다.
#pragma once using namespace ocs; // iterator class member function template<typename T, typename compare> bst<T, compare>::iterator::iterator() : pNode(nullptr), pTree(nullptr) { } template<typename T, typename compare> bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree) : pNode(_node), pTree(_tree) { } template<typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() { node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception("Error"); while (pNode->left != nullptr) pNode = pNode->left; } else if (pNode->right != nullptr) { pNode = pNode->right; while (pNode->left != nullptr) pNode = pNode->left; } else { p = pNode->parent; while (p != nullptr && pNode == p->right) { pNode = p; p = p->parent; } pNode = p; } return *this; } template<typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator--() { node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception("Error"); while (pNode->right != nullptr) pNode = pNode->right; } else if (pNode->left != nullptr) { pNode = pNode->left; while (pNode->right != nullptr) pNode = pNode->right; } else { p = pNode->parent; while (p != nullptr && pNode == p->left) { pNode = p; p = p->parent; } pNode = p; } return *this; } template<typename T, typename compare> typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator++(int) { auto temp = *this; this->operator++(); return temp; } template<typename T, typename compare> typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator--(int) { auto temp = *this; operator--(); return temp; } template<typename T, typename compare> typename bool bst<T, compare>::iterator::operator==(const iterator& it) const { return pTree == it.pTree && pNode == it.pNode; } template<typename T, typename compare> typename bool bst<T, compare>::iterator::operator!=(const iterator& it) const { return pTree != it.pTree || pNode != it.pNode; } template<typename T, typename compare> typename T bst<T, compare>::iterator::operator*() const { if (pNode == nullptr) throw std::exception("Error"); return pNode->val; } template<typename T, typename compare> typename const T* bst<T, compare>::iterator::operator->() const { return &pNode->val; } // bst clas member function template<typename T, typename compare> bst<T, compare>::bst(compare _cmp) : cmp(_cmp), root(nullptr) { } template<typename T, typename compare> bst<T, compare>::bst(const bst& other) { root = clone(other.root); cmp = other.cmp; } template<typename T, typename compare> bst<T, compare>::bst(const bst&& other) : root(other.root), cmp(other.cmp) { other.root = nullptr; } template<typename T, typename compare> bst<T, compare>::~bst() { clear(); } template<typename T, typename compare> typename bst<T, compare>& bst<T, compare>::operator=(const bst& other) { bst temp(other); std::swap(*this, temp); return *this; } template<typename T, typename compare> typename bst<T, compare>& bst<T, compare>::operator=(const bst&& other) { std::swap(*this->root, other.root); std::swap(*this->cmp, other.cmp); return *this; } template<typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::clone(node* p) const { if (p == nullptr) return nullptr; else return new node(p->val, clone(p->left), clone(p->right), p->parent); } template<typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::insert(const T& _val) { auto p = insert(_val, root, nullptr); if (p == nullptr) return end(); else return iterator(p, this); } template<typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::insert(const T& _val, node * &p, node *_parent) { if (p == nullptr) { p = new node(_val, nullptr, nullptr, _parent); return p; } else { if (cmp(_val, p->val)) return insert(_val, p->left, p); else if (cmp(p->val, _val)) return insert(_val, p->right, p); else return nullptr; //duplicate } } template<typename T, typename compare> typename bool bst<T, compare>::erase(const iterator& it) { return remove(root, *it); } template<typename T, typename compare> typename bool bst<T, compare>::remove(node* &p, const T& _val) { if (p == nullptr) return false; if (cmp(_val, p->val)) return remove(p->left, _val); else if (cmp(p->val, _val)) return remove(p->right, _val); else if (p->left != nullptr && p->right != nullptr) { p->val = findMin(p->right)->val; remove(p->right, p->val); return true; } else { node *old = p; p = (p->left != nullptr) ? p->left : p->right; delete old; return true; } } template<typename T, typename compare> typename void bst<T, compare>::removeAll(node *p) { if (p != nullptr) { removeAll(p->left); removeAll(p->right); delete p; } p = nullptr; } template<typename T, typename compare> typename const T& bst<T, compare>::findMin() const { if (root == nullptr) throw std::exception("Error"); return findMin(root)->val; } template<typename T, typename compare> typename const T& bst<T, compare>::findMax() const { if (root == nullptr) throw std::exception("Error"); return findMax(root)->val; } template<typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::findMin(node* p) const { if (p == nullptr || p->left == nullptr) return p; return findMin(p->left); } template<typename T, typename compare> typename bst<T, compare>::node* bst<T, compare>::findMax(node* p) const { if (p == nullptr || p->right == nullptr) return p; return findMax(p->right); } template<typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::begin() const { return iterator(findMin(root), this); } template<typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::end() const { return iterator(nullptr, this); } template<typename T, typename compare> typename bst<T, compare>::iterator bst<T, compare>::find(const T& _val) const { node *p = root; while (p != nullptr && !(p->val == _val)) { p = _val < p->val ? p->left : p->right; } return iterator(p, this); } template<typename T, typename compare> typename size_t bst<T, compare>::size() const { size_t size = 0; for (iterator itr = begin(); itr != end(); ++itr) size += 1; return size; } template<typename T, typename compare> typename void bst<T, compare>::clear() { removeAll(root); }
bst, bst::node, bst::iterator 클래스의 멤버함수 구현이며, 중첩 클래스와 템플릿으로 인해 코드가 길어보입니다.
참고로 C++11 using 지시자의 템플릿 별칭을 이용하면 가독성이 아래와 같이 향상됩니다.
[원본 코드, bst::iterator 클래스의 레퍼런스를 리턴]
template<typename T, typename compare> typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() {}
[using 템플릿 별칭 선언]
template<typename T, typename compare> using _iterator = typename ocs::bst<T, compare>::iterator;
[템플릿 별칭 사용, bst::iterator 클래스의 레퍼런스를 리턴]
template<typename T, typename compare> typename _iterator<T, compare>& bst<T, compare>::iterator::operator++() {}
Custom Iterator
1. bst class 객체 생성, 후 insert
ocs::bst<int> b1; b1.insert(5); b1.insert(3); b1.insert(7);
- 현재까지 bst::iterator 객체 생성 X.
2. bst::begin() 함수 호출 후, 리턴되는 iterator 객체 itr에 대입
- 대입연산자 우변의 b1.begin() 함수 호출
ocs::bst<int>::iterator itr = b1.begin();
- bst::begin() 함수는 bst::iterator 클래스의 임시 객체를 생성
typename bst<T, compare>::iterator bst<T, compare>::begin() const { return iterator(findMin(root), this); }
- bst::begin() 에 의해 생성되는 iterator 임시 객체는 findMin() 함수를 이용, 트리 값 5,3,7 중 가장 작은 3값을 가진 node*와 bst (this) 를 이용해 bst::Iterator 클래스 private 생성자 호출
- bst::iterator 의 private 생성자 모습
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree) : pNode(_node), pTree(_tree) { }
- 따라서 bst::begin() 에 의해 생성되는 iterator 임시객체는 3값을 갖는 node* 와 b1 트리의 값을 멤버변수로 가짐
- 이제 이 임시객체를 itr 타입에 대입하면 operator=() 이 아닌 iterator의 Copy Constructor가 호출되며, 디폴트 복사생성자는 멤버변수끼리 얕은 복사를 수행하므로 itr::pNode는 3값을 갖는 node*를 가리키고, itr::pTree는 b1를 가르키게 됨.
여기까지가 Custom Iterator 객체가 생성되는 과정입니다.
Iterator 임시객체의 소멸자는 Heap에 생성된 node *를 delete 하지 않으므로 우려할 부분은 없습니다.
node 객체를 동적생성, 해지하는 것은 bst 클래스에서 이루어 집니다.
아래 그림을 살펴보면 이해가 쉽습니다.
[ bst, iterator, node 관계 ] |
3. bst::iterator::operator 반복
- 코드를 아래와 같이 구성합니다.
ocs::bst<int> b1; b1.insert(5); b1.insert(3); b1.insert(7); for (auto itr = b1.begin(); itr != b1.end(); ++itr) { std::cout << *itr << std::endl; }
- for 루프 전 상황은 2번 설명의 그림처럼 b1, node 객체가 만들어진 상황
-
auto itr 는 bst::begin() 함수 리턴값에 의해 bst::iterator객체 저장
- *itr 은 operator*() 연산자 함수 호출, 3을 출력
typename T bst<T, compare>::iterator::operator*() const { if (pNode == nullptr) throw std::exception("Error"); return pNode->val; }
- iterator 는 bst::begin() 함수호출에 의해 생성 시, 가장 작은 값을 가지는 노드를 node* pNode에 저장하기 때문
- 반복문의 1회전 때 노드값 3출력 후, ++itr 에 의해 operator++(), 전위 증가 연산자 호출
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
{ node *p = nullptr; if (pNode == nullptr) { pNode = pTree->root; if (pNode == nullptr) throw std::exception("Error"); while (pNode->left != nullptr) pNode = pNode->left; } else if (pNode->right != nullptr) { pNode = pNode->right; while (pNode->left != nullptr) pNode = pNode->left; } else { p = pNode->parent; while (p != nullptr && pNode == p->right) { pNode = p; p = p->parent; } pNode = p; } return *this; }이상으로 모든 코드 설명을 마칩니다.
댓글
댓글 쓰기