중첩(Nested)클래스로 Iterator 직접 구현하기

C++ STL의 템플릿 클래스들이 제공하는 이터레이터를 직접 구현해본 예제 입니다.

예를 들면, std::vector 의 이터레이터를 다음과 같이 사용합니다.

std::vector<int> v;
v.push_back(5); 
v.push_back(3); 
v.push_back(7); 

// 이터레이터 선언
std::vector<int>::iterator itr = v.begin();

직접 템플릿 클래스를 만들고 이를 순회하는 iterator 도 직접 만들어 보는 것이죠.

이진 탐색 트리(Binary Search Tree, 이하 BST)의 Iterator를 직접 구현해보고자 하는 학원생이 있어 같이 공부하며 만들어 보았습니다.

이 게시물은 BST의 기본적인 구현을 설명하지 않습니다. 인터넷에 많은 자료가 있으므로 참조하기 바라며, Custom Iterator 의 내용 위주로 설명하겠습니다.


구글 검색으로 아래 사이트의 도움을 많이 받았습니다.

미국 "The Ohio State University, Steven J Zeil" 교수님이 구현한 예제이며, 해당 주제에 대해 잘 정리되어 있습니다.

링크 :  https://www.cs.odu.edu/~zeil/cs361/latest/Public/treetraversal/index.html


아래와 같이 사용합니다. 

Main.cpp

#include <iostream>
#include <string>
#include "bst.h"

int main()
{
    // int type bst
    ocs::bst<int> b1;    
    b1.insert(5);
    b1.insert(3);
    b1.insert(7);
    b1.insert(2);
    b1.insert(10);
    
    std::cout << "[b1]" << '\n';
    for (const auto& it : b1)
    {
        std::cout << it << std::endl;
    }
    std::cout << std::endl;


    // string type tree
    auto func = [](const auto& s1, const auto& s2)
    {
        return s1.size() < s2.size();
    };

    ocs::bst<std::string, decltype(func)> b2(func);
    b2.insert("abc");
    b2.insert("my name is kim");
    b2.insert("hi");
    b2.insert("hello world");    
    
    std::cout << "[b2]" << '\n';
    ocs::bst<std::string, decltype(func)>::iterator itr;

    itr = b2.begin();
    b2.erase(itr);

    for (itr = b2.begin(); itr!= b2.end(); ++itr)
    {
        std::cout << *itr << std::endl;        
    }
}

bst clsss의 객체를 생성하고, 이진트리에 값을 Insert 후 Custom Iterator로 순회하는 코드입니다.

BST가 구현된 bst.h는 뒤에 설명하고, 출력결과는 다음과 같습니다.

Custom Iterator가 잘 동작하는 모습입니다.

[b1, b2 BST를 생성 후 출력]

 

컴파일 오류가 발생하면, C++ 컴파일러 버전을 C++14 이상으로 잡아야 합니다. 

[VS 프로젝트 속성]

 24번 라인에 작성된 람다식의 auto 전달인자 때문입니다.

C++11까지 람다 함수의 전달인자는 명시적으로 제한되었지만, C++14부터는 auto 타입이 가능해졌습니다.

이 람다 함수는 템플릿으로 작성된 트리값의 타입이 문자열인 경우, 비교(트리의 왼쪽 or 오른쪽) 후 노드에 연결하기 위해 작성되었습니다.


bst.h

이진탐색트리를 구현한 템플릿 클래스 선언부입니다.
bst class 와 중첩 클래스(Nested class) 로 iteratornode class가 존재합니다.
#pragma once
#include <utility>
#include <stdexcept>

namespace ocs
{
    template<typename T, typename compare = std::less<>>
    class bst
    {
    private:
        // Nested class bst::node
        class node
        {            
        public:
            node(T _val = 0) :
                val(_val), left(nullptr), right(nullptr), parent(nullptr) {}
            ~node() {}

        private:
            friend class bst;
            node(T _val, node* _left, node* _right, node* _parent) :
                val(_val), left(_left), right(_right), parent(_parent) {}

