백준 3783, Binary Search
https://www.acmicpc.net/problem/3783
알고리즘 문제 해결 전략 #13 에 나오는 bisection method 를 적용하기 위한 간단한 연습문제로 생각을 했으나,
소수 정밀도 때문에 bigint 자료형을 직접 만들어서 binary search 를 적용해서 풀었다.
bisection method 를 통해 정밀한 세제곱근을 구해서 푸는 방법은 소수 정밀도의 한계 때문에 힘들어 보인다.
SPOJ 의 같은 문제를 풀은 아래 링크들이나, BOJ 에서 해결한 사람들 모두 이분탐색을 많이 사용했다.
Source Code
//BOJ3783
#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
using namespace std;
const int base = 100000;
int T;
struct bigInt {
#define ll long long
int lastdigit;
const ll base = 100000;
// 5 digit chunk
ll digits[505];
bigInt() {
lastdigit = 0;
memset(digits, 0, sizeof(digits));
}
bigInt operator = (bigInt rhs) {
lastdigit = rhs.lastdigit;
for (int i = 0; i < 505; ++i) {
digits[i] = rhs.digits[i];
}
return *this;
}
};
int pow10(int n) {
int ret = 1;
for (int i = 0; i < n; ++i) {
ret *= 10;
}
return ret;
}
void zero_justify(bigInt& n) {
while (n.lastdigit > 0 && n.digits[n.lastdigit] == 0) {
--(n.lastdigit);
}
}
void init(const string& s, bigInt& n) {
// get input. 5 digit chunks
memset(n.digits, 0, sizeof(n.digits));
int ssize = s.size();
int idx, chunk;
idx = 0;
for (int i = 0; i < ssize; i += 5) {
chunk = 0;
for (int j = 0; j < 5; ++j) {
int tidx = ssize - 1 - i - j;
if (tidx >= 0) {
chunk += (s[tidx] - '0') * pow10(j);
}
}
n.digits[idx] = chunk;
++idx;
}
n.lastdigit = idx;
zero_justify(n);
}
int compare(bigInt& lhs, bigInt& rhs) {
//lhs < rhs:1, lhs>rhs:-1, lhs == rhs:0
if (lhs.lastdigit < rhs.lastdigit) {
return 1;
}
if (lhs.lastdigit > rhs.lastdigit) {
return -1;
}
int maxdigit = max(lhs.lastdigit, rhs.lastdigit);
for (int i = maxdigit; i >= 0; --i) {
if (lhs.digits[i] > rhs.digits[i]) {
return -1;
}
if (lhs.digits[i] < rhs.digits[i]) {
return 1;
}
}
return 0;
}
void digitshift(bigInt& n, int d) {
// mutiply 100000(base)
if (n.lastdigit == 0 && n.digits[0] == 0) return;
for (int i = n.lastdigit; i >= 0; --i) {
n.digits[i + d] = n.digits[i];
}
for (int i = 0; i < d; ++i) {
n.digits[i] = 0;
}
n.lastdigit += d;
zero_justify(n);
return;
}
bigInt operator + (bigInt lhs, bigInt rhs) {
bigInt ret;
ll carry = 0;
ret.lastdigit = max(lhs.lastdigit, rhs.lastdigit) + 1;
for (int i = 0; i <= ret.lastdigit; ++i) {
ll tmp = lhs.digits[i] + rhs.digits[i] + carry;
ret.digits[i] = tmp % lhs.base;
carry = tmp / lhs.base;
}
zero_justify(ret);
return ret;
}
bigInt operator + (bigInt lhs, ll rhs) {
bigInt ret = lhs;
ll carry = 0;
ret.digits[0] += rhs;
ret.lastdigit++;
for (int i = 0; i <= ret.lastdigit; ++i) {
ll tmp = ret.digits[i] + carry;
ret.digits[i] = tmp % lhs.base;
carry = tmp / lhs.base;
}
zero_justify(ret);
return ret;
}
bigInt operator - (bigInt lhs, ll rhs) {
bigInt ret = lhs;
ll borrow = 0;
ret.digits[0] -= rhs;
for (int i = 0; i <= ret.lastdigit; ++i) {
ll tmp = ret.digits[i] - borrow;
if (ret.digits[i] > 0) {
borrow = 0;
}
if (tmp < 0) {
tmp += lhs.base;
borrow = 1;
}
ret.digits[i] = tmp % lhs.base;
}
zero_justify(ret);
return ret;
}
bigInt operator * (bigInt lhs, bigInt rhs) {
bigInt ret;
int ldigit = lhs.lastdigit, rdigit = rhs.lastdigit;
for (int i = 0; i <= ldigit; ++i) {
for (int j = 0; j <= rdigit; ++j) {
ll tmp = lhs.digits[i] * rhs.digits[j];
ret.digits[i + j] += tmp;
}
}
ll carry = 0;
ret.lastdigit = ldigit + rdigit + 1;
for (int i = 0; i <= ret.lastdigit; ++i) {
ll tmp = ret.digits[i] + carry;
ret.digits[i] = tmp % lhs.base;
carry = tmp / lhs.base;
}
zero_justify(ret);
return ret;
}
bigInt operator / (bigInt lhs, long long rhs) {
if (rhs == 0) return lhs;
long long tmp, r = 0;
bigInt q = lhs;
for (int i = lhs.lastdigit; i >= 0; --i) {
tmp = lhs.base * r + q.digits[i];
q.digits[i] = tmp / rhs;
r = tmp % rhs;
}
zero_justify(q);
return q;
}
bool operator > (bigInt lhs, bigInt rhs) {
if (compare(lhs, rhs) == -1) {
return true;
}
else return false;
}
bool operator < (bigInt lhs, bigInt rhs) {
if (compare(lhs, rhs) == 1) {
return true;
}
else return false;
}
bool operator <= (bigInt lhs, bigInt rhs) {
return !(rhs < lhs);
}
bool operator >= (bigInt lhs, bigInt rhs) {
return !(rhs > lhs);
}
bool operator == (bigInt lhs, bigInt rhs) {
if (compare(lhs, rhs) == 0) {
return true;
}
else return false;
}
void print10(bigInt a) {
//print the number with 10 decimal digits
if (a.lastdigit < 20) {
printf("0.");
for (int i = 19; i >= 18; --i) printf("%05d", (int)a.digits[i]);
return;
}
for (int i = a.lastdigit; i >= 18; --i) {
if (i == a.lastdigit) printf("%d", (int)a.digits[i]);
else printf("%05d", (int)a.digits[i]);
if (i == 20) printf(".");
}
}
void print(bigInt n) {
printf("%d", (int)n.digits[n.lastdigit]);
for (int i = n.lastdigit - 1; i >= 0; --i) printf("%05d", (int)n.digits[i]);
return;
}
void proc() {
//ifstream cin;
//cin.open("input.txt");
cin >> T;
while (T--) {
string s;
cin >> s;
bigInt lo, hi, target;
init("0", lo);
init(s, hi);
digitshift(hi,60);
//multiply hi large enough to guarantee the precision
//shift 60 digits -> decimal separator starts from (60/3)
target = hi;
bigInt mid;
//binary search
while (lo <= hi) {
mid = (lo + hi) / 2;
if (mid * (mid * mid) > target)
hi = mid - 1;
else
lo = mid + 1;
}
lo = lo - 1;
print10(lo);
printf("\n");
}
}
void test() {
bigInt a, b;
init("100000", b);
init("99999", a);
bigInt c = a*b;
print(c);
return;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
//test();
proc();
return 0;
}
아래는 내가 참고한 자료(위 참고링크 1). 해결 시간이 무려 4ms. 엄청나게 빠르다.
//Alfonso2 Peterssen
//SPOJ #291 "Cube Root"
//16 - 9 - 2008
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
#define REP( i, n ) \
for ( int i = 0; i < (n); i++ )
#define REPD( i, n ) \
for ( int i = (n) - 1; i >= 0; i-- )
typedef long long int64;
const int
BASE = 10000000,
BASE_LOG = 7,
MAXLEN = 100;
int T;
struct bigint {
int size;
int64 T[MAXLEN];
bigint(int64 x = 0, int size = 1) : size(size) {
memset(T, 0, sizeof(T));
T[0] = x;
}
};
bool operator < (const bigint& a, const bigint& b) {
if (a.size != b.size)
return a.size < b.size;
REPD(i, a.size)
if (a.T[i] != b.T[i])
return a.T[i] < b.T[i];
return false;
}
bool operator <= (const bigint& a, const bigint& b) {
return !(b < a);
}
bigint& normal(bigint& bn) {
int64 r = 0, t;
REP(i, bn.size) {
while (bn.T[i] < 0) {
bn.T[i + 1] -= 1;
bn.T[i] += BASE;
}
t = bn.T[i] + r;
bn.T[i] = t % BASE;
r = t / BASE;
}
for (; r > 0; r /= BASE)
bn.T[bn.size++] = r % BASE;
while (bn.size > 1 && bn.T[bn.size - 1] == 0)
bn.size--;
return bn;
}
bigint operator + (bigint a, bigint& b) {
a.size = max(a.size, b.size);
REP(i, a.size)
a.T[i] += b.T[i];
return normal(a);
}
bigint operator + (bigint a, int64 x) {
a.T[0] += x;
return normal(a);
}
bigint operator - (bigint a, bigint& b) {
a.size = max(a.size, b.size);
REP(i, a.size)
a.T[i] += b.T[i];
return normal(a);
}
bigint operator - (bigint a, int64 x) {
a.T[0] -= x;
return normal(a);
}
bigint operator * (bigint a, bigint b) {
bigint c(0, a.size + b.size);
REP(i, a.size)
REP(j, b.size)
c.T[i + j] += a.T[i] * b.T[j];
return normal(c);
}
bigint operator * (bigint a, int64 x) {
REP(i, a.size)
a.T[i] *= x;
return normal(a);
}
pair< bigint, int64 > divmod(bigint bn, int64 x) {
int64 r = 0, t;
REPD(i, bn.size) {
t = BASE * r + bn.T[i];
bn.T[i] = t / x;
r = t % x;
}
return make_pair(normal(bn), r);
}
bigint operator / (bigint bn, int64 x) {
return divmod(bn, x).first;
}
int64 operator % (bigint bn, int64 x) {
return divmod(bn, x).second;
}
void print(bigint bn) {
printf("%lld", bn.T[bn.size - 1]);
REPD(i, bn.size - 1)
printf("%0*lld", BASE_LOG, bn.T[i]);
}
void read(bigint& bn) {
static char line[200];
scanf("%s", &line);
int len = strlen(line);
int i = len % BASE_LOG;
if (i > 0) i -= BASE_LOG;
for (; i < len; i += BASE_LOG) {
int64 x = 0;
REP(j, BASE_LOG)
x = 10 * x + (i + j >= 0 ? line[i + j] - '0' : 0);
bn.T[bn.size++] = x;
}
reverse(bn.T, bn.T + bn.size);
normal(bn);
}
int main() {
for (scanf("%d", &T); T--; ) {
bigint X;
read(X);
REP(i, 15) X = X * 100;
bigint lo(0);
bigint hi = X;
while (lo <= hi) {
bigint mid = (lo + hi) / 2;
if (mid * mid * mid <= X)
lo = mid + 1;
else
hi = mid - 1;
}
lo = lo - 1;
pair< bigint, int64 > R;
R = divmod(lo, 100000);
int p1 = R.second; lo = R.first;
R = divmod(lo, 100000);
int p2 = R.second; lo = R.first;
print(lo); printf(".%05d%05d\n", p2, p1);
}
return 0;
}