Posts [알고리즘] 백준 1275 - 커피숍2
Post
Cancel

[알고리즘] 백준 1275 - 커피숍2


문제

1275 - 커피숍2


접근

세그먼트 트리는 구간 합을 구할 때 사용하기 좋은 알고리즘이다.
단순히 풀게 되면 N개의 원소를 가진 배열에서 특정 구간 (x, y)의 합을 구하는데 O(N)의 시간이 걸린다.
그리고 이러한 구간합을 M번 구해야한다면 O(NM)의 시간이 걸릴 것이다.
N과 M이 충분히 커진다면, 이러한 방식은 오랜 시간이 걸릴 것이고, 이를 효율적으로 풀기 위해 세그먼트 트리를 이용하게 된다.

세그먼트 트리는 이진 트리의 형식으로 각 구간의 합을 미리 구해놓고 주어진 문제에 대한 답을 연산하는 형식이다.
이 방식을 이용하면 트리를 구성하는데 O(N), 구간합을 연산하는데 O(logN)이 걸리게 되며, 최종적으로 O(N + MlogN)의 시간이 소요된다.

여기서는 세그먼트 트리에 대한 자세한 이론은 다루지 않고, 구현에 초점을 맞출 것이다.

우선 간단하게 배열 형식으로 세그먼트 트리를 만드려고 한다.
그러기 위해선 세그먼트 트리를 구성할 때 전체 노드의 개수가 몇개일지 미리 알아야한다.

이진 트리이기 때문에, 높이를 알면 전체 배열의 크기도 대충 감이 온다.
높이는 원소의 개수가 N이라 할 때, logN의 올림값이 된다.
예를 들어 N이 8이면 높이는 3, N이 7이어도 높이는 3이다.

그럼 높이를 구했으니 전체 트리의 크기는 몇으로 잡아야 할까?
높이가 1일 때는 (1+2 = 3) 3개의 크기가 필요하고, 2일 때는 (1+2+4) = 7개, 3일 때는 (1+2+4+8) = 15개가 필요하다.
즉, 2^(높이+1)-1 개가 필요하다.
하지만 우리는 1을 빼지 않고 2^(높이+1)을 사용할 것이다.
이유는 배열의 인덱싱 때문으로, 뒤에서 설명한다.

이제 배열에서 어떻게 노드에 접근하는지 알아보자.
세그먼트 트리의 리프노드는 자기 자신의 값을 가지고 있고, 위로 올라갈수록 자신의 자식들의 합을 담게 된다.

위에서부터 인덱스를 0, 1, 2…. 순으로 매겼을 때,
자기 자신의 인덱스를 i라고 하면 왼쪽 자식은 2i, 오른쪽 자식은 2i + 1로 나타낼 수 있다.
이를 표현하기 위해서는 배열의 인덱스를 0부터시작하지 않고, 1부터 시작해야 한다.

그럼 자식노드에서 부모노드로 접근은 어떻게 할까?
현재 노드를 i라고 하면, 부모노드는 i//2가 될 것이다.

구간합을 담기 위해서는 리프노드부터 타고 올라가는게 편할 것이다.
이를 위해 재귀적으로 실행하며 세그먼트 트리를 만들어주자.

이제 세그먼트 트리에서 구간합을 구해야 한다.
특정 구간이 주어졌을 때 해당 구간의 합을 세그먼트 트리에서 어떻게 찾을까?

우리는 세그먼트 트리를 생성할 때 트리의 노드는 구간합의 값만을 담았지만, 해당 노드가 어떤 구간의 합을 나타내는지도 알 수 있는 방법이 있다.
바로 트리를 생성할 때 전달해주는 인자인 start, end가 해당 노드가 나타내는 구간이 된다.
그럼 이제 특정 구간 [x, y] 가 주어졌을 때를 생각해보자.

