1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| constexpr ll mod = 998244353;
ll qmi(ll a, ll k, ll m){ a %= m; ll res = 1 % m; while (k) { if (k&1) res = res * a % m; a = a * a % m; k >>= 1; } return res; }
void solve(){ ll n, m, A, B; std::cin >> n >> m >> A >> B;
ll dp[31][2][2][2][2][2][2]; memset(dp, -1, sizeof(dp)); std::function<ll(int, int, int, int, int, int, int)> dfs = [&](int pos, int limit1, int limit2, int same1, int limit3, int limit4, int same2) -> ll{ if(pos < 0){ return ((!same1) && (!same2)); } if(dp[pos][limit1][limit2][same1][limit3][limit4][same2] != -1) return dp[pos][limit1][limit2][same1][limit3][limit4][same2]; ll ans = 0; for(int i = 0; i <= 1; i++){ for(int j = 0; j <= 1; j++){ for(int k = 0; k <= 1; k++){ for(int l = 0; l <= 1; l++){ if((i + j + k + l) % 2 == 0){ int up1 = (A >> pos) & 1; int up2 = (B >> pos) & 1; if(i > up1 && limit1) continue; if(j > up1 && limit2) continue; if(k > up2 && limit3) continue; if(l > up2 && limit4) continue;
ans = (ans + dfs(pos - 1, limit1 && (i == up1), limit2 && (j == up1), (i == j) & same1, limit3 && k == up2, limit4 && l == up2, (k == l) & same2)) % mod; } } } } } return dp[pos][limit1][limit2][same1][limit3][limit4][same2] = ans; }; ll res = dfs(30, true, true, true, true, true, true); res = res * qmi(4, mod - 2, mod) % mod;
ll ans = 0; ans = (res *(qmi(2, n, mod) - 2) % mod * (qmi(2, m, mod) - 2) % mod + ans) % mod;
ans = (ans + (A + 1) * (B + 1) % mod) % mod;
ans = (ans + (A + 1) * (B + 1) % mod * B % mod * (qmi(2, m, mod) - 2) % mod * qmi(2, mod - 2, mod) % mod) % mod; ans = (ans + (B + 1) * (A + 1) % mod * A % mod * (qmi(2, n, mod) - 2) % mod * qmi(2, mod - 2, mod) % mod) % mod;
ans = (ans + mod) % mod; std::cout << ans << endl; }
|