不知道这题能不能发出来,如果不能请联系我,我什么都会做的
题意:给一棵 nnn 个结点的树,每个结点有个 ax+bax+bax+b,求所有根到叶子的乘积之和。系数模 998244353998244353998244353。
链的情况就是分治 NTT,所以树上没有弱于这个的做法。
考虑链分治,先对树做长链剖分,然后对根所在的链分治,维护两个多项式,一个链上所有结点的乘积,一个从区间起点往下走,从区间中某个位置拐出去,走到所有叶子的路径乘积之和。递归到分治树的叶子的时候就递归算原树上的轻儿子。
为了保证复杂度,NTT 的长度应该开当前区间所有虚儿子的最大深度和区间长度的较大值,而非区间起点的深度。这样每条链只会在链头的父亲所在的链 分治的时候贡献 O(logn)\Omicron(\log n)O(logn) 次 NTT 的长度,总复杂度是 O(nlog2n)\Omicron(n\log^2n)O(nlog2n),并且上界很松。
第一次写封装多项式,挺舒服的
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define MAXN ((1<<18)+5)
using namespace std;
inline int read()
{int ans=0;char c=getchar();while (!isdigit(c)) c=getchar();while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{int ans=1;while (p){if (p&1) ans=(ll)ans*a%MOD;a=(ll)a*a%MOD,p>>=1;}return ans;
}
#define inv(x) qpow(x,MOD-2)
int rt[2][24];
int r[MAXN],l,lim;
inline void init(){lim=1<<l;for (int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));}
void ntt(int* a,int type)
{for (int i=0;i<lim;i++) if (i<r[i]) swap(a[i],a[r[i]]);for (int L=0;L<l;L++){int mid=1<<L,len=mid<<1;int Wn=rt[type][L+1];for (int s=0;s<lim;s+=len){ll w=1;for (int k=0;k<mid;k++,w=w*Wn%MOD){int x=a[s+k],y=w*a[s+mid+k]%MOD;a[s+k]=add(x,y),a[s+mid+k]=dec(x,y);}}}if (type){int t=inv(lim);for (int i=0;i<lim;i++) a[i]=(ll)a[i]*t%MOD;}
}
struct poly
{int *a,n;inline poly():n(0){}inline poly(int x):n(x){a=new int[x];memset(a,0,sizeof(int)*n);}inline poly(int x,int y):n(2){a=new int[2];a[0]=x,a[1]=y;}inline int& operator [](const int& i){return a[i];}inline const int& operator [](const int& i)const{return a[i];}
};
inline poly operator *(const poly& a,const poly& b)
{static int ta[MAXN],tb[MAXN];poly c(a.n+b.n-1);for (l=0;(1<<l)<c.n;++l);init();for (int i=0;i<lim;i++) ta[i]=tb[i]=0;for (int i=0;i<a.n;i++) ta[i]=a[i];for (int i=0;i<b.n;i++) tb[i]=b[i];ntt(ta,0),ntt(tb,0);for (int i=0;i<lim;i++) ta[i]=(ll)ta[i]*tb[i]%MOD;ntt(ta,1);for (int i=0;i<c.n;i++) c[i]=ta[i];return c;
}
inline poly operator +(const poly& a,const poly& b)
{poly c(max(a.n,b.n));for (int i=0;i<c.n;i++) c[i]=add(i<a.n? a[i]:0,i<b.n? b[i]:0);return c;
}
vector<int> e[MAXN];
int buf[MAXN],*tp=buf;
int fa[MAXN],son[MAXN],mx[MAXN];
int *lis[MAXN];
inline int* newbuf(int x){int* p=tp;tp+=x;return p;}
void dfs(int u,int f)
{fa[u]=f;for (int i=0;i<(int)e[u].size();i++)if (e[u][i]!=f){dfs(e[u][i],u);if (mx[e[u][i]]>mx[son[u]]) son[u]=e[u][i];}mx[u]=mx[son[u]]+1;
}
void dfs(int u,int* cur)
{*(lis[u]=cur)=u;if (son[u]) dfs(son[u],cur+1);for (int i=0;i<(int)e[u].size();i++)if (e[u][i]!=fa[u]&&e[u][i]!=son[u])dfs(e[u][i],newbuf(mx[e[u][i]]));
}
int rval[MAXN],gval[MAXN];
pair<poly,poly> solve(int* L,int* R)
{if (L==R){int u=*L;poly tmp;for (int i=0;i<(int)e[u].size();i++)if (e[u][i]!=fa[u]&&e[u][i]!=son[u])tmp=tmp+solve(lis[e[u][i]],lis[e[u][i]]+mx[e[u][i]]-1).second;if ((int)e[u].size()==(fa[u]>0)) tmp=poly(1),tmp[0]=1;return make_pair(poly(rval[u],gval[u]),poly(rval[u],gval[u])*tmp);}int* mid=L+((R-L)>>1);pair<poly,poly> lans=solve(L,mid),rans=solve(mid+1,R);return make_pair(lans.first*rans.first,lans.first*rans.second+lans.second);
}
poly ans;
int main()
{freopen("slime.in","r",stdin);freopen("slime.out","w",stdout);rt[0][23]=qpow(3,119),rt[1][23]=inv(rt[0][23]);for (int i=22;i>=0;i--){rt[0][i]=(ll)rt[0][i+1]*rt[0][i+1]%MOD;rt[1][i]=(ll)rt[1][i+1]*rt[1][i+1]%MOD;}int n=read();read();for (int i=1;i<=n;i++) rval[i]=read();for (int i=1;i<=n;i++) gval[i]=read();for (int i=1;i<n;i++){int u,v;u=read(),v=read();e[u].push_back(v),e[v].push_back(u);}dfs(1,0);dfs(1,newbuf(mx[1]));ans=solve(lis[1],lis[1]+mx[1]-1).second;for (int i=0;i<=n;i++) printf("%d\n",(i<ans.n? ans[i]:0));return 0;
}