DTOJ5021 最近公共祖先

题目

注:本题来源于2020牛客暑期多校训练营(第六场)D题data structure

题目描述

作为此次 NOIP 模拟的最后一道题,宫水三叶决定把题意说得简单一点
给一棵大小为$n$的以$rt$为根的树
有$m$组询问,每次询问 l,r,xl,r,xl,r,x,你要回答有多少$l \leqslant a < b \leqslant r$,满足$a,b$的最近公共祖先为$x$

输入格式

第一行三个整数$n,m,rt$
接下来$n-1$行,每行两个整数$x_i,y_i$,表示一条边
接下来$m$行,每行三个整数$l_i,r_i,x_i$,表示一组询问

输出格式

输出共$m$行,第$i$行表示第$i$个询问的答案

样例

样例输入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
10 10 7
4 2
10 4
3 2
6 10
9 2
7 3
1 4
8 2
5 3
8 10 10
2 6 2
3 6 2
4 6 4
3 10 2
8 8 10
3 10 4
2 3 2
2 6 4
1 7 10

样例输出

1
2
3
4
5
6
7
8
9
10
0
2
0
1
7
0
2
0
1
0

数据范围与提示

本题采用捆绑测试
对于所有测试点,满足$1\leqslant n,m \leqslant 2\times 10^5,1\leqslant rt \leqslant n$
子任务编号 | $n,m$ | 分值
—|—|—
$1$|$\leqslant 200$|$5$
$2$|$\leqslant 2000$|$20$
$3$|$\leqslant 5\times 10^4$|$35$
$4$|$\leqslant 2\times 10^5$|$40$
提示:本题时间限制为 2S ,请选手注意 IO 用时

题解

