중첩(Nested)클래스로 Iterator 직접 구현하기

C++ STL의 템플릿 클래스들이 제공하는 이터레이터를 직접 구현해본 예제 입니다.

예를 들면, std::vector 의 이터레이터를 다음과 같이 사용합니다.

1
2
3
4
5
6
7
std::vector<int> v;
v.push_back(5); 
v.push_back(3); 
v.push_back(7); 
 
// 이터레이터 선언
std::vector<int>::iterator itr = v.begin();

직접 템플릿 클래스를 만들고 이를 순회하는 iterator 도 직접 만들어 보는 것이죠.

이진 탐색 트리(Binary Search Tree, 이하 BST)의 Iterator를 직접 구현해보고자 하는 학원생이 있어 같이 공부하며 만들어 보았습니다.

이 게시물은 BST의 기본적인 구현을 설명하지 않습니다. 인터넷에 많은 자료가 있으므로 참조하기 바라며, Custom Iterator 의 내용 위주로 설명하겠습니다.


구글 검색으로 아래 사이트의 도움을 많이 받았습니다.

미국 "The Ohio State University, Steven J Zeil" 교수님이 구현한 예제이며, 해당 주제에 대해 잘 정리되어 있습니다.

링크 :  https://www.cs.odu.edu/~zeil/cs361/latest/Public/treetraversal/index.html


아래와 같이 사용합니다. 

Main.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <iostream>
#include <string>
#include "bst.h"
 
int main()
{
    // int type bst
    ocs::bst<int> b1;   
    b1.insert(5);
    b1.insert(3);
    b1.insert(7);
    b1.insert(2);
    b1.insert(10);
     
    std::cout << "[b1]" << '\n';
    for (const auto& it : b1)
    {
        std::cout << it << std::endl;
    }
    std::cout << std::endl;
 
 
    // string type tree
    auto func = [](const auto& s1, const auto& s2)
    {
        return s1.size() < s2.size();
    };
 
    ocs::bst<std::string, decltype(func)> b2(func);
    b2.insert("abc");
    b2.insert("my name is kim");
    b2.insert("hi");
    b2.insert("hello world");   
     
    std::cout << "[b2]" << '\n';
    ocs::bst<std::string, decltype(func)>::iterator itr;
 
    itr = b2.begin();
    b2.erase(itr);
 
    for (itr = b2.begin(); itr!= b2.end(); ++itr)
    {
        std::cout << *itr << std::endl;       
    }
}

bst clsss의 객체를 생성하고, 이진트리에 값을 Insert 후 Custom Iterator로 순회하는 코드입니다.

BST가 구현된 bst.h는 뒤에 설명하고, 출력결과는 다음과 같습니다.

Custom Iterator가 잘 동작하는 모습입니다.

[b1, b2 BST를 생성 후 출력]

 

컴파일 오류가 발생하면, C++ 컴파일러 버전을 C++14 이상으로 잡아야 합니다. 

[VS 프로젝트 속성]

 24번 라인에 작성된 람다식의 auto 전달인자 때문입니다.

C++11까지 람다 함수의 전달인자는 명시적으로 제한되었지만, C++14부터는 auto 타입이 가능해졌습니다.

이 람다 함수는 템플릿으로 작성된 트리값의 타입이 문자열인 경우, 비교(트리의 왼쪽 or 오른쪽) 후 노드에 연결하기 위해 작성되었습니다.


bst.h

이진탐색트리를 구현한 템플릿 클래스 선언부입니다.
bst class 와 중첩 클래스(Nested class) 로 iteratornode class가 존재합니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#pragma once
#include <utility>
#include <stdexcept>
 
namespace ocs
{
    template<typename T, typename compare = std::less<>>
    class bst
    {
    private:
        // Nested class bst::node
        class node
        {           
        public:
            node(T _val = 0) :
                val(_val), left(nullptr), right(nullptr), parent(nullptr) {}
            ~node() {}
 
