| 294 | } |
| 295 | |
| 296 | func TestTraceBatchErrorWhileReadingResults(t *testing.T) { |
| 297 | t.Parallel() |
| 298 | |
| 299 | tracer := &testTracer{} |
| 300 | |
| 301 | ctr := defaultConnTestRunner |
| 302 | ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { |
| 303 | config := defaultConnTestRunner.CreateConfig(ctx, t) |
| 304 | config.Tracer = tracer |
| 305 | return config |
| 306 | } |
| 307 | |
| 308 | ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) |
| 309 | defer cancel() |
| 310 | |
| 311 | pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { |
| 312 | traceBatchStartCalled := false |
| 313 | tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { |
| 314 | traceBatchStartCalled = true |
| 315 | require.NotNil(t, data.Batch) |
| 316 | require.Equal(t, 3, data.Batch.Len()) |
| 317 | return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo") |
| 318 | } |
| 319 | |
| 320 | traceBatchQueryCalledCount := 0 |
| 321 | tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { |
| 322 | traceBatchQueryCalledCount++ |
| 323 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 324 | if traceBatchQueryCalledCount == 2 { |
| 325 | require.Error(t, data.Err) |
| 326 | } else { |
| 327 | require.NoError(t, data.Err) |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | traceBatchEndCalled := false |
| 332 | tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { |
| 333 | traceBatchEndCalled = true |
| 334 | require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart"))) |
| 335 | require.Error(t, data.Err) |
| 336 | } |
| 337 | |
| 338 | batch := &pgx.Batch{} |
| 339 | batch.Queue(`select 1`) |
| 340 | batch.Queue(`select 2/n-2 from generate_series(0,10) n`) |
| 341 | batch.Queue(`select 3`) |
| 342 | |
| 343 | br := conn.SendBatch(context.Background(), batch) |
| 344 | require.True(t, traceBatchStartCalled) |
| 345 | |
| 346 | commandTag, err := br.Exec() |
| 347 | require.NoError(t, err) |
| 348 | require.Equal(t, "SELECT 1", commandTag.String()) |
| 349 | |
| 350 | commandTag, err = br.Exec() |
| 351 | require.Error(t, err) |
| 352 | require.Equal(t, "", commandTag.String()) |
| 353 | |