        private:
            T val;
            node* left;
            node* right;
            node* parent;
        };
    public:
        // Nested class bst::iterator
        class iterator
        {
        public:
            iterator();
            ~iterator() {}
            // operator overloading
            iterator& operator++();
            iterator& operator--();
            const iterator operator++(int);
            const iterator operator--(int);
            bool operator==(const iterator& it) const;
            bool operator!=(const iterator& it) const;
            T operator*() const;
            const T* operator->() const;

        private:
            friend class bst;
            iterator(const node* _node, const bst* _tree);

        private:
            const node* pNode;
            const bst* pTree;
        };

    // bst members
    public:
        bst(compare _cmp = compare());
        bst(const bst& other);
        bst(const bst&& other);
        bst& operator=(const bst& other);
        bst& operator=(const bst& other);
        ~bst();

    public:
        iterator insert(const T& _val);        
        bool erase(const iterator& it);
        const T& findMin() const;
        const T& findMax() const;
        iterator begin() const;
        iterator end() const;
        iterator find(const T& _val) const;
        size_t size() const;
        void clear();

    private:
        node* insert(const T& _val, node * &p, node *_parent);
        node* findMin(node *p) const;
        node* findMax(node *p) const;
        bool remove(node* &p, const T& _val);
        void removeAll(node* p);
        node* clone(node *p) const;

    private:
        compare cmp;
        node* root;
    };    

#include "bst.hpp"
} // end of namespace ocs

 

중첩 클래스의 내용을 제거하면 좀 더 가독성이 좋아집니다.

#pragma once
#include <utility>
#include <stdexcept>

namespace ocs
{
    template<typename T, typename compare = std::less<>>
    class bst
    {
    private:
        // Nested class bst::node
        class node
        ...생략...
        
    public:
        // Nested class bst::iterator
        class iterator
        ...생략...

    // bst members
    public:
        bst(compare _cmp = compare());
        bst(const bst& other);
        bst(const bst&& other);
        bst& operator=(const bst& other);
        bst& operator=(const bst& other);
        ~bst();

    public:
        iterator insert(const T& _val);        
        bool erase(const iterator& it);
        const T& findMin() const;
        const T& findMax() const;
        iterator begin() const;
        iterator end() const;
        iterator find(const T& _val) const;
        size_t size() const;
        void clear();

    private:
        node* insert(const T& _val, node * &p, node *_parent);
        node* findMin(node *p) const;
        node* findMax(node *p) const;
        bool remove(node* &p, const T& _val);
        void removeAll(node* p);
        node* clone(node *p) const;

    private:
        compare cmp;
        node* root;
    };    

#include "bst.hpp"
} // end of namespace ocs

 

한결 보기가 편해졌습니다. bst 클래스 UML 다이어 그램을 살펴보겠습니다.

[ bst class ]

bst class는 템플릿 클래스로 구성되어 있으며,

트리값을 저장할 typename T, 트리값을 비교할 typename compapre 를 가집니다.

  • 이진 검색 트리를 관리하는 Template Class
  • namespace ocs 내부에 선언
  • T 타입의 크다, 작다 비교용 compare 타입, 초기값 std::less<>
  • 중첩 클래스로 node, iterator 를 선언

  • Copy, Move Constructor 및 Copy, Move Assignment Operator 선언 
  • insert, erase, findMin, findMax, size, clear 등 트리 관리 멤버함수 선언
  • iterator class를 리턴하는 begin, end 등 멤버함수 선언

 

다음으로 bst::node 중첩클래스는 bst의 private영역에 선언되어 있습니다.

간단한 노드 클래스 입니다.

[ bst::node class ]

  •  node 형 포인터 좌, 우, 부모 및 트리 저장값을 멤버변수로 선언

 

마지막 bst::iterator 중첩클래스는 외부에서 사용해야 하므로 bst의 public 영역에 선언합니다.