        private:
            friend class bst;
            node(T _val, node* _left, node* _right, node* _parent) :
                val(_val), left(_left), right(_right), parent(_parent) {}
 
        private:
            T val;
            node* left;
            node* right;
            node* parent;
        };
    public:
        // Nested class bst::iterator
        class iterator
        {
        public:
            iterator();
            ~iterator() {}
            // operator overloading
            iterator& operator++();
            iterator& operator--();
            const iterator operator++(int);
            const iterator operator--(int);
            bool operator==(const iterator& it) const;
            bool operator!=(const iterator& it) const;
            T operator*() const;
            const T* operator->() const;
 
        private:
            friend class bst;
            iterator(const node* _node, const bst* _tree);
 
        private:
            const node* pNode;
            const bst* pTree;
        };
 
    // bst members
    public:
        bst(compare _cmp = compare());
        bst(const bst& other);
        bst(const bst&& other);
        bst& operator=(const bst& other);
        bst& operator=(const bst& other);
        ~bst();
 
    public:
        iterator insert(const T& _val);       
        bool erase(const iterator& it);
        const T& findMin() const;
        const T& findMax() const;
        iterator begin() const;
        iterator end() const;
        iterator find(const T& _val) const;
        size_t size() const;
        void clear();
 
    private:
        node* insert(const T& _val, node * &p, node *_parent);
        node* findMin(node *p) const;
        node* findMax(node *p) const;
        bool remove(node* &p, const T& _val);
        void removeAll(node* p);
        node* clone(node *p) const;
 
    private:
        compare cmp;
        node* root;
    };   
 
#include "bst.hpp"
} // end of namespace ocs

 

중첩 클래스의 내용을 제거하면 좀 더 가독성이 좋아집니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#pragma once
#include <utility>
#include <stdexcept>
 
namespace ocs
{
    template<typename T, typename compare = std::less<>>
    class bst
    {
    private:
        // Nested class bst::node
        class node
        ...생략...
         
    public:
        // Nested class bst::iterator
        class iterator
        ...생략...
 
    // bst members
    public:
        bst(compare _cmp = compare());
        bst(const bst& other);
        bst(const bst&& other);
        bst& operator=(const bst& other);
        bst& operator=(const bst& other);
        ~bst();
 
    public:
        iterator insert(const T& _val);       
        bool erase(const iterator& it);
        const T& findMin() const;
        const T& findMax() const;
        iterator begin() const;
        iterator end() const;
        iterator find(const T& _val) const;
        size_t size() const;
        void clear();
 
    private:
        node* insert(const T& _val, node * &p, node *_parent);
        node* findMin(node *p) const;
        node* findMax(node *p) const;
        bool remove(node* &p, const T& _val);
        void removeAll(node* p);
        node* clone(node *p) const;
 
    private:
        compare cmp;
        node* root;
    };   
 
#include "bst.hpp"
} // end of namespace ocs

 

한결 보기가 편해졌습니다. bst 클래스 UML 다이어 그램을 살펴보겠습니다.

[ bst class ]

bst class는 템플릿 클래스로 구성되어 있으며,

트리값을 저장할 typename T, 트리값을 비교할 typename compapre 를 가집니다.

  • 이진 검색 트리를 관리하는 Template Class
  • namespace ocs 내부에 선언
  • T 타입의 크다, 작다 비교용 compare 타입, 초기값 std::less<>
  • 중첩 클래스로 node, iterator 를 선언

  • Copy, Move Constructor 및 Copy, Move Assignment Operator 선언 
  • insert, erase, findMin, findMax, size, clear 등 트리 관리 멤버함수 선언
  • iterator class를 리턴하는 begin, end 등 멤버함수 선언

 

다음으로 bst::node 중첩클래스는 bst의 private영역에 선언되어 있습니다.

간단한 노드 클래스 입니다.

[ bst::node class ]

  •  node 형 포인터 좌, 우, 부모 및 트리 저장값을 멤버변수로 선언

 