考虑将问题转化为:$x$的子树中在$[l,r]$之间的点对数,减去$x$的子节点的子树中在$[l,r]$之间的点对数(因为$x$的子节点一定是它的子树中在$[l,r]$之间的点对的公共祖先,所以最近公共祖先一定不为$x$)
前半部分很简单,直接使用主席树就可以了(这个稍微想一想就可以了,是基本的主席树)
主要是后半部分如何解决
有一个可以很容易想到的就是先重链剖分,然后就可以让重儿子按照上面的方式去算
对于轻儿子的话,没有什么太好的处理方法,所以考虑用莫队,$l$和$r$变化时,就让这个点一直沿着重链条就可以了
效率为$\Theta(mlogn+n\sqrt{n}logn)$,如果常数小(用莫队时的每一块的大小选得好),可能可以得$60$分
代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include<algorithm>
#include<iostream>
using namespace std;
struct ppap2
{
int l,r,x,pos;
}q[200010];
struct ppap3
{
int l,r,sum,ch[2];
}t[20000010];
vector<ppap2>Q[200010];
int n,m,rt,tot,cnt,Tot,d,head[200010],to[400010],nxt[400010],fa[200010],siz[200010],son[200010],top[200010],dfn[200010],nfd[200010],Dfn[200010],root[200010],k[200010];
long long s[200010],sum[200010],ans[200010];
template<class T>void read(T &x)
{
x=0;int f=0;char ch=getchar();
while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
int cmp(const ppap2 &x,const ppap2 &y)
{
return k[x.l]==k[y.l]?((k[x.l]&1)?x.r<y.r:x.r>y.r):k[x.l]<k[y.l];
}
void add(int x,int y)
{
nxt[++tot]=head[x],to[tot]=y,head[x]=tot;
}
int dfs1(int x,int Fa)
{
fa[x]=Fa,siz[x]=1;
int maxson=-1;
for(int i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]){
siz[x]+=dfs1(to[i],x);
if(siz[to[x]]>maxson) maxson=siz[to[x]],son[x]=to[i];
}
return siz[x];
}
void dfs2(int x,int Top)
{
top[x]=Top,dfn[x]=++cnt,nfd[cnt]=x;
if(son[x]) dfs2(son[x],Top);
for(int i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) dfs2(to[i],to[i]);
Dfn[x]=cnt;
}
int ask(int p,int x,int y)
{
if((!t[p].l)&&(!t[p].r)) return 0;
if(t[p].l==x&&t[p].r==y) return t[p].sum;
int mid=(t[p].l+t[p].r)/2;
if(y<=mid) return ask(t[p].ch[0],x,y);
else if(x>=mid+1) return ask(t[p].ch[1],x,y);
else return ask(t[p].ch[0],x,mid)+ask(t[p].ch[1],mid+1,y);
}
void build(int &p,int q,int l,int r,int x)
{
t[++Tot].l=l,t[Tot].r=r,p=Tot;
if(l==r){t[p].sum++;return;}
int mid=(l+r)>>1;
t[p].ch[x<=mid]=t[q].ch[x<=mid];
if(x>mid) build(t[p].ch[1],t[q].ch[1],mid+1,r,x);
else build(t[p].ch[0],t[q].ch[0],l,mid,x);
t[p].sum=t[t[p].ch[0]].sum+t[t[p].ch[1]].sum;
}
void add(int x)
{
for(int i;i=fa[top[x]];x=i) sum[i]+=s[top[x]],s[top[x]]++;
}
void dec(int x)
{
for(int i;i=fa[top[x]];x=i) s[top[x]]--,sum[i]-=s[top[x]];
}
void js()
{
for(int i=1;i<=n;i++) build(root[i],root[i-1],1,n,nfd[i]);
for(int i=1;i<=n;i++) if(son[i]) for(ppap2 j:Q[i]){
int x=ask(root[Dfn[i]],j.l,j.r)-ask(root[dfn[i]-1],j.l,j.r),sum=ask(root[Dfn[son[i]]],j.l,j.r)-ask(root[dfn[son[i]]-1],j.l,j.r);
ans[j.pos]+=1ll*x*(x-1)/2-1ll*sum*(sum-1)/2;
}
d=n/(int)sqrt(n*2/3);
for(int i=1;i<=n;i++) k[i]=(i-1)/d;
sort(q+1,q+m+1,cmp);
for(int i=1,l=1,r=0;i<=m;i++){
while(r<q[i].r) r++,add(r);
while(l>q[i].l) l--,add(l);
while(r>q[i].r) dec(r),r--;
while(l<q[i].l) dec(l),l++;
ans[q[i].pos]-=sum[q[i].x];
}
}
int main()
{
read(n),read(m),read(rt);
for(int i=1,x,y;i<n;i++) read(x),read(y),add(x,y),add(y,x);
for(int i=1;i<=m;i++) read(q[i].l),read(q[i].r),read(q[i].x),q[i].pos=i,Q[q[i].x].push_back(q[i]);
dfs1(rt,0),dfs2(rt,rt),js();
for(int i=1;i<=m;i++) cout<<ans[i]<<endl;
}