[ bst::iterator class ]

  • node*, bst* 를 멤버 변수로 선언

  • 전위/후위 (prefix, postfix) 증가 및 감소 연산자 operator ++, -- 선언
  • 비교 연산자 operator ==, != 선언

  • 역참조(Pointer dereference) 연산자 operator * 선언
  • 화살표(Member selection) 연산자 operator -> 선언


bst.hpp

템플릿 클래스는 *.h 선언부에 구현을 같이 포함하는 것이 일반적이지만, 코드의 가독성을 위해 bst.hpp에 bst 클래스의 구현(Implementation)을 포함하도록 나누어 보았습니다.

#pragma once
using namespace ocs;

// iterator class member function
template<typename T, typename compare>
bst<T, compare>::iterator::iterator() : pNode(nullptr), pTree(nullptr)
{
}

template<typename T, typename compare>
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree) : pNode(_node), pTree(_tree)
{    
}

template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()
{
    node *p = nullptr;
    if (pNode == nullptr)
    {
        pNode = pTree->root;
        if (pNode == nullptr)
            throw std::exception("Error");

        while (pNode->left != nullptr)
            pNode = pNode->left;
    }
    else
        if (pNode->right != nullptr)
        {
            pNode = pNode->right;
            while (pNode->left != nullptr)
                pNode = pNode->left;
        }
        else
        {
            p = pNode->parent;
            while (p != nullptr && pNode == p->right)
            {
                pNode = p;
                p = p->parent;
            }
            pNode = p;
        }

    return *this;
}

template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator--()
{
    node *p = nullptr;

    if (pNode == nullptr)
    {
        pNode = pTree->root;

        if (pNode == nullptr)
            throw std::exception("Error");

        while (pNode->right != nullptr)
            pNode = pNode->right;
    }
    else if (pNode->left != nullptr)
    {
        pNode = pNode->left;

        while (pNode->right != nullptr)
            pNode = pNode->right;
    }
    else
    {
        p = pNode->parent;
        while (p != nullptr && pNode == p->left)
        {
            pNode = p;
            p = p->parent;
        }
        pNode = p;
    }
    return *this;
}

template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator++(int)
{
    auto temp = *this;
    this->operator++();
    return temp;
}

template<typename T, typename compare>
typename const bst<T, compare>::iterator bst<T, compare>::iterator::operator--(int)
{
    auto temp = *this;
    operator--();
    return temp;
}

template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator==(const iterator& it) const
{
    return pTree == it.pTree && pNode == it.pNode;
}

template<typename T, typename compare>
typename bool bst<T, compare>::iterator::operator!=(const iterator& it) const
{
    return pTree != it.pTree || pNode != it.pNode;
}

template<typename T, typename compare>
typename T bst<T, compare>::iterator::operator*() const
{
    if (pNode == nullptr)
        throw std::exception("Error");
    return pNode->val;
}

template<typename T, typename compare>
typename const T* bst<T, compare>::iterator::operator->() const
{
    return &pNode->val;
}



// bst clas member function
template<typename T, typename compare>
bst<T, compare>::bst(compare _cmp) : cmp(_cmp), root(nullptr)
{    
}

template<typename T, typename compare>
bst<T, compare>::bst(const bst& other)
{
    root = clone(other.root);
    cmp = other.cmp;
}

template<typename T, typename compare>
bst<T, compare>::bst(const bst&& other) : root(other.root), cmp(other.cmp)
{
    other.root = nullptr;
}

template<typename T, typename compare>
bst<T, compare>::~bst()
{
    clear();
}

template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst& other)
{
    bst temp(other);
    std::swap(*this, temp);
    return *this;
}

template<typename T, typename compare>
typename bst<T, compare>& bst<T, compare>::operator=(const bst&& other)
{
    std::swap(*this->root, other.root);
    std::swap(*this->cmp, other.cmp);
    return *this;
}

template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::clone(node* p) const
{
    if (p == nullptr)
        return nullptr;
    else
        return new node(p->val, clone(p->left), clone(p->right), p->parent);
}