마지막 bst::iterator 중첩클래스는 외부에서 사용해야 하므로 bst의 public 영역에 선언합니다.

[ bst::iterator class ]

  • node*, bst* 를 멤버 변수로 선언

  • 전위/후위 (prefix, postfix) 증가 및 감소 연산자 operator ++, -- 선언
  • 비교 연산자 operator ==, != 선언

  • 역참조(Pointer dereference) 연산자 operator * 선언
  • 화살표(Member selection) 연산자 operator -> 선언


bst.hpp

템플릿 클래스는 *.h 선언부에 구현을 같이 포함하는 것이 일반적이지만, 코드의 가독성을 위해 bst.hpp에 bst 클래스의 구현(Implementation)을 포함하도록 나누어 보았습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#pragma once
using namespace ocs;
 
// iterator class member function
template<typename T, typename compare>
bst<T, compare>::iterator::iterator() : pNode(nullptr), pTree(nullptr)
{
}
 
template<typename T, typename compare>
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree) : pNode(_node), pTree(_tree)
{   
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
{
    node *p = nullptr;
    if (pNode == nullptr)
    {
        pNode = pTree->root;
        if (pNode == nullptr)
            throw std::exception("Error");
 
        while (pNode->left != nullptr)
            pNode = pNode->left;
    }
    else
        if (pNode->right != nullptr)
        {
            pNode = pNode->right;
            while (pNode->left != nullptr)
                pNode = pNode->left;
        }
        else
        {
            p = pNode->parent;
            while (p != nullptr && pNode == p->right)
            {
                pNode = p;
                p = p->parent;
            }
            pNode = p;
        }
 
    return *this;
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator--()
{
    node *p = nullptr;
 
    if (pNode == nullptr)
    {
        pNode = pTree->root;
 
        if (pNode == nullptr)
            throw std::exception("Error");
 
        while (pNode->right != nullptr)
            pNode = pNode->right;
    }
    else if (pNode->left != nullptr)
    {
        pNode = pNode->left;
 
        while (pNode->right != nullptr)
            pNode = pNode->right;
    }
    else
    {
        p = pNode->parent;
        while (p != nullptr && pNode == p->left)
        {
            pNode = p;
            p = p->parent;
        }
        pNode = p;
    }
    return *this;
}
 
template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator++(int)
{
    auto temp = *this;
    this->operator++();
    return temp;
}
 
template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator--(int)
{
    auto temp = *this;
    operator--();
    return temp;
}
 
template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator==(const iterator& it) const
{
    return pTree == it.pTree && pNode == it.pNode;
}
 
template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator!=(const iterator& it) const
{
    return pTree != it.pTree || pNode != it.pNode;
}
 
template<typename T, typename compare>
typename T bst<T, compare>::iterator::operator*() const
{
    if (pNode == nullptr)
        throw std::exception("Error");
    return pNode->val;
}
 
template<typename T, typename compare>
typename const T* bst<T, compare>::iterator::operator->() const
{
    return &pNode->val;
}
 
 
 
// bst clas member function
template<typename T, typename compare>
bst<T, compare>::bst(compare _cmp) : cmp(_cmp), root(nullptr)
{   
}
 
template<typename T, typename compare>
bst<T, compare>::bst(const bst& other)
{
    root = clone(other.root);
    cmp = other.cmp;
}
 
template<typename T, typename compare>
bst<T, compare>::bst(const bst&& other) : root(other.root), cmp(other.cmp)
{
    other.root = nullptr;
}
 
template<typename T, typename compare>
bst<T, compare>::~bst()
{
    clear();
}
 
template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst& other)
{
    bst temp(other);
    std::swap(*this, temp);
    return *this;
}
 
template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst&& other)
{
    std::swap(*this->root, other.root);
    std::swap(*this->cmp, other.cmp);
    return *this;
}
 
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::clone(node* p) const
{
    if (p == nullptr)
        return nullptr;
    else
        return new node(p->val, clone(p->left), clone(p->right), p->parent);
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::insert(const T& _val)
{
    auto p = insert(_val, root, nullptr);
    if (p == nullptr)
        return end();
    else
        return iterator(p, this);
}
 
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::insert(const T& _val, node * &p, node *_parent)
{
    if (p == nullptr)
    {
        p = new node(_val, nullptr, nullptr, _parent);
        return p;
    }
    else
    {
        if (cmp(_val, p->val))
            return insert(_val, p->left, p);
        else if (cmp(p->val, _val))
            return insert(_val, p->right, p);
        else
            return nullptr; //duplicate
    }
}
 
template<typename T, typename compare>
typename bool bst<T, compare>::erase(const iterator& it)
{
    return remove(root, *it);
}
 
template<typename T, typename compare>
typename bool bst<T, compare>::remove(node* &p, const T& _val)
{
    if (p == nullptr)
        return false;
    if (cmp(_val, p->val))
        return remove(p->left, _val);
    else if (cmp(p->val, _val))
        return remove(p->right, _val);
    else if (p->left != nullptr && p->right != nullptr)
    {
        p->val = findMin(p->right)->val;
        remove(p->right, p->val);
        return true;
    }
    else
    {
        node *old = p;
        p = (p->left != nullptr) ? p->left : p->right;
        delete old;
        return true;
    }
}
 
template<typename T, typename compare>
typename void bst<T, compare>::removeAll(node *p)
{
    if (p != nullptr)
    {
        removeAll(p->left);
        removeAll(p->right);
        delete p;
    }
    p = nullptr;
}
 
template<typename T, typename compare>
typename const T& bst<T, compare>::findMin() const
{
    if (root == nullptr)
        throw std::exception("Error");
    return findMin(root)->val;
}
 
template<typename T, typename compare>
typename const T& bst<T, compare>::findMax() const
{
    if (root == nullptr)
        throw std::exception("Error");
    return findMax(root)->val;
}
 
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMin(node* p) const
{
    if (p == nullptr || p->left == nullptr)
        return p;
    return findMin(p->left);
}
 
template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMax(node* p) const
{
    if (p == nullptr || p->right == nullptr)
        return p;
    return findMax(p->right);
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
    return iterator(findMin(root), this);
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::end() const
{
    return iterator(nullptr, this);
}
 
template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::find(const T& _val) const
{
    node *p = root;
    while (p != nullptr && !(p->val == _val))
    {
        p = _val < p->val ? p->left : p->right;
    }
    return iterator(p, this);   
}
 
template<typename T, typename compare>
typename size_t bst<T, compare>::size() const
{
    size_t size = 0;
    for (iterator itr = begin(); itr != end(); ++itr)   
        size += 1;
    return size;
}
 
template<typename T, typename compare>
typename void bst<T, compare>::clear()
{
    removeAll(root);
}

bst, bst::node, bst::iterator 클래스의 멤버함수 구현이며, 중첩 클래스와 템플릿으로 인해 코드가 길어보입니다.

 

참고로 C++11 using 지시자의 템플릿 별칭을 이용하면 가독성이 아래와 같이 향상됩니다.

[원본 코드, bst::iterator 클래스의 레퍼런스를 리턴]

1
2
template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() {}

[using 템플릿 별칭 선언]

1
2
template<typename T, typename compare>
using _iterator = typename ocs::bst<T, compare>::iterator;

[템플릿 별칭 사용, bst::iterator 클래스의 레퍼런스를 리턴]

1
2
template<typename T, typename compare>
typename _iterator<T, compare>& bst<T, compare>::iterator::operator++() {}


Custom Iterator

그럼 Custom Iterator가 어떻게 동작하는지 살펴보겠습니다.

1. bst class 객체 생성, 후 insert

1
2
3
4
ocs::bst<int> b1;
b1.insert(5);
b1.insert(3);
b1.insert(7);
  • 현재까지 bst::iterator 객체 생성 X.

 

2. bst::begin() 함수 호출 후, 리턴되는 iterator 객체 itr에 대입

  • 대입연산자 우변의 b1.begin() 함수 호출

1
ocs::bst<int>::iterator itr = b1.begin();

  • bst::begin() 함수는 bst::iterator 클래스의 임시 객체를 생성
1
2
3
4
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
    return iterator(findMin(root), this);
}


  • bst::begin() 에 의해 생성되는 iterator 임시 객체는 findMin() 함수를 이용, 트리 값 5,3,7 중 가장 작은 3값을 가진 node*와 bst (this) 를 이용해 bst::Iterator 클래스 private 생성자 호출
  •  bst::iterator 의 private 생성자 모습
1
2
3
4
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree)
: pNode(_node), pTree(_tree)
{      
}
  • 따라서 bst::begin() 에 의해 생성되는 iterator 임시객체는 3값을 갖는 node* 와 b1 트리의 값을 멤버변수로 가짐
  • 이제 이 임시객체를 itr 타입에 대입하면 operator=() 이 아닌 iterator의 Copy Constructor가 호출되며, 디폴트 복사생성자는 멤버변수끼리 얕은 복사를 수행하므로 itr::pNode는 3값을 갖는 node*를 가리키고, itr::pTree는 b1를 가르키게 됨.

 

