중첩(Nested)클래스로 Iterator 직접 구현하기
C++ STL의 템플릿 클래스들이 제공하는 이터레이터를 직접 구현해본 예제 입니다.
예를 들면, std::vector 의 이터레이터를 다음과 같이 사용합니다.
std::vector<int> v; v.push_back(5); v.push_back(3); v.push_back(7); // 이터레이터 선언 std::vector<int>::iterator itr = v.begin();
직접 템플릿 클래스를 만들고 이를 순회하는 iterator 도 직접 만들어 보는
것이죠.
이진 탐색 트리(Binary Search Tree, 이하 BST)의 Iterator를 직접 구현해보고자 하는 학원생이 있어 같이 공부하며 만들어 보았습니다.
이 게시물은 BST의 기본적인 구현을 설명하지 않습니다. 인터넷에 많은 자료가 있으므로 참조하기 바라며, Custom Iterator 의 내용 위주로 설명하겠습니다.
구글 검색으로 아래 사이트의 도움을 많이 받았습니다.
미국 "The Ohio State University, Steven J Zeil" 교수님이 구현한
예제이며, 해당 주제에 대해 잘 정리되어 있습니다.
링크 :
https://www.cs.odu.edu/~zeil/cs361/latest/Public/treetraversal/index.html
아래와 같이 사용합니다.
Main.cpp
#include <iostream>
#include <string>
#include "bst.h"
int main()
{
// int type bst
ocs::bst<int> b1;
b1.insert(5);
b1.insert(3);
b1.insert(7);
b1.insert(2);
b1.insert(10);
std::cout << "[b1]" << '\n';
for (const auto& it : b1)
{
std::cout << it << std::endl;
}
std::cout << std::endl;
// string type tree
auto func = [](const auto& s1, const auto& s2)
{
return s1.size() < s2.size();
};
ocs::bst<std::string, decltype(func)> b2(func);
b2.insert("abc");
b2.insert("my name is kim");
b2.insert("hi");
b2.insert("hello world");
std::cout << "[b2]" << '\n';
ocs::bst<std::string, decltype(func)>::iterator itr;
itr = b2.begin();
b2.erase(itr);
for (itr = b2.begin(); itr!= b2.end(); ++itr)
{
std::cout << *itr << std::endl;
}
}
bst clsss의 객체를 생성하고, 이진트리에 값을 Insert 후 Custom Iterator로
순회하는 코드입니다.
BST가 구현된 bst.h는 뒤에 설명하고, 출력결과는 다음과 같습니다.
Custom Iterator가 잘 동작하는 모습입니다.
|
|
[b1, b2 BST를 생성 후 출력] |
컴파일 오류가 발생하면, C++ 컴파일러 버전을 C++14 이상으로 잡아야 합니다.
|
|
[VS 프로젝트 속성] |
24번 라인에 작성된 람다식의 auto 전달인자 때문입니다.
C++11까지 람다 함수의 전달인자는 명시적으로 제한되었지만, C++14부터는 auto 타입이 가능해졌습니다.
이 람다 함수는 템플릿으로 작성된 트리값의 타입이 문자열인 경우, 비교(트리의 왼쪽 or 오른쪽) 후 노드에 연결하기 위해 작성되었습니다.
bst.h
#pragma once
#include <utility>
#include <stdexcept>
namespace ocs
{
template<typename T, typename compare = std::less<>>
class bst
{
private:
// Nested class bst::node
class node
{
public:
node(T _val = 0) :
val(_val), left(nullptr), right(nullptr), parent(nullptr) {}
~node() {}
private:
friend class bst;
node(T _val, node* _left, node* _right, node* _parent) :
val(_val), left(_left), right(_right), parent(_parent) {}
private:
T val;
node* left;
node* right;
node* parent;
};
public:
// Nested class bst::iterator
class iterator
{
public:
iterator();
~iterator() {}
// operator overloading
iterator& operator++();
iterator& operator--();
const iterator operator++(int);
const iterator operator--(int);
bool operator==(const iterator& it) const;
bool operator!=(const iterator& it) const;
T operator*() const;
const T* operator->() const;
private:
friend class bst;
iterator(const node* _node, const bst* _tree);
private:
const node* pNode;
const bst* pTree;
};
// bst members
public:
bst(compare _cmp = compare());
bst(const bst& other);
bst(const bst&& other);
bst& operator=(const bst& other);
bst& operator=(const bst& other);
~bst();
public:
iterator insert(const T& _val);
bool erase(const iterator& it);
const T& findMin() const;
const T& findMax() const;
iterator begin() const;
iterator end() const;
iterator find(const T& _val) const;
size_t size() const;
void clear();
private:
node* insert(const T& _val, node * &p, node *_parent);
node* findMin(node *p) const;
node* findMax(node *p) const;
bool remove(node* &p, const T& _val);
void removeAll(node* p);
node* clone(node *p) const;
private:
compare cmp;
node* root;
};
#include "bst.hpp"
} // end of namespace ocs
중첩 클래스의 내용을 제거하면 좀 더 가독성이 좋아집니다.
#pragma once
#include <utility>
#include <stdexcept>
namespace ocs
{
template<typename T, typename compare = std::less<>>
class bst
{
private:
// Nested class bst::node
class node
...생략...
public:
// Nested class bst::iterator
class iterator
...생략...
// bst members
public:
bst(compare _cmp = compare());
bst(const bst& other);
bst(const bst&& other);
bst& operator=(const bst& other);
bst& operator=(const bst& other);
~bst();
public:
iterator insert(const T& _val);
bool erase(const iterator& it);
const T& findMin() const;
const T& findMax() const;
iterator begin() const;
iterator end() const;
iterator find(const T& _val) const;
size_t size() const;
void clear();
private:
node* insert(const T& _val, node * &p, node *_parent);
node* findMin(node *p) const;
node* findMax(node *p) const;
bool remove(node* &p, const T& _val);
void removeAll(node* p);
node* clone(node *p) const;
private:
compare cmp;
node* root;
};
#include "bst.hpp"
} // end of namespace ocs
한결 보기가 편해졌습니다. bst 클래스 UML 다이어 그램을 살펴보겠습니다.
|
| [ bst class ] |
bst class는 템플릿 클래스로 구성되어 있으며,
트리값을 저장할 typename T, 트리값을 비교할 typename compapre 를 가집니다.
- 이진 검색 트리를 관리하는 Template Class
- namespace ocs 내부에 선언
-
T 타입의 크다, 작다 비교용 compare 타입, 초기값 std::less<>
- 중첩 클래스로 node, iterator 를 선언
- Copy, Move Constructor 및 Copy, Move Assignment Operator 선언
- insert, erase, findMin, findMax, size, clear 등 트리 관리 멤버함수 선언
- iterator class를 리턴하는 begin, end 등 멤버함수 선언
다음으로 bst::node 중첩클래스는 bst의 private영역에 선언되어 있습니다.
간단한 노드 클래스 입니다.
|
|
[ bst::node class ] |
-
node 형 포인터 좌, 우, 부모 및 트리 저장값을 멤버변수로 선언
마지막 bst::iterator 중첩클래스는 외부에서 사용해야 하므로 bst의 public 영역에 선언합니다.
|
| [ bst::iterator class ] |
- node*, bst* 를 멤버 변수로 선언
- 전위/후위 (prefix, postfix) 증가 및 감소 연산자 operator ++, -- 선언
- 비교 연산자 operator ==, != 선언
- 역참조(Pointer dereference) 연산자 operator * 선언
- 화살표(Member selection) 연산자 operator -> 선언
bst.hpp
템플릿 클래스는 *.h 선언부에 구현을 같이 포함하는 것이 일반적이지만, 코드의
가독성을 위해 bst.hpp에 bst 클래스의 구현(Implementation)을 포함하도록 나누어
보았습니다.
#pragma once
using namespace ocs;
// iterator class member function
template<typename T, typename compare>
bst<T, compare>::iterator::iterator() : pNode(nullptr), pTree(nullptr)
{
}
template<typename T, typename compare>
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree) : pNode(_node), pTree(_tree)
{
}
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
{
node *p = nullptr;
if (pNode == nullptr)
{
pNode = pTree->root;
if (pNode == nullptr)
throw std::exception("Error");
while (pNode->left != nullptr)
pNode = pNode->left;
}
else
if (pNode->right != nullptr)
{
pNode = pNode->right;
while (pNode->left != nullptr)
pNode = pNode->left;
}
else
{
p = pNode->parent;
while (p != nullptr && pNode == p->right)
{
pNode = p;
p = p->parent;
}
pNode = p;
}
return *this;
}
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator--()
{
node *p = nullptr;
if (pNode == nullptr)
{
pNode = pTree->root;
if (pNode == nullptr)
throw std::exception("Error");
while (pNode->right != nullptr)
pNode = pNode->right;
}
else if (pNode->left != nullptr)
{
pNode = pNode->left;
while (pNode->right != nullptr)
pNode = pNode->right;
}
else
{
p = pNode->parent;
while (p != nullptr && pNode == p->left)
{
pNode = p;
p = p->parent;
}
pNode = p;
}
return *this;
}
template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator++(int)
{
auto temp = *this;
this->operator++();
return temp;
}
template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator--(int)
{
auto temp = *this;
operator--();
return temp;
}
template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator==(const iterator& it) const
{
return pTree == it.pTree && pNode == it.pNode;
}
template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator!=(const iterator& it) const
{
return pTree != it.pTree || pNode != it.pNode;
}
template<typename T, typename compare>
typename T bst<T, compare>::iterator::operator*() const
{
if (pNode == nullptr)
throw std::exception("Error");
return pNode->val;
}
template<typename T, typename compare>
typename const T* bst<T, compare>::iterator::operator->() const
{
return &pNode->val;
}
// bst clas member function
template<typename T, typename compare>
bst<T, compare>::bst(compare _cmp) : cmp(_cmp), root(nullptr)
{
}
template<typename T, typename compare>
bst<T, compare>::bst(const bst& other)
{
root = clone(other.root);
cmp = other.cmp;
}
template<typename T, typename compare>
bst<T, compare>::bst(const bst&& other) : root(other.root), cmp(other.cmp)
{
other.root = nullptr;
}
template<typename T, typename compare>
bst<T, compare>::~bst()
{
clear();
}
template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst& other)
{
bst temp(other);
std::swap(*this, temp);
return *this;
}
template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst&& other)
{
std::swap(*this->root, other.root);
std::swap(*this->cmp, other.cmp);
return *this;
}
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::clone(node* p) const
{
if (p == nullptr)
return nullptr;
else
return new node(p->val, clone(p->left), clone(p->right), p->parent);
}
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::insert(const T& _val)
{
auto p = insert(_val, root, nullptr);
if (p == nullptr)
return end();
else
return iterator(p, this);
}
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::insert(const T& _val, node * &p, node *_parent)
{
if (p == nullptr)
{
p = new node(_val, nullptr, nullptr, _parent);
return p;
}
else
{
if (cmp(_val, p->val))
return insert(_val, p->left, p);
else if (cmp(p->val, _val))
return insert(_val, p->right, p);
else
return nullptr; //duplicate
}
}
template<typename T, typename compare>
typename bool bst<T, compare>::erase(const iterator& it)
{
return remove(root, *it);
}
template<typename T, typename compare>
typename bool bst<T, compare>::remove(node* &p, const T& _val)
{
if (p == nullptr)
return false;
if (cmp(_val, p->val))
return remove(p->left, _val);
else if (cmp(p->val, _val))
return remove(p->right, _val);
else if (p->left != nullptr && p->right != nullptr)
{
p->val = findMin(p->right)->val;
remove(p->right, p->val);
return true;
}
else
{
node *old = p;
p = (p->left != nullptr) ? p->left : p->right;
delete old;
return true;
}
}
template<typename T, typename compare>
typename void bst<T, compare>::removeAll(node *p)
{
if (p != nullptr)
{
removeAll(p->left);
removeAll(p->right);
delete p;
}
p = nullptr;
}
template<typename T, typename compare>
typename const T& bst<T, compare>::findMin() const
{
if (root == nullptr)
throw std::exception("Error");
return findMin(root)->val;
}
template<typename T, typename compare>
typename const T& bst<T, compare>::findMax() const
{
if (root == nullptr)
throw std::exception("Error");
return findMax(root)->val;
}
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMin(node* p) const
{
if (p == nullptr || p->left == nullptr)
return p;
return findMin(p->left);
}
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMax(node* p) const
{
if (p == nullptr || p->right == nullptr)
return p;
return findMax(p->right);
}
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
return iterator(findMin(root), this);
}
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::end() const
{
return iterator(nullptr, this);
}
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::find(const T& _val) const
{
node *p = root;
while (p != nullptr && !(p->val == _val))
{
p = _val < p->val ? p->left : p->right;
}
return iterator(p, this);
}
template<typename T, typename compare>
typename size_t bst<T, compare>::size() const
{
size_t size = 0;
for (iterator itr = begin(); itr != end(); ++itr)
size += 1;
return size;
}
template<typename T, typename compare>
typename void bst<T, compare>::clear()
{
removeAll(root);
}
bst, bst::node, bst::iterator 클래스의 멤버함수 구현이며, 중첩 클래스와 템플릿으로 인해 코드가 길어보입니다.
참고로 C++11 using 지시자의 템플릿 별칭을 이용하면 가독성이 아래와 같이 향상됩니다.
[원본 코드, bst::iterator 클래스의 레퍼런스를 리턴]
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() {}
[using 템플릿 별칭 선언]
template<typename T, typename compare> using _iterator = typename ocs::bst<T, compare>::iterator;
[템플릿 별칭 사용, bst::iterator 클래스의 레퍼런스를 리턴]
template<typename T, typename compare>
typename _iterator<T, compare>& bst<T, compare>::iterator::operator++() {}
Custom Iterator
1. bst class 객체 생성, 후 insert
ocs::bst<int> b1; b1.insert(5); b1.insert(3); b1.insert(7);
- 현재까지 bst::iterator 객체 생성 X.
2. bst::begin() 함수 호출 후, 리턴되는 iterator 객체 itr에 대입
- 대입연산자 우변의 b1.begin() 함수 호출
ocs::bst<int>::iterator itr = b1.begin();
- bst::begin() 함수는 bst::iterator 클래스의 임시 객체를 생성
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
return iterator(findMin(root), this);
}
- bst::begin() 에 의해 생성되는 iterator 임시 객체는 findMin() 함수를 이용, 트리 값 5,3,7 중 가장 작은 3값을 가진 node*와 bst (this) 를 이용해 bst::Iterator 클래스 private 생성자 호출
- bst::iterator 의 private 생성자 모습
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree)
: pNode(_node), pTree(_tree)
{
}
- 따라서 bst::begin() 에 의해 생성되는 iterator 임시객체는 3값을 갖는 node* 와 b1 트리의 값을 멤버변수로 가짐
- 이제 이 임시객체를 itr 타입에 대입하면 operator=() 이 아닌 iterator의 Copy Constructor가 호출되며, 디폴트 복사생성자는 멤버변수끼리 얕은 복사를 수행하므로 itr::pNode는 3값을 갖는 node*를 가리키고, itr::pTree는 b1를 가르키게 됨.
여기까지가 Custom Iterator 객체가 생성되는 과정입니다.
Iterator 임시객체의 소멸자는 Heap에 생성된 node *를 delete 하지 않으므로 우려할 부분은 없습니다.
node 객체를 동적생성, 해지하는 것은 bst 클래스에서 이루어 집니다.
아래 그림을 살펴보면 이해가 쉽습니다.
|
|
[ bst, iterator, node 관계 ] |
3. bst::iterator::operator 반복
- 코드를 아래와 같이 구성합니다.
ocs::bst<int> b1;
b1.insert(5);
b1.insert(3);
b1.insert(7);
for (auto itr = b1.begin(); itr != b1.end(); ++itr)
{
std::cout << *itr << std::endl;
}
- for 루프 전 상황은 2번 설명의 그림처럼 b1, node 객체가 만들어진 상황
-
auto itr 는 bst::begin() 함수 리턴값에 의해 bst::iterator객체 저장
- *itr 은 operator*() 연산자 함수 호출, 3을 출력
typename T bst<T, compare>::iterator::operator*() const
{
if (pNode == nullptr)
throw std::exception("Error");
return pNode->val;
}
- iterator 는 bst::begin() 함수호출에 의해 생성 시, 가장 작은 값을 가지는 노드를 node* pNode에 저장하기 때문
- 반복문의 1회전 때 노드값 3출력 후, ++itr 에 의해 operator++(), 전위 증가 연산자 호출
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
{
node *p = nullptr;
if (pNode == nullptr)
{
pNode = pTree->root;
if (pNode == nullptr)
throw std::exception("Error");
while (pNode->left != nullptr)
pNode = pNode->left;
}
else
if (pNode->right != nullptr)
{
pNode = pNode->right;
while (pNode->left != nullptr)
pNode = pNode->left;
}
else
{
p = pNode->parent;
while (p != nullptr && pNode == p->right)
{
pNode = p;
p = p->parent;
}
pNode = p;
}
return *this;
}
이상으로 모든 코드 설명을 마칩니다.






댓글
댓글 쓰기