比赛 清华集训2017模板练习 评测结果 AAAAAAAAEE
题目名称 超强的乘法问题 最终得分 80
用户昵称 Cydiater 运行时间 1.296 s
代码语言 C++ 内存使用 10.33 MiB
提交时间 2017-07-17 16:49:07
显示代码纯文本
#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define up(i,j,n)	for(int i=j;i<=n;i++)
#define down(i,j,n)	for(int i=j;i>=n;i--)
#define cmax(a,b)	a=max(a,b)
#define cmin(a,b)	a=min(a,b)
#define cadd(a,b)	a=add(a,b)
#define cpop(a,b)	a=pop(a,b)
#define cmul(a,b)	a=mul(a,b)
#define bin(i)		(1<<(i))
#define FILE		"bettermul"

const int MAXN=5e5+5;
const int oo=0x3f3f3f3f;
const int mod=998244353;
const int g=3;

int mul(int a,int b){return 1LL*a*b%mod;}
int add(int a,int b){a+=b;return a>=mod?a-mod:a;}
int pop(int a,int b){a-=b;return a<0?a+mod:a;}

int qpow(int a,int b){
	int c=1;
	while(b){
		if(b&1)cmul(c,a);
		cmul(a,a);b>>=1;
	}
	return c;
}

int N,M,A[MAXN],B[MAXN],NM,num[MAXN];
char s[MAXN];
ll sum=0;

namespace NTT{
	int omega[MAXN],inv[MAXN];
	int Fix(int x){
		int p=0;
		while(bin(p)<x)p++;
		return bin(p);
	}
	void Prepare(int N){
		int x=qpow(g,(mod-1)/N);
		up(i,0,N-1){
			omega[i]=(!i?1:mul(omega[i-1],x));
			inv[i]=qpow(omega[i],mod-2);
		}
	}
	void Transform(int N,int *A,int *w){
		int p=0;
		up(i,0,N-1){
			if(p<i)swap(A[p],A[i]);
			for(int t=N>>1;(p^=t)<t;t>>=1);
		}
		for(int l=2;l<=N;l<<=1){
			int d=l>>1;
			for(int j=0;j<N;j+=l)up(k,0,d-1){
				int tmp=mul(A[j+k+d],w[N/l*k]);
				A[j+k+d]=pop(A[j+k],tmp);
				A[j+k]=add(A[j+k],tmp);
			}
		}
	}
	void DFT(int N,int *A){
		Transform(N,A,omega);
	}
	void IDFT(int N,int *A){
		Transform(N,A,inv);
		int INV=qpow(N,mod-2);
		up(i,0,N-1)cmul(A[i],INV);
	}
}

namespace solution{
	void Prepare(){
		scanf("%s",s);
		N=strlen(s);
		reverse(s,s+N);
		up(i,0,N-1)A[i]=s[i]-'0';
		scanf("%s",s);
		M=strlen(s);
		reverse(s,s+M);
		up(i,0,M-1)B[i]=s[i]-'0';
	}
	void Solve(){
		NM=NTT::Fix(N+M);
		NTT::Prepare(NM);
		NTT::DFT(NM,A);
		NTT::DFT(NM,B);
		up(i,0,NM-1)cmul(A[i],B[i]);
		NTT::IDFT(NM,A);
		up(i,0,NM-1){
			sum+=A[i];
			num[i]=sum%10;
			sum/=10;
		}
		if(sum)printf("%lld",sum);
		int upper=NM;
		while(upper>=0&&num[upper]==0)upper--;
		down(i,upper,0)printf("%d",num[i]);
		if(upper==-1)printf("0");
		puts("");
	}
}

int main(){
	freopen(FILE".in","r",stdin);
	freopen(FILE".out","w",stdout);
	using namespace solution;
	Prepare();
	Solve();
	return 0;
}