항상 궁금했던 세그먼트 트리..!
solved.ac에 태그를 보면, 다른 건 뭔지는 알겠는데 '세그먼트 트리'만 전혀 몰라서 궁금했다..!!
공부하고 싶어서 스터디 문제에 추가했고, 오늘 공부해봤다 ~~
단순 배열에서 구간 합을 구하려면 O(N) 걸리는데, 세그먼트 트리를 사용하면 O(logN)으로 시간을 줄여준다 ! (꼭 구간 합을 구할 때만 사용하는 것은 아니고, 구간에서 최댓값이나 최솟값을 찾을 때도 사용한다고 한다.)
내가 세그먼트 트리에 관해서 쓰면 정말 좋겠지만,, 그건 미뤄두고 일단 뭘 참고했는지만 쓰겠다.
먼저 관련된 유튜브를 2개 정도 열심히 봤다!
그리고 아래 쪽 유튜브를 보면서 구현했다. 유튜브에서는 파이썬은 아니고 C?로 쓰신 것 같은데 설명을 듣고 주석을 달면서 비슷하게 코드를 짰다.
https://www.youtube.com/watch?v=LHWLXRfpgf4&t=1813s
https://www.youtube.com/watch?v=ahFB9eCnI6c
import sys
input = sys.stdin.readline
MIIS = lambda:map(int,input().split())
# build - segment tree 생성(루트부터 아래로 반씩 나누면서)
def build(arr, node, nodeleft, noderight):
if nodeleft == noderight: # 리프 노드- 범위에 노드가 하나니까 그걸 넣으면 됨.
segment_tree[node] = arr[nodeleft]
return segment_tree[node]
# 범위에 노드가 두 개 이상이면
mid = int((nodeleft + noderight)/2)
left_value = build(arr, node*2, nodeleft, mid) # 왼쪽 노드
right_value = build(arr, node*2+1, mid+1, noderight) # 오른쪽 노드
segment_tree[node] = left_value+right_value
return segment_tree[node]
# 구간 지정 -> 합 구하기
def total_sum(left, right, node, nodeleft, noderight):
if right<nodeleft or noderight < left: # 범위 아예 바깥에 있는 노드들
return 0 # 그냥 무시
if left <= nodeleft and noderight <= right: # 노드가 범위에 완전히 포함됨.
return segment_tree[node]
# 범위가 노드에 걸친다. (쪼개져서 내려감.)
mid = int((nodeleft + noderight)/2)
return total_sum(left, right, node*2, nodeleft, mid) + total_sum(left, right, node*2+1, mid+1, noderight)
# 값 업데이트
def update(index, newvalue, node, nodeleft, noderight):
if index < nodeleft or noderight < index: # 값 변경되는 리프와 상관없는 노드
return segment_tree[node] # 그대로 값 리턴(업데이트 안 일어남)
if nodeleft==noderight: # 값 변경될 리프
segment_tree[node] = newvalue # 업데이트
return segment_tree[node]
mid = int((nodeleft + noderight) / 2)
left_value = update(index, newvalue, node*2, nodeleft, mid)
right_value = update(index, newvalue, node*2+1, mid+1, noderight)
segment_tree[node] = left_value + right_value
return segment_tree[node]
#입력 받기
N, Q = MIIS()
arr = list(MIIS())
# 세그먼트 트리 만들기
segment_tree = [0]*(N*4 + 1)
build(arr, 1, 0, N-1)
# 명령 수행
for _ in range(Q):
x, y, a, b= MIIS()
if x > y:
x, y = y, x
print(total_sum(x-1, y-1, 1, 0, N-1))
update(a-1, b, 1, 0, N-1)
728x90
'즐거운 PS 👩💻🥰' 카테고리의 다른 글
[백준-파이썬] 비트 마스크 문제들 (0) | 2021.12.11 |
---|---|
[백준-파이썬] 6087 : 레이저 통신 (0) | 2021.12.09 |
[SWEA - 파이썬] 2382.미생물 격리 (0) | 2021.12.06 |
[백준-파이썬] 9844: Gecko (0) | 2021.12.05 |
[백준-파이썬] 23739: 벼락치기 (0) | 2021.12.03 |