template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::insert(const T& _val)
{
    auto p = insert(_val, root, nullptr);
    if (p == nullptr)
        return end();
    else
        return iterator(p, this);
}

template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::insert(const T& _val, node * &p, node *_parent)
{
    if (p == nullptr)
    {
        p = new node(_val, nullptr, nullptr, _parent);
        return p;
    }
    else
    {
        if (cmp(_val, p->val))
            return insert(_val, p->left, p);
        else if (cmp(p->val, _val))
            return insert(_val, p->right, p);
        else
            return nullptr; //duplicate
    }
}

template<typename T, typename compare>
typename bool bst<T, compare>::erase(const iterator& it)
{
    return remove(root, *it);
}

template<typename T, typename compare>
typename bool bst<T, compare>::remove(node* &p, const T& _val)
{
    if (p == nullptr)
        return false;
    if (cmp(_val, p->val))
        return remove(p->left, _val);
    else if (cmp(p->val, _val))
        return remove(p->right, _val);
    else if (p->left != nullptr && p->right != nullptr)
    {
        p->val = findMin(p->right)->val;
        remove(p->right, p->val);
        return true;
    }
    else
    {
        node *old = p;
        p = (p->left != nullptr) ? p->left : p->right;
        delete old;
        return true;
    }
}

template<typename T, typename compare>
typename void bst<T, compare>::removeAll(node *p)
{
    if (p != nullptr)
    {
        removeAll(p->left);
        removeAll(p->right);
        delete p;
    }
    p = nullptr;
}

template<typename T, typename compare>
typename const T& bst<T, compare>::findMin() const
{
    if (root == nullptr)
        throw std::exception("Error");
    return findMin(root)->val;
}

template<typename T, typename compare>
typename const T& bst<T, compare>::findMax() const
{
    if (root == nullptr)
        throw std::exception("Error");
    return findMax(root)->val;
}

template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMin(node* p) const
{
    if (p == nullptr || p->left == nullptr)
        return p;
    return findMin(p->left);
}

template<typename T, typename compare>
typename bst<T, compare>::node* bst<T, compare>::findMax(node* p) const
{
    if (p == nullptr || p->right == nullptr)
        return p;
    return findMax(p->right);
}

template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
    return iterator(findMin(root), this);
}

template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::end() const
{
    return iterator(nullptr, this);
}

template<typename T, typename compare>
typename bst<T, compare>::iterator bst<T, compare>::find(const T& _val) const
{
    node *p = root;
    while (p != nullptr && !(p->val == _val))
    {
        p = _val < p->val ? p->left : p->right;
    }
    return iterator(p, this);    
}

template<typename T, typename compare>
typename size_t bst<T, compare>::size() const
{
    size_t size = 0;
    for (iterator itr = begin(); itr != end(); ++itr)    
        size += 1;
    return size;
}

template<typename T, typename compare>
typename void bst<T, compare>::clear()
{
    removeAll(root);
}

bst, bst::node, bst::iterator 클래스의 멤버함수 구현이며, 중첩 클래스와 템플릿으로 인해 코드가 길어보입니다.

 

참고로 C++11 using 지시자의 템플릿 별칭을 이용하면 가독성이 아래와 같이 향상됩니다.

[원본 코드, bst::iterator 클래스의 레퍼런스를 리턴]

template<typename T, typename compare>
typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++() {}

[using 템플릿 별칭 선언]

template<typename T, typename compare>
using _iterator = typename ocs::bst<T, compare>::iterator;

[템플릿 별칭 사용, bst::iterator 클래스의 레퍼런스를 리턴]

template<typename T, typename compare>
typename _iterator<T, compare>& bst<T, compare>::iterator::operator++() {}


Custom Iterator

그럼 Custom Iterator가 어떻게 동작하는지 살펴보겠습니다.

1. bst class 객체 생성, 후 insert

