HYSBZ 2588 Spoj 10628. Count on a tree 树链剖分+主席树

518人浏览 / 25人评论

题目链接

  • Spoj 10628. Count on a tree
    • 需要求一棵树上两点之间的第k小点权。
    • “第k小”想到主席树,“树上两点之间”想到树链剖分。
    • 我们先对这棵树进行轻重链剖分,得到每一条重链的顶端、各结点从属的重链等信息。
    • 树链剖分完成后,下面考虑如何建立树上主席树:
    • 我们把每个结点在其父结点的基础上建立主席树即可。
    • 所以,在树链剖分的第二次DFS中,我们可以建立好主席树。
    • 下面考虑如何在主席树上得出结点u到结点v之间所有点的权值线段树信息:
    • u与v之间的权值线段树=u与根的权值线段树+v与根的权值线段树-lca(u, v)与根的权值线段树-fa[lca(u, v)]的权值线段树
    • 我们只需要在查询时提供四个根结点,按照以上等式即可算出u与v之间的权值线段树,由此找出u与v之间结点的第k小值。
    • 注意:lca不需用倍增法,只需根据树链剖分即可求得。

代码:

#pragma GCC optimize(2)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
using namespace std;
const int maxn=100010;
map<int,int>get_ori;
struct Node
{
	int num,lch,rch;
	Node():num(0),lch(-1),rch(-1){}
}tree[maxn*40];
int tot2,a[maxn],rt[maxn],value[maxn],v[maxn];
inline void pushup(int p)
{
	tree[p].num=tree[tree[p].lch].num+tree[tree[p].rch].num;
}
int build(int l,int r)
{
	int p=tot2++;
	tree[p]=Node();
	if(l==r)
	{
		tree[p].num=0;
		return p;
	}
	int mid=(l+r)>>1;
	tree[p].lch=build(l,mid);
	tree[p].rch=build(mid+1,r);
	pushup(p);
	return p;
}
int add(int p,int l,int r,int x,int y,int z)
{
	int cp=tot2++;
	tree[cp]=Node();
	if(x<=l&&r<=y)
	{
		tree[cp].num=tree[p].num+r-l+1;
		return cp;
	}
	int mid=(l+r)>>1;
	if(x<=mid)
	{
		tree[cp].lch=add(tree[p].lch,l,mid,x,y,z);
	}
	else
	{
		tree[cp].lch=tree[p].lch;
	}
	if(mid<y)
	{
		tree[cp].rch=add(tree[p].rch,mid+1,r,x,y,z);
	}
	else
	{
		tree[cp].rch=tree[p].rch;
	}
	pushup(cp);
	return cp;
}
int find(int x,int y,int lca,int flca,int l,int r,int k)
{
	if(l==r)
	{
		return v[l];
	}
	int mid=(l+r)>>1;
	int tmp=tree[tree[y].lch].num+tree[tree[x].lch].num-tree[tree[lca].lch].num-tree[tree[flca].lch].num;
	if(tmp>=k)
	{
		return find(tree[x].lch,tree[y].lch,tree[lca].lch,tree[flca].lch,l,mid,k);
	}
	return find(tree[x].rch,tree[y].rch,tree[lca].rch,tree[flca].rch,mid+1,r,k-tmp);
}
struct Edge
{
	int to,next;
}edge[maxn<<1];
int head[maxn],tot,top[maxn],fa[maxn],deep[maxn],num[maxn],p[maxn],fp[maxn],son[maxn],pos;
void init()
{
	tot=pos=0;
	memset(head,-1,sizeof(head));
	memset(son,-1,sizeof(son));
}
void addedge(int u,int v)
{
	edge[tot].to=v;
	edge[tot].next=head[u];
	head[u]=tot++;
}
void dfs1(int u,int pre,int d)
{
	deep[u]=d;
	fa[u]=pre;
	num[u]=1;
	for(int i=head[u];i!=-1;i=edge[i].next)
	{
		int v=edge[i].to;
		if(v!=pre)
		{
			dfs1(v,u,d+1);
			num[u]+=num[v];
			if(son[u]==-1||num[v]>num[son[u]])
			{
				son[u]=v;
			}
		}
	}
}
void get_pos(int u,int sp,int n)
{
	top[u]=sp;
	p[u]=++pos;
	fp[p[u]]=u;
	int current_pos=get_ori[value[u]];
	rt[u]=add(rt[fa[u]],1,n,current_pos,current_pos,current_pos);
	if(son[u]==-1)
	{
		return;
	}
	get_pos(son[u],sp,n);
	for(int i=head[u];i!=-1;i=edge[i].next)
	{
		int v=edge[i].to;
		if(v!=son[u]&&v!=fa[u])
		{
			get_pos(v,v,n);
		}
	}
}
int lca(int a,int b)
{
	int x=top[a],y=top[b];
	while(x!=y)
	{
		if(deep[x]>deep[y])
		{
			a=fa[x];
			x=top[a];
		}
		else
		{
			b=fa[y];
			y=top[b];
		}
	}
	return deep[a]>deep[b]?b:a;
}
int main()
{
	init();
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i)
	{
		scanf("%d",&value[i]);
		v[i]=value[i];
	}
	sort(v+1,v+1+n);
	int cnt=unique(v+1,v+1+n)-v-1;
	for(int i=1;i<=cnt;++i)
	{
		get_ori[v[i]]=i;
	}
	int u,v;
	for(int i=0;i<n-1;++i)
	{
		scanf("%d%d",&u,&v);
		addedge(u,v);
		addedge(v,u);
	}
	dfs1(1,0,0);
	rt[0]=build(1,cnt);
	get_pos(1,1,cnt);
	int k,ans=0;
	while(m--)
	{
		scanf("%d%d%d",&u,&v,&k);
		u^=ans;
		int current_lca=lca(u,v);
		printf("%d\n",ans=find(rt[u],rt[v],rt[current_lca],rt[fa[current_lca]],1,cnt,k));
	}
	return 0;
}

全部评论