플래티넘5 : 세그먼트 트리 문제이다.

생각

전형적인 세그먼트 트리 문제이다. 구간에 대한 합을 연속된 쿼리로 물어보는 경우인데, 이런 경우 단순하게 짜면 그냥 펑!이다. 이 문제 같은 경우는 추가적으로 수가 변형되는데, 이전 BOJ - 최소값(10868)에서 한 단계 업그레이드 되었다고 보면된다.

update

업데이트는 별 것 없다. 내가 바꾸고 싶은 녀석의 트리 하위 노드에 접속하여, 그 노드가 영향을 주는 노드만 쏙쏙 업데이트 해주면 된다. 만약 2번 값을 바꾸고 싶다면 2번이 영향을 주는 노드는 파란색 노드들 이다.

void update(int changed_index, ll changed_value, int index, int start, int end){
        if (start == end) {
           nodes[index] = changed_value;
           return;
        }
        int mid = (start+end)/2;
        if (start <= changed_index && changed_index <= mid) {
            update(changed_index, changed_value, 2*index, start, mid);
        } else {
            update(changed_index, changed_value, 2*index+1, mid+1, end);
        }
        nodes[index] = nodes[2*index] + nodes[2*index+1];
    }

사실 이 부분을 생각할 때, 위 방법 말고, 차이를 다 더해주는 방법을 생각했는데 좋지 않아 적지 않는다.

Code

이번에는 클래스로 구현했다. 그런데, 클래스로 구현하면 속도가 매우 느리다고 한다. private 키워드도 느리니 쓰지 말라한다. 흥. 구조체로 구현하라고 한다.

#include <iostream>
#include <cmath>
#include <vector>
using namespace std;
typedef long long ll;
 
int n, m, k;
ll arr[1000001];
 
class SegmentTree{
private:
    ll* nodes;
    ll* input;
 
public:
    SegmentTree(int size, ll* input){
        int h = ceil(log2(size))+1;
        int treeSize = (1 << h);
        nodes = new ll[treeSize];
        this->input = input;
 
        init(1, 0, size-1);
    }
 
    ~SegmentTree(){
        delete[] nodes;
    }
 
    void init(int index, int start, int end){
        if (start == end) {
            nodes[index] = input[start];
            return;
        }
        int mid = (start+end)/2;
        init(2*index, start, mid);
        init(2*index+1, mid+1, end);
        nodes[index] = nodes[index*2] + nodes[index*2+1];
    }
 
    ll findSum(int index, int start, int end, int left, int right){
        if (left > end || right < start){
            return 0;
        } else if (left <= start && end <= right){
            return nodes[index];
        }
        int mid = (start+end)/2;
        return findSum(2*index, start, mid, left, right) + findSum(2*index+1, mid+1, end, left, right);
    }
 
    void update(int changed_index, ll changed_value, int index, int start, int end){
        if (start == end) {
           nodes[index] = changed_value;
           return;
        }
        int mid = (start+end)/2;
        if (start <= changed_index && changed_index <= mid) {
            update(changed_index, changed_value, 2*index, start, mid);
        } else {
            update(changed_index, changed_value, 2*index+1, mid+1, end);
        }
        nodes[index] = nodes[2*index] + nodes[2*index+1];
    }
};
 
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
 
    cin >> n >> m >> k;
    for (int i = 0; i < n; i++) {
        cin >> arr[i];
    }
    SegmentTree st(n, arr);
 
    int changeCount = 0, printCount = 0;
 
    while (!(changeCount == m && printCount == k)) {
        int a, b;
        cin >> a >> b;
        if (a == 1) {
            ll c;
            cin >> c;
            changeCount++;
            st.update(b-1, c, 1, 0, n-1);
        } else {
            int c;
            cin >> c;
            printCount++;
            cout << st.findSum(1, 0, n-1, b-1, c-1) << '\n';
        }
    }
 
    return 0;
}
#include<iostream>
#include<string>
#include<vector>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
int N, M, K;
 
void init(vector<ll>& input, vector<ll>& tree, int node, int start, int end){
    if (start == end) {
        tree[node] = input[start];
        return;
    }
    int mid = (start+end)/2;
    init(input, tree, 2*node, start, mid);
    init(input, tree, 2*node+1, mid+1, end);
    tree[node] = tree[node*2] + tree[node*2+1];
}
 
ll findSum(vector<ll>& tree, int node, int start, int end, int left, int right){
    if (end < left || start > right) return 0;
    else if (start >= left && end <= right) return tree[node];
 
    int mid = (start+end)/2;
    ll a = findSum(tree, 2*node, start, mid, left, right);
    ll b = findSum(tree, 2*node+1, mid+1, end, left, right);
    return a+b;
}
 
void update(vector<ll>& tree, int loc, ll value, int node, int start, int end) {
    if (start == end) {
        tree[node] = value;
        return;
    }
 
    int mid = (start+end)/2;
 
    if (start <= loc && loc <= mid) {
        update(tree, loc, value, 2*node, start, mid);
    } else {
        update(tree, loc, value, 2*node+1, mid+1, end);
    }
    tree[node] = tree[2*node] + tree[2*node+1];
}
 
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL); cout.tie(NULL);
 
    cin >> N >> M >> K;
    int h = int(ceil(log2(N)));
    int treeSize = (1 << (h+1));
    vector<ll> a(N);
    vector<ll> tree(treeSize, 0);
    for (int i = 0; i < N; i++) {
        cin >> a[i];
    }
    init(a, tree, 1, 0, N-1);
 
    for (int i = 0; i < M+K; i++) {
        int action;
        cin >> action;
        if (action == 1){
            int loc;
            ll value;
            cin >> loc >> value;
            update(tree, loc-1, value, 1, 0, N-1);
 
        } else {
            int left, right;
            cin >> left >> right;
            cout << findSum(tree, 1, 0, N-1, left-1, right-1) << '\n';
        }
    }
 
    return 0;
}
 

Reference