ocs::bst<int> b1;	
b1.insert(5);
b1.insert(3);
b1.insert(7);
  • 현재까지 bst::iterator 객체 생성 X.

 

2. bst::begin() 함수 호출 후, 리턴되는 iterator 객체 itr에 대입

  • 대입연산자 우변의 b1.begin() 함수 호출

ocs::bst<int>::iterator itr = b1.begin();

  • bst::begin() 함수는 bst::iterator 클래스의 임시 객체를 생성
typename bst<T, compare>::iterator bst<T, compare>::begin() const
{
	return iterator(findMin(root), this);
}


  • bst::begin() 에 의해 생성되는 iterator 임시 객체는 findMin() 함수를 이용, 트리 값 5,3,7 중 가장 작은 3값을 가진 node*와 bst (this) 를 이용해 bst::Iterator 클래스 private 생성자 호출
  •  bst::iterator 의 private 생성자 모습
bst<T, compare>::iterator::iterator(const node* _node, const bst* _tree)
: pNode(_node), pTree(_tree)
{		
}
  • 따라서 bst::begin() 에 의해 생성되는 iterator 임시객체는 3값을 갖는 node* 와 b1 트리의 값을 멤버변수로 가짐
  • 이제 이 임시객체를 itr 타입에 대입하면 operator=() 이 아닌 iterator의 Copy Constructor가 호출되며, 디폴트 복사생성자는 멤버변수끼리 얕은 복사를 수행하므로 itr::pNode는 3값을 갖는 node*를 가리키고, itr::pTree는 b1를 가르키게 됨.

 

여기까지가 Custom Iterator 객체가 생성되는 과정입니다.

Iterator 임시객체의 소멸자는 Heap에 생성된 node *를 delete 하지 않으므로 우려할 부분은 없습니다.

node 객체를 동적생성, 해지하는 것은 bst 클래스에서 이루어 집니다.

아래 그림을 살펴보면 이해가 쉽습니다.

[ bst, iterator, node 관계 ]

3. bst::iterator::operator 반복

  • 코드를 아래와 같이 구성합니다.

ocs::bst<int> b1;	
b1.insert(5);
b1.insert(3);
b1.insert(7);		

for (auto itr = b1.begin(); itr != b1.end(); ++itr)
{
    std::cout << *itr << std::endl;
}
  • for 루프 전 상황은 2번 설명의 그림처럼 b1, node 객체가 만들어진 상황
  • auto itr 는 bst::begin() 함수 리턴값에 의해 bst::iterator객체 저장
  • *itr 은 operator*() 연산자 함수 호출, 3을 출력
typename T bst<T, compare>::iterator::operator*() const
{
    if (pNode == nullptr)
        throw std::exception("Error");
    return pNode->val;
}
  • iterator 는 bst::begin() 함수호출에 의해 생성 시, 가장 작은 값을 가지는 노드를 node* pNode에 저장하기 때문
  • 반복문의 1회전 때 노드값 3출력 후, ++itr 에 의해 operator++(), 전위 증가 연산자 호출

 typename bst<T, compare>::iterator& bst<T, compare>::iterator::operator++()

{
	node *p = nullptr;
	if (pNode == nullptr)
	{
		pNode = pTree->root;
		if (pNode == nullptr)
			throw std::exception("Error");

		while (pNode->left != nullptr)
			pNode = pNode->left;
	}
	else
		if (pNode->right != nullptr)
		{
			pNode = pNode->right;
			while (pNode->left != nullptr)
				pNode = pNode->left;
		}
		else
		{
			p = pNode->parent;
			while (p != nullptr && pNode == p->right)
			{
				pNode = p;
				p = p->parent;
			}
			pNode = p;
		}

	return *this;
}
이상으로 모든 코드 설명을 마칩니다.
 
참고로 Custom std::vector 예제도 작성해 두었으므로 참조바랍니다.

감사합니다.

댓글

이 블로그의 인기 게시물

Qt Designer 설치하기

PyQt5 기반 동영상 플레이어앱 만들기