본문 바로가기
알고리즘

세그먼트 트리 개념 및 구현 / 코딩테스트 주요 5대 알고리즘

by 개발하는 감자입니다 2024. 1. 18.
728x90

 

 

안녕하세요! 개발감자입니다.

오늘 세그먼트 트리 개념과 코딩테스트에서 구현하는 방법에 이야기해볼게요.

세그먼트 트리를 알아보기에 앞서 트리의 개념이 헷갈리는 분이 계신다면, 아래의 포스팅을 참고하시길 바랍니다.

 

[알고리즘 마스터] 14. 트리와 트라이

안녕하세요! 개발감자입니다🥔 오늘은 알고리즘 중 트리와 트라이에 대해서 정리하고 백준 문제를 통해 구현해보는 시간을 가져보도록 하겠습니다. 1. 트리 알아보기 트리는 그래프의 특수한

qkrrmsdud.tistory.com

 

1. 세그먼트 트리의 개념

 

세그먼트 트리(Segment Tree)는 구간 쿼리를 빠르게 처리하기 위한 자료구조로, 주로 배열과 같은 일차원 데이터에 대한 구간 합을 효율적으로 구하는 데 사용됩니다. 세그먼트 트리는 완전 이진 트리의 형태를 띄며, 각 노드는 해당 구간의 합, 최소값, 최대값 등을 나타냅니다.

2. 세그먼트 트리 핵심 이론

세그먼트 트리의 개념을 잘 설명하는 영상을 추천해드립니다. 이 영상을 보시면, 이해가 아주 잘 되실 겁니다.

아래의 영상을 토대로 세그먼트 트리 핵심 이론 내용을 정리해보도록 하겠습니다.

 

2-1. 트리 초기화하기

데이터의 개수가 N이라고 가정하고 시작합니다.

트리 리스트의 크기 구하는 방법 : 2**k >= N을 만족하는 k의 최솟값을 구한 후, ( 2**k )* 2 를 트리 리스트의 크기로 정의합니다.

예를 들어, 데이터의 개수가 5라면 2의 거듭제곱 중 위의 식을 만족하는 최솟값은 바로 3이 될 것입니다. 그럼 트리 리스트의 크기는 16이 됩니다.

위에서 구한 트리 리스트의 크기를 기반으로 데이터를 넣습니다. 리스트에 넣을 인덱스는 주어진 질의 인덱스 + 2**k -1 입니다.

2-2. 질의값 구하기

질의 값을 구하는 방법은

1) start_index %2 == 1 (오른쪽 노드)일 때 해당 노드를 선택한다

2) end_index % 2 == 0 (왼쪽 노드) 일 때 해당 노드를 선택한다.

3) start_index depth 변경 : start_index = (start_index+1)/2 연산 실행한다.

4) end_index depth 변경 : end_index = (end_index-1)/2 연산 실행한다.

5) 1-4번을 반복하다가 end_index<start_index가 되면 종료한다.

 

2-3. 데이터 업데이트 하기

자신의 부모 노드로 이동하면서 업데이트를 하지만, 어떤 값으로 업데이트할 지는 트리 타입 별로 조금 다릅니다.

구간 합, 최대 최소인 경우에 따라 데이터를 업데이트하는 계산이 다르죠.

구간합은 자식 노드를 모두 더해 부모노드를 업데이트를 계속하고, 부모 노드의 인덱스가 1인 경우에 그만둡니다.

 

3. 세그먼트 트리 구현하기

백준 문제 2042번 구간 합 구하기를 보면서 세그먼트 트리를 코딩테스트에서 어떻게 응용하는지 함께 보실까요?

위의 핵심이론을 그대로 구현하여 코드를 작성하였습니다. 제가 제출한 정답을 보면서 같이 이야기해볼게요.

import sys

input = sys.stdin.readline

# 세그먼트 트리 - 업데이트하고 구간합 구하기

n, m, sum = map(int,input().split()) # 수의 갯수, 수의 변경, 구간합을 구하는 횟수
change = []
for i in range(1,63):
    if 2**i>= n:
        k = i
        break
num = [0 for i in range(2**(k+1))]  

for i in range(2**k,2**k+n): # 주어진 수
    temp = int(input())
    num[i] = temp

