diff --git a/py/modmath.c b/py/modmath.c index c3ea55f936..f80b73ab52 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -17,10 +17,14 @@ mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_float(y_obj))); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); -#define MATH_FUN_BOOL1(py_name, c_name) \ +#define MATH_FUN_1_TO_BOOL(py_name, c_name) \ mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return MP_BOOL(c_name(mp_obj_get_float(x_obj))); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); +#define MATH_FUN_1_TO_INT(py_name, c_name) \ + mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_int((machine_int_t)MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); + STATIC const mp_obj_float_t mp_math_e_obj = {{&mp_type_float}, M_E}; STATIC const mp_obj_float_t mp_math_pi_obj = {{&mp_type_float}, M_PI}; @@ -44,15 +48,15 @@ MATH_FUN_1(acos, acos) MATH_FUN_1(asin, asin) MATH_FUN_1(atan, atan) MATH_FUN_2(atan2, atan2) -MATH_FUN_1(ceil, ceil) +MATH_FUN_1_TO_INT(ceil, ceil) MATH_FUN_2(copysign, copysign) MATH_FUN_1(fabs, fabs) -MATH_FUN_1(floor, floor) //TODO: delegate to x.__floor__() if x is not a float +MATH_FUN_1_TO_INT(floor, floor) //TODO: delegate to x.__floor__() if x is not a float MATH_FUN_2(fmod, fmod) -MATH_FUN_BOOL1(isfinite, isfinite) -MATH_FUN_BOOL1(isinf, isinf) -MATH_FUN_BOOL1(isnan, isnan) -MATH_FUN_1(trunc, trunc) +MATH_FUN_1_TO_BOOL(isfinite, isfinite) +MATH_FUN_1_TO_BOOL(isinf, isinf) +MATH_FUN_1_TO_BOOL(isnan, isnan) +MATH_FUN_1_TO_INT(trunc, trunc) MATH_FUN_2(ldexp, ldexp) MATH_FUN_1(erf, erf) MATH_FUN_1(erfc, erfc) diff --git a/tests/basics/math-fun.py b/tests/basics/math-fun.py index 1301dc2a5b..eb80ab0f54 100644 --- a/tests/basics/math-fun.py +++ b/tests/basics/math-fun.py @@ -6,41 +6,41 @@ test_values = [-100., -1.23456, -1, -0.5, 0.0, 0.5, 1.23456, 100.] p_test_values = [0.1, 0.5, 1.23456] unit_range_test_values = [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.] -functions = [(sqrt, p_test_values), - (exp, test_values), - (expm1, test_values), - (log, p_test_values), - (log2, p_test_values), - (log10, p_test_values), - (cosh, test_values), - (sinh, test_values), - (tanh, test_values), - (acosh, [1.0, 5.0, 1.0]), - (asinh, test_values), - (atanh, [-0.99, -0.5, 0.0, 0.5, 0.99]), - (cos, test_values), - (sin, test_values), - (tan, test_values), - (acos, unit_range_test_values), - (asin, unit_range_test_values), - (atan, test_values), - (ceil, test_values), - (fabs, test_values), - (floor, test_values), - #(frexp, test_values), - (trunc, test_values) +functions = [('sqrt', sqrt, p_test_values), + ('exp', exp, test_values), + ('expm1', expm1, test_values), + ('log', log, p_test_values), + ('log2', log2, p_test_values), + ('log10', log10, p_test_values), + ('cosh', cosh, test_values), + ('sinh', sinh, test_values), + ('tanh', tanh, test_values), + ('acosh', acosh, [1.0, 5.0, 1.0]), + ('asinh', asinh, test_values), + ('atanh', atanh, [-0.99, -0.5, 0.0, 0.5, 0.99]), + ('cos', cos, test_values), + ('sin', sin, test_values), + ('tan', tan, test_values), + ('acos', acos, unit_range_test_values), + ('asin', asin, unit_range_test_values), + ('atan', atan, test_values), + ('ceil', ceil, test_values), + ('fabs', fabs, test_values), + ('floor', floor, test_values), + #('frexp', frexp, test_values), + ('trunc', trunc, test_values) ] -for function, test_vals in functions: +for function_name, function, test_vals in functions: + print(function_name) for value in test_vals: - print("{:8.7f}".format(function(value))) + print("{:8.7g}".format(function(value))) -binary_functions = [(copysign, [(23., 42.), (-23., 42.), (23., -42.), +binary_functions = [('copysign', copysign, [(23., 42.), (-23., 42.), (23., -42.), (-23., -42.), (1., 0.0), (1., -0.0)]) ] -for function, test_vals in binary_functions: +for function_name, function, test_vals in binary_functions: + print(function_name) for value1, value2 in test_vals: - print("{:8.7f}".format(function(value1, value2))) - - + print("{:8.7g}".format(function(value1, value2)))