여기까지가 Custom Iterator 객체가 생성되는 과정입니다.

Iterator 임시객체의 소멸자는 Heap에 생성된 node *를 delete 하지 않으므로 우려할 부분은 없습니다.

node 객체를 동적생성, 해지하는 것은 bst 클래스에서 이루어 집니다.

아래 그림을 살펴보면 이해가 쉽습니다.

[ bst, iterator, node 관계 ]

3. bst::iterator::operator 반복

  • 코드를 아래와 같이 구성합니다.

1
2
3
4
5
6
7
8
9
ocs::bst<int> b1;
b1.insert(5);
b1.insert(3);
b1.insert(7);      
 
for (auto itr = b1.begin(); itr != b1.end(); ++itr)
{
    std::cout << *itr << std::endl;
}
  • for 루프 전 상황은 2번 설명의 그림처럼 b1, node 객체가 만들어진 상황
  • auto itr 는 bst::begin() 함수 리턴값에 의해 bst::iterator객체 저장
  • *itr 은 operator*() 연산자 함수 호출, 3을 출력
1
2
3
4
5
6
typename T bst<T, compare>::iterator::operator*() const
{
    if (pNode == nullptr)
        throw std::exception("Error");
    return pNode->val;
}
  • iterator 는 bst::begin() 함수호출에 의해 생성 시, 가장 작은 값을 가지는 노드를 node* pNode에 저장하기 때문
  • 반복문의 1회전 때 노드값 3출력 후, ++itr 에 의해 operator++(), 전위 증가 연산자 호출

 typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
{
    node *p = nullptr;
    if (pNode == nullptr)
    {
        pNode = pTree->root;
        if (pNode == nullptr)
            throw std::exception("Error");
 
        while (pNode->left != nullptr)
            pNode = pNode->left;
    }
    else
        if (pNode->right != nullptr)
        {
            pNode = pNode->right;
            while (pNode->left != nullptr)
                pNode = pNode->left;
        }
        else
        {
            p = pNode->parent;
            while (p != nullptr && pNode == p->right)
            {
                pNode = p;
                p = p->parent;
            }
            pNode = p;
        }
 
    return *this;
}
이상으로 모든 코드 설명을 마칩니다.
 
참고로 Custom std::vector 예제도 작성해 두었으므로 참조바랍니다.

감사합니다.

댓글

이 블로그의 인기 게시물

Qt Designer 설치하기

PyQt5 기반 동영상 플레이어앱 만들기