考虑怎么优化呢?刚刚的问题在于轻儿子可能很多,这大大增加了第二部分的时间复杂度,怎么办呢?我们考虑以为伪树剖,设置一个闸值$Siz$,对于所有$size>Siz$的子树,把它们都当做$x$的重儿子
这样,我们就可以减少轻儿子的数量,从而平衡两边的复杂度
好像还有可以优化的,就是主席树可以换成一个树状数组,这样代码即短,又可以节省时间(因为主席树是一棵升级版的线段树,虽然线段树的效率是$\Theta(nlogn)$,但是常数巨大,受许多因素影响,但是树状数组的效率比较稳定,这样更容易通过,不需要卡常)
附上代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#pragma GCC optimize(2)
#include<algorithm>
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
using namespace std;
struct ppap1
{
int l,r,x,h;
}q[200010];
struct ppap2
{
int x,h,f1,f2,f3;
};
struct ppap3
{
int c[200010];
}t;
vector<int> hson[200010],lson[200010];
vector<ppap2> Q[200010];
vector<int> temp;
int n,m,rt,tot,cnt,Tot,d,nt=-1,head[200010],to[400010],nxt[400010],fa[200010],siz[200010],top[200010],dfn[200010],Dfn[200010],root[200010],k[200010];
long long s[200010],sum[200010],ans[200010];
template<class T>void read(T &x)
{
x=0;int f=0;char ch=getchar();
while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
int cmp(const ppap1 &x,const ppap1 &y)
{
return k[x.l]==k[y.l]?((k[x.l]&1)?x.r<y.r:x.r>y.r):k[x.l]<k[y.l];
}
void add(int x,int y)
{
nxt[++tot]=head[x],to[tot]=y,head[x]=tot;
}
void dfs1(int x)
{
siz[x]=1,dfn[x]=++cnt;
for(int i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]) fa[to[i]]=x,dfs1(to[i]),siz[x]+=siz[to[i]];
int Siz=sqrt(siz[x])/3;
for(int i=head[x];i;i=nxt[i]) if(to[i]!=fa[x])
if(siz[to[i]]>Siz) hson[x].push_back(to[i]);
else lson[x].push_back(to[i]);
Dfn[x]=cnt;
}
void dfs2(int x,int Top)
{
top[x]=Top;
for(int y:hson[x]) dfs2(y,Top);
for(int y:lson[x]) dfs2(y,y);
}
void Add(int x)
{
for(;x<=n;x+=(x&(-x))) t.c[x]++;
}
int ask(int x)
{
int ans=0;
for(;x;x-=(x&(-x))) ans+=t.c[x];
return ans;
}
void add(int x)
{
for(int i;i=fa[top[x]];x=i) sum[i]+=s[top[x]]++;
}
void dec(int x)
{
for(int i;i=fa[top[x]];x=i) sum[i]-=--s[top[x]];
}
void js()
{
for(int i=1;i<=n;i++){
Add(dfn[i]);
for(ppap2 x:Q[i]){
temp[x.f3]+=(ask(Dfn[x.x])-ask(dfn[x.x]-1))*x.f2;
if(x.f2==1) ans[x.h]+=1ll*temp[x.f3]*(temp[x.f3]-1)/2*(x.f1?1:-1);
}
}
d=n/sqrt(n*2/3);
for(int i=1;i<=n;i++) k[i]=(i-1)/d;
sort(q+1,q+m+1,cmp);
for(int i=1,l=1,r=0;i<=m;i++){
while(r<q[i].r) add(++r);
while(l>q[i].l) add(--l);
while(r>q[i].r) dec(r--);
while(l<q[i].l) dec(l++);
ans[q[i].h]-=sum[q[i].x];
}
}
int main()
{
read(n),read(m),read(rt);
for(int i=1,x,y;i<n;i++) read(x),read(y),add(x,y),add(y,x);
dfs1(rt),dfs2(rt,rt);
for(int i=1;i<=m;i++){
read(q[i].l),read(q[i].r),read(q[i].x),q[i].h=i,temp.push_back(0),nt++,Q[q[i].l-1].push_back((ppap2){q[i].x,i,1,-1,nt}),Q[q[i].r].push_back((ppap2){q[i].x,i,1,1,nt});
for(int x:hson[q[i].x]) temp.push_back(0),nt++,Q[q[i].l-1].push_back((ppap2){x,i,0,-1,nt}),Q[q[i].r].push_back((ppap2){x,i,0,1,nt});
}
js();
for(int i=1;i<=m;i++) printf("%lld\n",ans[i]);
}

温馨提示:
请不要把排序的$cmp$程序写成类似于这样的形式:
1
2
3
4
5
6
int cmp(const ppap1 &x,const ppap1 &y)
{
return k[x.l]==k[y.l]&&(k[x.l]&1)&&x.r<y.r;
return k[x.l]==k[y.l]&&x.r>y.r;
return k[x.l]<k[y.l];
}

否则,你就会T,虽然我也不知道是什么原因,但是我知道:三目运算符行!(难道是太多&&了?求大佬解答)就为了这个问题,我调了快一个小时/kk