树状数组

树状数组

template<class T>
struct BIT {
    T t[N];
    int size;
    void resize(int s) { size = s; }
    T query(int x) {
        assert(x <= size);
        T s = 0;
        for (; x; x -= lowbit(x)) {
            s += t[x];
        }
        return s;
    }
    void add(int x, T s) {//a[x]+=s
        assert(x != 0);
        for (; x <= size; x += lowbit(x)) {
            t[x] += s;
        }
    }
};

应用

直接存点的:

c.add(x,k) 单点修改
c.query(x)-c.query(x-1) 单点查询

c.add(x,k) 单点修改
c.query(r)-c.query(l-1) 区间查询

利用差分存点的:

c.add(l,d); c.add(r+1,-d) 区间修改
c.query(x) 单点查询

利用差分存点的:

c1存i;c2存i*a[i];

c1.add(l,d); c1.add(r+1,-d);c2.add(l,l*d);c2.add(r+1,-(r+1)*d);区间修改
((r+1)*c1.query(r)-c2.query(r))-(l*c1.query(l-1)-c2.query(l-1));区间查询

利用差分存点原理:

a1=d1a2=d1+d2a3=d1+d2+d3a1+a2++an=i=1n(x+1i)di=(x+1)i=1ndii=1nidia_1=d_1 \\ a_2=d_1+d_2\\ a_3=d_1+d_2+d_3\\ ……\\ a_{1}+a_{2}+…+a_{n}=\sum_{i=1}^{n}(x+1-i)d_{i}=(x+1)\sum_{i=1}^{n}d_{i}-\sum_{i=1}^{n}i*d_{i}

注意在add函数中,下标不能有零,否则加上lowbit会死循环

树状数组二分

查询最大的pos满足a[1]+a[2]+……+a[pos]<=s

struct BIT {
    T t[N];
    int size;
    void resize(int s) { size = s; }
    T query(ll s) {//查询最大的pos满足a[1]+a[2]+……+a[pos]<=s
        int pos = 0;
        for (int j = 18; j>=0; j--) {
            if (pos + (1 << j) <= size && t[pos + (1 << j)] <= s) {
                pos += (1 << j);
                s -= t[pos];
            }
        }
        return pos;
    }
    void add(int x, T s) {//a[x]+=s
        assert(x != 0);
        for (; x <= size; x += lowbit(x)) {
            t[x] += s;
        }
    }
};

高维树状数组

二维:

template<class T>
struct BIT {
    T t[N][M];
    int size_n, size_m;
    void resize(int s1, int s2) { size_n = s1; size_m = s2; }
    T query(int x, int y) {
        assert(x <= size_n); assert(y <= size_m);
        T s = 0;
        for (int p = x; p; p -= lowbit(p)) {
            for (int q = y; q; q -= lowbit(q)) {
                s += t[p][q];
            }
        }
        return s;
    }
    void add(int x,int y, T s) {//a[x]+=s
        assert(x != 0); assert(y != 0);
        for (int p = x; p <= size_n; p += lowbit(p)) {
            for (int q = y; q <= size_m; q += lowbit(q)) {
                t[p][q] += s;
            }
        }
    }
};

Last updated