경우의 수는 3가지가 나온다. 더 잘게 쪼개보면 4가지라고 할 수 있을 것이다.

  1. [x, y]가 [start, end]를 완전히 포함하는 경우
  2. [x, y]가 [start, end]에 완전히 포함되는 경우
  3. [x, y]와 [start, end]와 겹치지 않는 경우
  4. [x, y]와 start, end]가 겹쳐져 있는 경우(1,2,3을 제외한 나머지 전부)

1번의 경우는 가장 베스트한 경우이다.
해당 노드의 결과를 리턴해주면 된다.
예를 들어 x, y가 [3, 7]이고 start, end는 [4, 7]이라고 하자.
[start, end]의 모든 범위를 [x, y]에서 가지고 있다.
이 경우 현재 노드의 값을 그대로 반환해주면, 우리는 [4, 7]의 범위의 합을 바로 구하는 것이다.
[3]에 대해서는 생각하지 않아도 된다.
분할 정복이기 때문에 다른 쪽에서 알아서 연산해줄 것이다.

2번의 경우는 1번과 반대의 경우이다.
이 경우에는 더 아래로 내려가야 한다.
위의 예시를 거꾸로 해보자.
x, y가 [4, 7]이고 start, end가 [3, 7]이다.
현재 노드는 3~7까지의 값을 모두 더한 것이기 때문에, 우리가 원하는 4~7의 값을 더한 값을 얻을 수 없다.
그렇다면 현재 노드의 자식노드로 내려가서 다시 연산을 수행하면 된다.

3번의 경우는 볼 것도 없다.
현재 노드에서 아무리 내려가도 원하는 답은 나오지 않는다.
0을 리턴하고 종료하자.

4번은 결국 잘 생각해보면 2번과 동일한 케이스가 된다.

이제 이를 바탕으로 구간합을 구하는 코드를 짜보자.

마지막은 값의 변경이다.

특정 인덱스의 값을 변경하면 그 인덱스를 포함하는 구간합들이 모두 바뀌어야 한다.
이는 트리를 똑같이 탐색하며 현재 [start, end] 범위내에 인덱스가 있으면 해당 구간합에 변경된 수치만큼을 더해주면 된다.

주어지는 인덱스에서 실제 배열의 인덱스는 -1을 해줘야되는데, 이 부분을 그냥 넘어가서 왜 틀렸는지.. 시간을 많이 썼다.

코드

  • 파이썬 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import math
input = sys.stdin.readline

def createSegmentTree(start, end, node, arr, tree):
    if start == end:
        tree[node] = arr[start]
        return tree[node]
    mid = (start+end)//2
    tree[node] = createSegmentTree(start, mid, node * 2, arr, tree) + createSegmentTree(mid+1, end, node*2+1, arr, tree)
    return tree[node]

def calcIntervalSum(start, end, left, right, node, tree):
    if start > right or end < left:
        return 0
    elif start >= left and end <= right:
        return tree[node]
    else:
        mid = (start+end)//2
        return calcIntervalSum(start, mid, left, right, node*2, tree) + calcIntervalSum(mid+1, end, left, right, node*2+1, tree)

def changeValue(start, end, node, idx, dis, tree):
    if idx < start or idx > end:
        return
    mid = (start+end)//2

    tree[node] += dis
    if start != end:
        changeValue(start, mid, node*2, idx, dis, tree)
        changeValue(mid+1, end, node*2 + 1, idx, dis, tree)

if __name__ == "__main__":
    n, q = map(int, input().split())
    arr = list(map(int, input().split()))
    high = math.ceil(math.log2(n))
    size = 2 ** (high + 1)
    tree = [0] * size
    createSegmentTree(0, n-1, 1, arr, tree)

    for _ in range(q):
        x, y, changeNode, changeVal = map(int, input().split())
        if x > y:
            x, y = y, x
        ans = calcIntervalSum(0, n-1, x-1, y-1, 1, tree)
        print(ans)
        dis = changeVal - arr[changeNode-1]
        arr[changeNode-1] += dis
        changeValue(0, n-1, 1, changeNode-1, dis, tree)



This post is licensed under CC BY 4.0 by the author.