|
| 1 | +from sqlalchemy.event import listen, remove |
| 2 | + |
| 3 | +g_tracer = None |
| 4 | +g_trace_all = True |
| 5 | + |
| 6 | +def init_tracing(tracer, trace_all=False): |
| 7 | + ''' |
| 8 | + Set our global tracer. |
| 9 | + Tracer objects from our pyramid/flask/django libraries |
| 10 | + can be passed as well. |
| 11 | + ''' |
| 12 | + global g_tracer, g_trace_all |
| 13 | + if hasattr(tracer, '_tracer'): |
| 14 | + g_tracer = tracer._tracer |
| 15 | + else: |
| 16 | + g_tracer = tracer |
| 17 | + |
| 18 | + g_trace_all = trace_all |
| 19 | + |
| 20 | +def get_traced(stmt_obj): |
| 21 | + ''' |
| 22 | + Gets a bool indicating whether or not this |
| 23 | + statement is marked for tracing. |
| 24 | + ''' |
| 25 | + return getattr(stmt_obj, '_traced', False) |
| 26 | + |
| 27 | +def set_traced(stmt_obj): |
| 28 | + ''' |
| 29 | + Mark a statement to be traced. |
| 30 | + ''' |
| 31 | + stmt_obj._traced = True |
| 32 | + |
| 33 | +def get_parent_span(stmt_obj): |
| 34 | + ''' |
| 35 | + Gets a parent span for this statement, if any. |
| 36 | + ''' |
| 37 | + return getattr(stmt_obj, '_parent_span', None) |
| 38 | + |
| 39 | +def set_parent_span(stmt_obj, parent_span): |
| 40 | + ''' |
| 41 | + Marks a statement as a child of a span. |
| 42 | + It gets marked to be traced if it wasn't before. |
| 43 | + ''' |
| 44 | + stmt_obj._parent_span = parent_span |
| 45 | + stmt_obj._traced = True |
| 46 | + |
| 47 | +def has_parent_span(stmt_obj): |
| 48 | + ''' |
| 49 | + Get whether or not the statement has |
| 50 | + a parent span. |
| 51 | + ''' |
| 52 | + return hasattr(stmt_obj, '_parent_span') |
| 53 | + |
| 54 | +def get_span(stmt_obj): |
| 55 | + ''' |
| 56 | + Get the span of a statement object, if any. |
| 57 | + ''' |
| 58 | + return getattr(stmt_obj, '_span', None) |
| 59 | + |
| 60 | +def register_connectable(obj): |
| 61 | + ''' |
| 62 | + Register an object to have its events be traced. |
| 63 | + Any Connectable object is accepted, which |
| 64 | + includes Connection and Engine. |
| 65 | + ''' |
| 66 | + listen(obj, 'before_cursor_execute', _before_cursor_handler) |
| 67 | + listen(obj, 'after_cursor_execute', _after_cursor_handler) |
| 68 | + listen(obj, 'handle_error', _error_handler) |
| 69 | + |
| 70 | +def unregister_connectable(obj): |
| 71 | + ''' |
| 72 | + Remove a connectable from having its events being |
| 73 | + traced. |
| 74 | + ''' |
| 75 | + remove(obj, 'before_cursor_execute', _before_cursor_handler) |
| 76 | + remove(obj, 'after_cursor_execute', _after_cursor_handler) |
| 77 | + remove(obj, 'handle_error', _error_handler) |
| 78 | + |
| 79 | +def _get_operation_name(stmt_obj): |
| 80 | + return stmt_obj.__visit_name__ |
| 81 | + |
| 82 | +def _normalize_stmt(statement): |
| 83 | + return statement.strip().replace('\n', '').replace('\t', '') |
| 84 | + |
| 85 | +def _before_cursor_handler(conn, cursor, statement, parameters, context, executemany): |
| 86 | + if context.compiled is None: # PRAGMA |
| 87 | + return |
| 88 | + |
| 89 | + # Don't trace if trace_all is disabled and the statement wasn't marked |
| 90 | + stmt_obj = context.compiled.statement |
| 91 | + if not (g_trace_all or get_traced(stmt_obj)): |
| 92 | + return |
| 93 | + |
| 94 | + parent_span = get_parent_span(stmt_obj) |
| 95 | + operation_name = _get_operation_name(stmt_obj) |
| 96 | + |
| 97 | + # Start a new span for this query. |
| 98 | + span = g_tracer.start_span(operation_name=operation_name, child_of=parent_span) |
| 99 | + span.set_tag('component', 'sqlalchemy') |
| 100 | + span.set_tag('db.type', 'sql') |
| 101 | + span.set_tag('db.statement', _normalize_stmt(statement)) |
| 102 | + |
| 103 | + stmt_obj._span = span |
| 104 | + |
| 105 | +def _after_cursor_handler(conn, cursor, statement, parameters, context, executemany): |
| 106 | + if context.compiled is None: # PRAGMA |
| 107 | + return |
| 108 | + |
| 109 | + stmt_obj = context.compiled.statement |
| 110 | + span = get_span(stmt_obj) |
| 111 | + if span is None: |
| 112 | + return |
| 113 | + |
| 114 | + span.finish() |
| 115 | + |
| 116 | +def _error_handler(exception_context): |
| 117 | + execution_context = exception_context.execution_context |
| 118 | + stmt_obj = execution_context.compiled.statement |
| 119 | + span = get_span(stmt_obj) |
| 120 | + if span is None: |
| 121 | + return |
| 122 | + |
| 123 | + span.set_tag('error', 'true') |
| 124 | + span.finish() |
| 125 | + |
0 commit comments