「BZOJ-3295」[Cqoi2011]动态逆序对-树套树

对于序列A,它的逆序对数定义为满足 $i < j$,且 $A_i > A_j$ 的数对 $(i, j)$ 的个数。给 $1$ 到 $n$ 的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。

链接

BZOJ-3295

输入

输入第一行包含两个整数n和m,即初始元素的个数和删除的元素个数。以下n行每行包含一个1到n之间的正整数,即初始排列。以下m行每行一个正整数,依次为每次删除的元素。

输出

输出包含m行,依次为删除每个元素之前,逆序对的个数。

样例

输入

1
2
3
4
5
6
7
8
9
10
5 4
1
5
3
4
2
5
1
4
2

输出

1
2
3
4
5
2
2
1

题解

直接树状数组套上线段树就可以水过了,虽然套主席树更靠谱

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>
const int IN_LEN = 1000000, OUT_LEN = 1000000;
inline int nextChar() {
static char buf[IN_LEN], *h, *t;
if (h == t) {
t = (h = buf) + fread(buf, 1, IN_LEN, stdin);
if (h == t) return -1;
}
return *h++;
}
template<class T>
inline bool read(T &x) {
static bool iosig = 0;
static char c;
for (iosig = 0, c = nextChar(); !isdigit(c); c = nextChar()) {
if (c == -1) return false;
if (c == '-') iosig = 1;
}
for (x = 0; isdigit(c); c = nextChar()) x = (x << 1) + (x << 3) + (c ^ '0');
if (iosig) x = -x;
return true;
}
char obuf[OUT_LEN], *oh = obuf;
inline void writeChar(const char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void write(T x) {
static int buf[30], cnt;
if (!x) writeChar(48);
else {
if (x < 0) writeChar('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) writeChar(buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, oh - obuf, stdout); }
const static int MAXN = 100005;
typedef long long ll;
int n, m, x, num, pos[MAXN];
ll ans;
struct Node {
Node *lc, *rc;
int cnt;
} node[MAXN * 90], *cur = node, *root[MAXN];
inline void insert(Node *&k, int l, int r, int x) {
if (!k) k = ++cur;
k->cnt++;
if (l == r) return;
register int mid = l + r >> 1;
if (x <= mid) insert(k->lc, l, mid, x);
else insert(k->rc, mid + 1, r, x);
}
inline void remove(Node *k, int l, int r, int x) {
if (!k) return;
k->cnt--;
if (l == r) return;
register int mid = l + r >> 1;
if (x <= mid) remove(k->lc, l, mid, x);
else remove(k->rc, mid + 1, r, x);
}
inline int query(Node *k, int l, int r, int s, int t) {
if (!k) return 0;
if (l == s && r == t) return k->cnt;
register int mid = l + r >> 1;
if (t <= mid) return query(k->lc, l, mid, s, t);
else if (s > mid) return query(k->rc, mid + 1, r, s, t);
else return query(k->lc, l, mid, s, mid) + query(k->rc, mid + 1, r, mid + 1, t);
}
inline int lowbit(int k) { return k & -k; }
int main() {
read(n), read(m);
for (register int i = 1; i <= n; i++) {
read(x), pos[x] = i;
for (register int j = i; j <= n; j += lowbit(j)) insert(root[j], 1, n, x);
if (x != n) for (register int j = i; j; j -= lowbit(j)) ans += (ll)query(root[j], 1, n, x + 1, n);
}
for (register int i = 1; i <= m; i++) {
read(x), write(ans), writeChar('\n');
if (x != n) for (register int j = pos[x]; j; j -= lowbit(j)) ans -= (ll)query(root[j], 1, n, x + 1, n);
if (x != 1) {
for (register int j = n; j; j -= lowbit(j)) ans -= (ll)query(root[j], 1, n, 1, x - 1);
for (register int j = pos[x] - 1; j; j -= lowbit(j)) ans += (ll)query(root[j], 1, n, 1, x - 1);
}
for (register int j = pos[x]; j <= n; j += lowbit(j)) remove(root[j], 1, n, x);
}
flush();
return 0;
}
分享到