From 17298af61ededcfa4e96a6fdc6f484c79ffe220d Mon Sep 17 00:00:00 2001 From: Michael Buesch Date: Thu, 10 Dec 2015 13:28:37 +0100 Subject: [PATCH] py/modmath: Add domain error checking to sqrt, log, log2, log10. These functions will raise 'ValueError: math domain error' on invalid input. --- py/modmath.c | 34 ++++++++++++++++++++++++++++------ tests/float/math_fun.py | 21 +++++++++++++-------- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/py/modmath.c b/py/modmath.c index e2836cbd26..cffd8cf491 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -25,6 +25,7 @@ */ #include "py/builtin.h" +#include "py/nlr.h" #if MICROPY_PY_BUILTINS_FLOAT && MICROPY_PY_MATH @@ -35,7 +36,10 @@ /// The `math` module provides some basic mathematical funtions for /// working with floating-point numbers. -//TODO: Change macros to check for overflow and raise OverflowError or RangeError +STATIC NORETURN void math_error(void) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "math domain error")); +} + #define MATH_FUN_1(py_name, c_name) \ STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); @@ -52,6 +56,16 @@ STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { mp_int_t x = MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj)); return mp_obj_new_int(x); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); +#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \ + mp_float_t x = mp_obj_get_float(x_obj); \ + if (error_condition) { \ + math_error(); \ + } \ + return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \ + } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); + #if MP_NEED_LOG2 // 1.442695040888963407354163704 is 1/_M_LN2 #define log2(x) (log(x) * 1.442695040888963407354163704) @@ -59,7 +73,7 @@ /// \function sqrt(x) /// Returns the square root of `x`. -MATH_FUN_1(sqrt, sqrt) +MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0)) /// \function pow(x, y) /// Returns `x` to the power of `y`. MATH_FUN_2(pow, pow) @@ -69,9 +83,9 @@ MATH_FUN_1(exp, exp) /// \function expm1(x) MATH_FUN_1(expm1, expm1) /// \function log2(x) -MATH_FUN_1(log2, log2) +MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0)) /// \function log10(x) -MATH_FUN_1(log10, log10) +MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0)) /// \function cosh(x) MATH_FUN_1(cosh, cosh) /// \function sinh(x) @@ -139,11 +153,19 @@ MATH_FUN_1(lgamma, lgamma) // log(x[, base]) STATIC mp_obj_t mp_math_log(mp_uint_t n_args, const mp_obj_t *args) { - mp_float_t l = MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[0])); + mp_float_t x = mp_obj_get_float(args[0]); + if (x <= (mp_float_t)0.0) { + math_error(); + } + mp_float_t l = MICROPY_FLOAT_C_FUN(log)(x); if (n_args == 1) { return mp_obj_new_float(l); } else { - return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[1]))); + mp_float_t base = mp_obj_get_float(args[1]); + if (base <= (mp_float_t)0.0) { + math_error(); + } + return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(base)); } } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_math_log_obj, 1, 2, mp_math_log); diff --git a/tests/float/math_fun.py b/tests/float/math_fun.py index 94411a36bd..277507ecf0 100644 --- a/tests/float/math_fun.py +++ b/tests/float/math_fun.py @@ -9,15 +9,14 @@ except ImportError: test_values = [-100., -1.23456, -1, -0.5, 0.0, 0.5, 1.23456, 100.] test_values_small = [-10., -1.23456, -1, -0.5, 0.0, 0.5, 1.23456, 10.] # so we don't overflow 32-bit precision -p_test_values = [0.1, 0.5, 1.23456] unit_range_test_values = [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.] -functions = [('sqrt', sqrt, p_test_values), +functions = [('sqrt', sqrt, test_values), ('exp', exp, test_values_small), ('expm1', expm1, test_values_small), - ('log', log, p_test_values), - ('log2', log2, p_test_values), - ('log10', log10, p_test_values), + ('log', log, test_values), + ('log2', log2, test_values), + ('log10', log10, test_values), ('cosh', cosh, test_values_small), ('sinh', sinh, test_values_small), ('tanh', tanh, test_values_small), @@ -41,7 +40,10 @@ functions = [('sqrt', sqrt, p_test_values), for function_name, function, test_vals in functions: print(function_name) for value in test_vals: - print("{:.5g}".format(function(value))) + try: + print("{:.5g}".format(function(value))) + except ValueError as e: + print(str(e)) tuple_functions = [('frexp', frexp, test_values), ('modf', modf, test_values), @@ -59,10 +61,13 @@ binary_functions = [('copysign', copysign, [(23., 42.), (-23., 42.), (23., -42.) ('atan2', atan2, ((1., 0.), (0., 1.), (2., 0.5), (-3., 5.), (-3., -4.),)), ('fmod', fmod, ((1., 1.), (0., 1.), (2., 0.5), (-3., 5.), (-3., -4.),)), ('ldexp', ldexp, ((1., 0), (0., 1), (2., 2), (3., -2), (-3., -4),)), - ('log', log, ((2., 2.), (3., 2.), (4., 5.))), + ('log', log, ((2., 2.), (3., 2.), (4., 5.), (0., 1.), (1., 0.), (-1., 1.), (1., -1.))), ] for function_name, function, test_vals in binary_functions: print(function_name) for value1, value2 in test_vals: - print("{:.5g}".format(function(value1, value2))) + try: + print("{:.5g}".format(function(value1, value2))) + except ValueError as e: + print(str(e))