| 180 | } |
| 181 | |
| 182 | func TestTraceBatchNormal(t *testing.T) { |
| 183 | t.Parallel() |
| 184 | |
| 185 | tracer := &testTracer{} |
| 186 | |
| 187 | ctr := defaultConnTestRunner |
| 188 | ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { |
| 189 | config := defaultConnTestRunner.CreateConfig(ctx, t) |
| 190 | config.Tracer = tracer |
| 191 | return config |
| 192 | } |
| 193 | |
| 194 | ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) |
| 195 | defer cancel() |
| 196 | |
| 197 | pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { |
| 198 | traceBatchStartCalled := false |
| 199 | tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { |
| 200 | traceBatchStartCalled = true |
| 201 | require.NotNil(t, data.Batch) |
| 202 | require.Equal(t, 2, data.Batch.Len()) |
| 203 | return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") |
| 204 | } |
| 205 | |
| 206 | traceBatchQueryCalledCount := 0 |
| 207 | tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { |
| 208 | traceBatchQueryCalledCount++ |
| 209 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 210 | require.NoError(t, data.Err) |
| 211 | } |
| 212 | |
| 213 | traceBatchEndCalled := false |
| 214 | tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { |
| 215 | traceBatchEndCalled = true |
| 216 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 217 | require.NoError(t, data.Err) |
| 218 | } |
| 219 | |
| 220 | batch := &pgx.Batch{} |
| 221 | batch.Queue(`select 1`) |
| 222 | batch.Queue(`select 2`) |
| 223 | |
| 224 | br := conn.SendBatch(context.Background(), batch) |
| 225 | require.True(t, traceBatchStartCalled) |
| 226 | |
| 227 | var n int32 |
| 228 | err := br.QueryRow().Scan(&n) |
| 229 | require.NoError(t, err) |
| 230 | require.EqualValues(t, 1, n) |
| 231 | require.EqualValues(t, 1, traceBatchQueryCalledCount) |
| 232 | |
| 233 | err = br.QueryRow().Scan(&n) |
| 234 | require.NoError(t, err) |
| 235 | require.EqualValues(t, 2, n) |
| 236 | require.EqualValues(t, 2, traceBatchQueryCalledCount) |
| 237 | |
| 238 | err = br.Close() |
| 239 | require.NoError(t, err) |