for i in range(m+sum): # 수의 변경
    a,b,c = map(int, input().split())
    change.append([a,b,c])


# 구간합 구하기
for i in range(2**k-1,0,-1):
    num[i] = num[2*i] + num[2*i +1]
def changeBC(b,c):
    b_index = b + 2**k -1
    num[b_index] = c # 변경

    # 다른 데이터 업데이트 (구간합)
    # 오른쪽 왼쪽인지 확인
    hello = 1
    while hello == 1:
        temp = int(b_index%2)
        if temp == 1: # 오른쪽 노드
            parent_index = int((b_index-1)/2)
            num[parent_index] = num[b_index-1] + num[b_index]
        if temp == 0: # 왼쪽 노드
            parent_index = int(b_index/2)
            num[parent_index] = num[b_index] + num[b_index+1]

        b_index = parent_index
        if b_index == 1:
            hello = 0

def printSum(b,c):
    sum = 0
    start = b + 2**k -1
    end = c + 2**k -1
    while end >= start:
        if int(start%2) == 1: # 선택한다
            sum += num[start]

        if int(end %2) == 0: # 선택한다
            sum += num[end]
        
        start = int((start +1)/2)
        end = int((end-1)/2)

    print(sum)

for i in range(m+sum): # 수의 변경
    op = change[i][0]
    b = change[i][1]
    c = change[i][2]

    if op == 1:
        # b->c 로 바꾸기
        changeBC(b,c)
    elif op == 2:
        # 인덱스 b부터 c까지 합 출력하기
        printSum(b,c)

 

이 코드는 세그먼트 트리를 사용하여 수열의 업데이트와 구간 합을 구하는 기능을 구현한 것으로 보입니다. 코드를 세부적으로 설명해보겠습니다.

 

3-1. 라이브러리 및 입력

sys.stdin.readline을 사용하여 빠른 입력을 위한 라이브러리를 불러옵니다.

import sys
input = sys.stdin.readline

 

3-2. 입력 받기

n: 수의 개수, m: 수의 변경 횟수, sum: 구간합을 구하는 횟수를 입력받습니다.

n, m, sum = map(int,input().split())

 

3-3. 트리 초기화

주어진 수의 개수 n을 기반으로 세그먼트 트리를 초기화합니다.

k는 트리의 높이로, 2**k가 주어진 수의 개수를 초과하는 최소의 2의 거듭제곱을 나타냅니다.

num은 세그먼트 트리를 저장할 배열로 초기값은 0으로 설정합니다.

change = []
for i in range(1, 63):
    if 2**i >= n:
        k = i
        break
num = [0 for i in range(2**(k+1))]

 

3-4. 주어진 수 입력

주어진 수열을 세그먼트 트리의 리프 노드에 입력합니다.

for i in range(2**k, 2**k + n):
    temp = int(input())
    num[i] = temp

 

3-5. 수의 변경 및 구간합 구하기

수의 변경 및 구간합을 구할 쿼리 연산을 입력받습니다.

for i in range(m + sum):
    a, b, c = map(int, input().split())
    change.append([a, b, c])

 

3-6. 세그먼트 트리 업데이트 함수

주어진 인덱스 b의 값을 c로 변경하고, 업데이트된 값을 상위 노드까지 반영하는 함수입니다.

hello 변수와 while 루프를 사용하여 상위 노드까지 거슬러 올라가며 업데이트를 수행합니다.

def changeBC(b, c):
    # 업데이트 로직 생략

 

3-7. 구간합 출력 함수

주어진 범위 b부터 c까지의 구간합을 계산하는 함수입니다.

해당 구간에 대해 세그먼트 트리를 순회하면서 필요한 노드를 선택하여 합을 계산합니다.

def printSum(b, c):
    # 구간합 계산 로직 생략

 

3-8. 쿼리 수행 및 출력

입력받은 쿼리에 따라 수의 변경 또는 구간합을 출력합니다. 함수를 호출하여 해당 연산을 수행합니다.

for i in range(m + sum):
    op = change[i][0]
    b = change[i][1]
    c = change[i][2]

    if op == 1:
        changeBC(b, c)
    elif op == 2:
        printSum(b, c)

 

그럼 더 알찬 포스팅으로 돌아올게요. 개발감자였습니다 :)

728x90
반응형